The issue that pushed me to find a better solution
My research involves working with cosmological models — dark energy equations of state, modified gravity theories, tachyonic fields — and answering a key question: what do the observational data actually tell us about the model parameters? The standard approach for answering this is Bayesian inference. Typically, I run dynesty nested sampling, performing anywhere from a few thousand to several hundred thousand likelihood evaluations, depending on how complex the model is.
Throughout most of my PhD, I never gave much thought to the ODE solver embedded within the likelihood function, since solve_ivp did the job. It was dependable, so I simply used it and focused on other things.
Then I began working on a tachyonic DBI dark energy model, where the dark energy field is described by a non-standard kinetic term, and the background and perturbation equations form a coupled, somewhat stiff system. Each likelihood evaluation required solving those ODEs, computing the comoving distance, and evaluating the distance modulus at the redshifts of 30 supernovae.
I ran a performance profile. The ODE solve by itself was consuming 0.4 ms per call. In a nested sampling run with 10⁵ evaluations, that adds up to 40 seconds — purely on ODE calls, not even counting any additional overhead. And for a 10-parameter model, computing a gradient via central finite differences requires 20 extra forward solves, pushing that 0.4 ms up to 8 ms per gradient evaluation. That totals 300 seconds, roughly 5 minutes, just for gradient computations — for a single nested sampling run.
Clearly, something needed to change.
What I discovered: diffrax
After spending a day searching, I came across diffrax [1], a numerical ODE solver library built entirely in JAX. It is not a neural network surrogate or any kind of approximation. It implements the same embedded Runge–Kutta algorithms I was already using in scipy — Tsit5 in place of RK45, but from the same family of methods — now compiled, differentiable, and vectorizable.
Three key capabilities stem from the “built entirely in JAX” design:
JIT compilation – The full adaptive-stepping loop compiles into a single XLA kernel. After the first call, there is zero Python overhead.
Autodiff – Since every operation inside the solver is a JAX primitive, jax.grad propagates gradients through the entire solve. This yields exact gradients in a single backward pass, no matter how many parameters are involved.
vmap – A whole batch of parameter vectors can be solved simultaneously using jax.vmap. This is essential for nested sampling.
Getting it installed takes about 10 seconds:pip install jax diffrax
The benchmark problem: flat ΛCDM with supernovae data
To make the comparison tangible, here is the exact problem I was tackling. In a flat ΛCDM universe, the comoving distance obeys:
The distance modulus is then given by: μ(z) = 5 log₁₀[(1+z)χ(z) / 10 pc]. The goal is to infer (Ωₘ, H₀) from 30 mock SNIa distance-modulus measurements.
from scipy.integrate import solve_ivp
import numpy as np
C_KMS = 299792.458 # speed of light [km/s]
def rhs(z, chi, Om, H0):
return C_KMS / (H0 * np.sqrt(Om*(1+z)**3 + (1-Om)))
def forward_scipy(Om, H0, z_obs):
sol = solve_ivp(rhs, t_span=(0, z_obs[-1]),
y0=[0.0], t_eval=z_obs,
args=(Om, H0), method="RK45",
rtol=1e-8, atol=1e-10)
chi = sol.y[0]
return 5 * np.log10((1 + z_obs) * chi * 1e5) # distance modulus
The traditional approach: SciPy
from scipy.integrate import solve_ivp
import numpy as np
C_KMS = 299792.458 # speed of light [km/s]
def rhs(z, chi, Om, H0):
return C_KMS / (H0 * np.sqrt(Om*(1+z)**3 + (1-Om)))
def forward_scipy(Om, H0, z_obs):
sol = solve_ivp(rhs, t_span=(0, z_obs[-1]),
y0=[0.0], t_eval=z_obs,
args=(Om, H0), method="RK45",
rtol=1e-8, atol=1e-10)
chi = sol.y[0]
return 5 * np.log10((1 + z_obs) * chi * 1e5) # distance modulusThe modern approach: Diffrax
import jax, jax.numpy as jnp
import diffrax as dfx
# Non-negotiable: enable 64-bit (more on this below)
jax.config.update("jax_enable_x64", True)
def H_jax(z, Om, H0):
return H0 * jnp.sqrt(Om*(1+z)**3 + (1-Om))
@jax.jit # compile once, call fast forever
def forward_diffrax(theta, z_obs):
Om, H0 = theta[0], theta[1]
sol = dfx.diffeqsolve(
dfx.ODETerm(lambda z, chi, a: C_KMS / H_jax(z, a[0], a[1])),
dfx.Tsit5(),
t0=0.0, t1=float(z_obs[-1]), # initial and final value
dt0=1e-3, # initial step-size
y0=jnp.array(0.0), # initial condition
args=(Om, H0),
saveat=dfx.SaveAt(ts=z_obs),
stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-10),
max_steps=10_000,
)
chi = sol.ys
return 5 * jnp.log10((1 + z_obs) * chi * 1e5)The underlying physics is exactly the same. The solver algorithm is nearly identical (Tsit5 closely mirrors RK45). The only structural changes are @jax.jit and the diffrax API. Let us see what those two modifications deliver.
Surprise 1: the speedup
solve_ivp: 404 μs per call. diffrax post-JIT:
Here is the paraphrased version of the article:
59 μs per call. That’s 6.8× faster.
When I first saw this figure, I paused for a moment to take it in. Let me be upfront about where the speed improvement actually comes from — it’s not some kind of trick.
With solve_ivp, Python re-enters the C/Cython backend on every single call. Fresh memory gets allocated each time. The adaptive while-loop passes through the Python interpreter repeatedly, checking: “Is the local error too high? If so, reject the step; otherwise, increase the step size; repeat.” For a solve that takes 12 steps, that means 12 rounds of Python dispatch, 12 separate memory allocations, and 12 error-estimate calculations — all bottlenecked by the interpreter.
With diffrax, the first @jax.jit call traces the entire computation — including the adaptive while-loop, which gets compiled down to a lax.while_loop and handed off to XLA to produce a single machine-code kernel. Every call after that runs the kernel directly. No Python overhead, no repeated allocations, no dispatch cost.

Across 100,000 likelihood evaluations, 404 μs versus 59 μs adds up to 40.4 seconds compared to 5.9 seconds. And this gap only widens as the model grows more complex.
Surprise 2: gradients come at no extra cost
This was the part that didn’t just change my workflow — it fundamentally shifted how I approach inference. Using scipy, computing a single gradient of the log-likelihood with respect to two parameters (Ωₘ, H₀) requires 4 forward solves (via central finite differences). As you scale up, the cost climbs quickly: 10 parameters demand 20 forward solves, 50 parameters demand 100. The expense grows linearly with the number of parameters.
With diffrax, all I need to write is:
def loss(theta):
mu_pred = forward_diffrax(theta, z_obs)
return 0.5 * jnp.sum(((mu_pred - mu_obs) / sigma_mu)**2)
grad_fn = jax.jit(jax.grad(loss)) # that's the only change needed
g = grad_fn(jnp.array([0.3, 70.0])) # exact gradientBehind the scenes, JAX’s reverse-mode automatic differentiation integrates the adjoint equations [2] backward through the ODE solve — but I never have to derive or code those equations myself. The result is an exact gradient at a cost roughly equal to a single forward pass, regardless of how many parameters there are.

How to pick the right solver
Choosing an appropriate solver requires a bit of thought. I defaulted to Tsit5 for nearly everything, and it handled about 95% of my use cases without any issues. Here’s the full decision process:
- Non-stiff ODE (most cosmological problems) →
dfx.Tsit5()← start here - Very tight tolerances (< 10⁻⁹) →
dfx.Dopri8() - Stiff ODE (many steps, solver feels sluggish) →
dfx.Kvaerno5() - Stiff + non-stiff terms (IMEX) →
dfx.KenCarp4() - SDE →
dfx.EulerHeun()ordfx.SPaRK()
A quick test for stiffness: print sol.stats["num_steps"]. If the step count is 10–100× higher than expected, your problem is stiff and you’ll need an implicit solver.
The payoff: end-to-end cosmological inference
Now let me walk through the full inference comparison. Both pipelines start from the same poor initial guess (Ωₘ, H₀) = (0.10, 60), far from the true values (0.30, 70), and run for 350 gradient steps.
- Scipy pipeline: gradients via central finite differences, plain gradient descent, fixed learning rate.
- Diffrax pipeline: gradients via autodiff, Adam optimizer with a cosine-decay learning-rate schedule.
import optax # optimizers for JAX
# Rescale parameters so Adam treats them on equal footing
# Om ~ 0.3, h = H0/100 ~ 0.7 -- both O(1) now
def loss_scaled(theta_s):
theta = jnp.array([theta_s[0], 100.0 * theta_s[1]])
return loss(theta)
grad_scaled = jax.jit(jax.grad(loss_scaled))
schedule
method = opt.init(theta)
for step in range(350):
g = grad_fn(theta)
updates, state = opt.update(g, state)
theta = optax.apply_updates(theta, updates)
if (step + 1) % 50 == 0:
print(f"Step {step+1}: Om={theta[0]:.3f} H0={100*theta[1]:.2f}")
While the diffrax pipeline recovers physically meaningful parameters, the scipy pipeline fails to adjust both parameters at once — a classic example of gradient descent struggling with poorly scaled problems. Adam resolves this automatically through its per-parameter adaptive learning rates, but Adam can only do this because autodiff supplies it with exact gradients.
Three mistakes I made (so you can avoid them)

Pitfall 1: neglecting 64-bit precision. JAX uses 32-bit floats by default. When you tighten tolerances (rtol < 10⁻⁷), this can produce unexpected behaviour: on my ODE the solver takes 69 steps in 32-bit mode versus just 12 in 64-bit. Push the tolerances even tighter and it may fail entirely. The remedy is straightforward — activate 64-bit mode before doing anything else:
jax.config.update("jax_enable_x64", True) # must be set firstPitfall 2: timing without a warm-up run. The first invocation of any @jax.jit-decorated function carries a one-time compilation overhead of roughly 90–100 ms. If you include that in your benchmarks, diffrax will appear slower than scipy for the wrong reason. The solution is to run once for warm-up and discard that initial result:
_ = forward_diffrax(theta, z_obs).block_until_ready() # trigger compilation
# NOW time it -- this reflects the true speedAlso: JAX dispatches work asynchronously. Always use .block_until_ready() inside timing loops, or you’ll measure only the time to queue the computation rather than complete it.
Pitfall 3: the argument-order trap. scipy.odeint expects f(y, t) (state first, time second). Nearly everything else (solve_ivp, diffrax) expects f(t, y). If you migrate old odeint code to diffrax without swapping the arguments, you end up solving a different ODE — and you typically won’t see an error. You’ll simply get the wrong result.
Should you switch?
The honest take: if you’re solving a single ODE and don’t need gradients, solve_ivp works just fine — there’s no reason to learn a new API. But if you’re doing inference (repeated likelihood evaluations, parameter gradients, or batched solves), the switch is well worth the effort.
| Situation | solve_ivp | odeint | diffrax |
|---|---|---|---|
| One-off solve, no inference | ✓ | ✓ | fine too |
| Nested sampling / MCMC | slow | slow | YES |
| Need gradients | FD only | FD only | exact, free |
| Batch over parameter grid | for-loop | for-loop | vmap |
| Stiff system | Radau | auto (LSODA) | Kvaerno5 |
| SDE or Neural ODE | no | no | YES |
| GPU/TPU | no | no | YES |
The migration itself is minimal. The forward model requires roughly six changed lines. The gradient appears with just one additional line. The rest of the inference code remains unchanged.
One important clarification: diffrax is not “ML-based” in the sense of employing a neural network. It uses the same classical Runge–Kutta mathematics, implemented in JAX. The “ML acceleration” comes from JIT compilation and autodiff — both infrastructure tools borrowed from the ML world and applied to a classical numerical solver. The only truly ML-based approach would be a neural surrogate that learns θ → μ(z) from training data — a separate and more advanced topic.
The complete working code
Everything above condensed into a single self-contained script (pip install jax diffrax optax):
"""
flat_lcdm_inference.py
Infer (Omega_m, H0) from 30 mock supernovae using diffrax + Adam.
pip install jax diffrax optax
"""
import jax, jax.numpy as jnp, numpy as np
import diffrax as dfx, optax
from scipy.integrate import solve_ivp # only for generating mock data
jax.config.update("jax_enable_x64", True)
# -- Constants and data -----------------------------------------------
C_KMS = 299792.458
z_obs = jnp.linspace(0.05, 1.5, 30)
SIGMA = 0.10
# Mock data at truth (Om=0.30, H0=70)
def chi_np(Om, H0):
sol = solve_ivp(lambda z, y: C_KMS/(H0*np.sqrt(Om*(1+z)**3+(1-Om))),
(0, 1.5), [0.], t_eval=np.array(z_obs), rtol=1e-10)
return sol.y[0]
mu_true = 5*np.log10((1+np.array(z_obs))*chi_np(0.3, 70.)*1e5)
mu_obs = jnp.array(mu_true + 0.10*np.random.default_rng(42).standard_normal(30))
# -- diffrax forward model --------------------------------------------
@jax.jit
def forward(theta):
Om, H0 = theta[0], theta[1]
sol = dfx.diffeqsolve(
dfx.ODETerm(lambda z, chi, a:
C_KMS/(a[1]*jnp.sqrt(a[0]*(1+z)**3+(1-a[0])))),
dfx.Tsit5(),
t0=0., t1=1.5, dt0=1e-3, y0=jnp.array(0.),
args=(Om, H0),
saveat=dfx.SaveAt(ts=z_obs),
stepsize_controller=dfx.PIDController(rtol=1e-8, atol=1e-10),
max_steps=10_000,
).ys
return 5*jnp.log10((1+z_obs)*sol*1e5)
# -- Loss and gradient ------------------------------------------------
def loss(th_s): # optimise in scaled coords (Om, h=H0/100)
mu = forward(jnp.array([th_s[0], 100.*th_s[1]]))
return 0.5*jnp.sum(((mu - mu_obs)/SIGMA)**2)
grad_fn = jax.jit(jax.grad(loss))
# Warm up the JIT compiler
theta_init = jnp.array([0.10, 0.60])
_ = forward(jnp.array([0.3, 0.7])).block_until_ready()
_ = grad_fn(theta_init).block_until_ready()
# -- Adam optimiser with cosine LR schedule ---------------------------
sched = optax.cosine_decay_schedule(init_value=0.05, decay_steps=350, alpha=0.04)
opt = optax.adam(sched)
theta = theta_init
state = opt.init(theta)
for step in range(350):
g = grad_fn(theta)
updates, state = opt.update(g, state)
theta = optax.apply_updates(theta, updates)
if (step + 1) % 50 == 0:
print(f"Step {step+1}: Om={theta[0]:.3f} H0={100*theta[1]:.2f}")theta = opt.init(theta)
print(f"{'Step':>5} {'Om':>7} {'H0':>7} {'Loss':>8}")
for step in range(350):
g = grad_fn(theta)
upd, state = opt.update(g, state)
theta = optax.apply_updates(theta, upd)
if (step + 1) % 70 == 0 or step == 0:
L = float(loss(theta))
print(f"{step+1:5d} {float(theta[0]):7.4f} {100*float(theta[1]):7.3f} {L:8.2f}")
Om_fit, H0_fit = float(theta[0]), 100*float(theta[1])
print(f"nFinal: Om = {Om_fit:.3f} H0 = {H0_fit:.2f}")
print(f"Truth: Om = 0.300 H0 = 70.00")
Key Results Summary
| Measurement | scipy | diffrax | Speedup |
|---|---|---|---|
| Single forward call | 0.4 ms | 57 μs | ~7× |
| Gradient (2 params) | 1.62 ms | 195 μs | ~8× |
| 10⁵ forward calls | 40 s | 5.9 s | ~7× |
| 10⁵ gradient calls | ~98 s | ~19.6 s | ~5× |
| Final Ωₘ (350 steps) | 0.652 (wrong) | 0.270 | — |
| Final H₀ (350 steps) | 60.10 (stuck) | 70.94 | — |
The incorrect scipy outcome isn’t due to solver failure—it stems from basic gradient descent with finite-difference gradients struggling with the 200× scale difference between Ωₘ and H₀.
Closing Reflection
Migrating my forward model to diffrax left the physics and inference approach untouched. What it did change was whether performing that inference was even practical. A nested-sampling run that was trending toward a massive forward-model time budget shrank to under a minute. Gradients that previously demanded 20 additional solves per step became virtually cost-free.
The learning curve took roughly one afternoon. Most of the troubleshooting involved the 64-bit precision caveat and initial JIT warmup confusion. The benefits have been tangible and instant.
If you’re a physicist relying on scipy for repeated likelihood evaluations and haven’t explored diffrax yet, I hope this gives you a compelling reason to start.
A note on reproducibility: the precise timings you observe will vary on your machine and even between runs on the same system. On my Mac (MacBook Air M3 Base Model), the diffrax forward call ranged from 55 μs to 62 μs across sessions, while scipy ranged from 400 μs to 407 μs. This variation is expected—CPU thermal conditions, OS scheduling, and memory cache states all shift absolute numbers by roughly 10–15%. What remains consistent is the ratio: diffrax is reliably 7–8× faster than scipy on this problem. Focus on the ratio, not the absolute time, as the key takeaway.
The Python code that produced every figure in this article is available at: github.com/Samit1424/ODE_solver_comparison
Note: Excluding the featured image, which was generated using an AI tool, all illustrations are the author’s original work.
References
[1] P. Kidger, On Neural Differential Equations, DPhil thesis, University of Oxford, 2021. docs.kidger.site/diffrax/
[2] R. T. Q. Chen, Y. Rubanova, J. Bettencourt, D. Duvenaud, Neural Ordinary Differential Equations, NeurIPS 2018.



