Design and Implementation
Chromatix’s architecture draws significant inspiration from modern deep learning frameworks. In this section, we outline the core principles guiding our design and implementation. We believe that effective computational optics frameworks—much like their deep learning counterparts—must prioritize three essential features: differentiability, composability, and scalability. We also describe how Chromatix’s high-level implementation reflects these principles.
Differentiability
Differentiability refers to the ability to compute gradients, which are essential for gradient-based optimization—such as tuning parameters in an optical simulation. For simple, low-dimensional inputs, numerical differentiation may suffice. However, for high-dimensional data like spatial light modulator (SLM) pixel arrays, this approach becomes prohibitively slow. Instead, backpropagation—efficiently propagating gradients through each simulation step—is preferred. When paired with diverse optical models or neural networks, automatic differentiation becomes highly valuable.
Languages like MATLAB and C lack built-in automatic differentiation, forcing developers to manually derive and code gradients—a tedious, error-prone, and inflexible process. Modern deep learning frameworks24,25,26, by contrast, automatically compute gradients for any differentiable function with respect to its parameters. Chromatix leverages JAX to offer the same capability: once a simulation is defined, its gradients are automatically available.
Differentiability has already enabled end-to-end design of computational optics systems across numerous applications3,9,18,19,20,21,22,23. It also enhances solutions to inverse optical problems. Traditional methods often oversimplify both the sample and optical setup33,34,35,36, whereas differentiable models can incorporate realistic complexity. Automatic differentiation unlocks advanced optimizers like Adam37, enabling high-fidelity reconstructions even when forward models include complex physics such as scattering16,38 or sample deformation39. Crucially, this comes at almost no extra coding effort: defining the forward model automatically defines its gradients. This also enables self-calibrating algorithms40, where physical parameters (e.g., illumination angle in tomography) are jointly optimized with the sample to improve accuracy41,42.
Another emerging approach replaces discrete voxel-based representations with continuous neural network-based models—known as implicit neural representations (INRs) or neural radiance fields39,43,44,45. INRs have been used to separate motion artifacts from sample dynamics39, estimate dynamic aberrations44, reconstruct 3D quantitative phase in scattering samples46, and correct aberrations without wavefront sensors or calibration43. All of these rely on differentiable simulations for training.
Composability
Differentiability naturally supports composability: the gradient of a composite function can be derived from the gradients of its parts. More broadly, composability means being able to easily swap or replace components—such as changing an activation function in a neural network—without rewriting the entire system. This is possible thanks to standardization in machine learning, allowing researchers to rapidly integrate others’ work by simply plugging in new functions.
Optics, however, lags behind. Implementations are often custom-built for specific projects, with inconsistent conventions and no shared benchmarks for accuracy or speed. This leads to wasted effort, increased errors, and difficulty reproducing results. Chromatix addresses this by proposing a standardized framework for wave-optics simulations, enabling seamless integration of diverse optical models (Fig. 1). All experiments in this paper use a shared, rigorously tested codebase, and many additional components are documented. We believe that providing both a standard library and reference implementations can significantly accelerate research and improve reproducibility.
a, Chromatix integrates wave-optics modeling, GPU acceleration, and differentiability into a single library, offering a unified framework for diverse applications. b, It supports a broad range of optical elements—including lenses, sensors, scalar and vectorial free-space propagation models70,71,72, and complex scattering samples16,38. c, These components can be flexibly combined to simulate various experimental setups and tackle a wide array of computational optics challenges. Green-highlighted elements indicate which component or sample is optimized in each use case. DMD: digital micromirror device; SLM: spatial light modulator; f: focal length; zn: propagation distance.
Scalability
Optical systems are increasingly demanding larger fields of view (FOVs) and higher resolutions, often requiring large compute clusters for sample reconstruction. Thus, scalability is a critical requirement for modern optics simulation tools. Researchers need the flexibility to prototype quickly on a laptop and then seamlessly scale up to GPU clusters for large-scale reconstructions.
Previous tools have made this difficult: NumPy runs only on CPUs47; MATLAB requires specialized code for GPU support and lacks general-purpose automatic differentiation; PyTorch24 and TensorFlow25 simplify GPU programming with automatic differentiation but struggle with multi-GPU support and perform poorly on optical simulation tasks that differ significantly from typical neural network operations. Writing device-specific code demands substantial effort and locks in limitations,
This approach falls short when it comes to the rapid iteration cycles required by modern scientific research. Chromatix addresses this by leveraging JAX26 and its built-in XLA (Accelerated Linear Algebra) compiler, enabling high-speed optical simulations across CPUs, GPUs, and tensor processing units—all from a single codebase. Unlike other frameworks, there’s no need to write specialized low-level GPU code to achieve performance gains, as is often necessary with tools like PyTorch24. Additionally, JAX provides built-in utilities for automatically vectorizing computations—such as processing an entire batch on one GPU—or distributing work across multiple GPUs, regardless of how the optical components in a system are described26. For instance, with just a few minor code adjustments, a basic 2D, single-wavelength simulation can be expanded into a full 3D, multi-wavelength simulation that runs efficiently across several GPUs simultaneously (see Extended Data Fig. 1 for concrete examples).
Implementation
In deep learning, models are typically built as chains of operations called layers. Optics follows a similar pattern: optical systems are made up of a series of components and light propagation steps. However, there’s a crucial distinction—the “state” of an optical system isn’t abstract; it corresponds directly to the physical complex light field traveling through the setup. To fully capture this field—and thus the complete state of the system at any moment—you also need details like wavelength, polarization, and spatial sampling resolution. The foundational concept in Chromatix is that all these pieces of information can be unified within a single, core data structure. Each optical element then becomes a transformation applied to this structured field, and any full optical system is simply a sequence of such transformations. This design allows Chromatix to handle a broad range of optical setups through one consistent interface, making it easy to extend and adapt.
Experiments
We conducted six computational experiments to highlight four key strengths of Chromatix: its ability to solve inverse problems for sample reconstruction, accelerate both reconstruction and optical design via deep learning, flexibly combine modular optical elements and models, and boost simulation speed by up to tenfold. These demonstrations include reimplementations of established wave-based computational optics methods, as well as novel in silico solutions to inverse problems involving previously underexplored combinations of optical effects.
Inverse problems for reconstructing samples
Microscope aberrations are often treated as uniform across the field of view (FOV), mainly because modeling and measuring position-dependent point spread functions (PSFs) is computationally and experimentally demanding. However, one research group48,49 pointed out that many imaging systems exhibit rotational symmetry due to the symmetric nature of common optical components. As a result, most aberrations change only with radial distance from the center of the FOV, up to a rotation. Characterizing aberration variation along this radial profile drastically reduces calibration effort and computational cost—scaling linearly rather than quadratically with camera sensor height. This efficiency makes it feasible to perform deconvolution that accounts for spatially varying aberrations. The team introduced “ring deconvolution microscopy”48, a method that exploits rotational invariance in standard microscopes using incoherent illumination (e.g., fluorescence or brightfield with incoherent light) to model spatially varying PSFs efficiently. We implemented this technique in Chromatix (Fig. 2b) using data from a UCLA Miniscope50,51—a compact widefield microscope. The system is modeled as a 4f optical setup with rotationally symmetric Seidel aberrations placed in the Fourier plane. After calibrating Seidel coefficients from reference images48, we deconvolved the captured sample image using this rotationally invariant yet spatially varying PSF model.

a, Sample type and representation used in ring deconvolution microscopy. b, Our Chromatix implementation of ring deconvolution microscopy, following the approach in ref. 48. c–f, Zoomed-in views from the center (red border) and edge (blue border) of the FOV are shown in the bottom row. All displayed images have been adjusted to minimize vignetting for clarity; uncorrected versions appear in Extended Data Fig. 2. Panel c shows a raw image of incoherently illuminated rabbit liver tissue taken with a gradient-index lens Miniscope, without correcting for spatially varying aberrations. d, Result of standard spatially invariant deconvolution applied to the same FOV. e,f, Results from Chromatix’s rotationally invariant deconvolution (e) compared to the original PyTorch version, which lacks parallelization and cannot process the full FOV on a single H100 GPU (80 GB memory) (f). g, Sample type and representation for computational aberration correction using implicit neural networks. h, Implementation of the CoCoA self-supervised framework using coordinate-based implicit neural networks to simultaneously infer aberrations and reconstruct a 3D sample from a single 3D measurement. i–k, Visual comparison of a maximum-intensity projection from a 4-µm-thick slice: raw aberrated data (i), reconstruction from the original PyTorch implementation (j), and Chromatix reconstruction of the same dendrite structure (k). l–n, Validation of recovered aberrations: true wavefront measured via direct sensing (l), aberration inferred by the original PyTorch code (m), and aberration inferred by Chromatix (n). o, Sample type and representation for 3D refractive index microscopy. p, Schematic of refractive index microscopy, which recovers the 3D refractive index distribution of a strongly scattering sample from intensity measurements. q–s, Raw multi-angle coherent illumination data (b) are used to reconstruct the full 3D refractive index map of a D. rerio embryo tail at 24 hours post-fertilization (hpf), shown via maximum projections from the original MATLAB reconstruction (r) and the Chromatix reconstruction (s). t,u, Cross-sectional slices through the volume (colored borders indicate slice locations) reveal fewer grid artifacts in the Chromatix output (u(i), u(ii)) compared to the original MATLAB result (t(i), t(ii)). In all panels, green highlights indicate optimized parameters, and dashed gray arrows show gradient flow during iterative optimization.
We present, respectively, the measured
Here’s a rephrased, clearer version of your content while preserving all HTML structure exactly:
An image of a rabbit liver with inconsistent illumination is shown alongside three reconstructions: one using Chromatix’s ring deconvolution—which restores sharpness across the entire field of view (FOV), another using standard deconvolution—which only improves clarity at the FOV center—and a third from the original ring deconvolution method48,49 (Fig. 2c–f and Extended Data Fig. 2). Unlike the earlier approach, our version supports a significantly larger FOV by distributing computations across multiple GPUs. The original method fails to reconstruct the full camera FOV without either slowing down dramatically or exceeding the memory limits of a single GPU (e.g., 48 GB on an RTX 8000 or 80 GB on an H100). Chromatix is also much faster: it runs 4.5× quicker than the original PyTorch code on one GPU and nearly 19× faster when using 8 GPUs (see below).
Beyond voxel grids, Chromatix supports implicit neural representations (INRs), which can simplify the optimization process in certain cases44. For example, CoCoA (Coordinate-based Neural Representations for Computational Adaptive Optics)43 simultaneously reconstructs both the sample (as an INR) and optical aberrations (as Zernike coefficients) without requiring labeled training data. Chromatix’s implementation of CoCoA is illustrated in Fig. 2g,h. In this scenario, aberrations are assumed to be uniform across space, and the sample emits incoherent fluorescence rather than being lit by incoherent transmitted light. We show a maximum-intensity projection of a 4-µm-thick slice from a widefield image of mouse neurons, followed by reconstructions from the original CoCoA method and from Chromatix (Fig. 2i–k). The Chromatix result preserves smoother, more continuous dendrites, whereas the original produces fragmented, spotty structures. Additionally, Chromatix completes the reconstruction twice as fast on a single GPU and nearly 9× faster with 8 GPUs. On a fluorescent bead dataset with known, deliberately introduced aberrations43, Chromatix recovers the Zernike mode coefficients with a root-mean-square error of 3.56 nm—much better than the original method’s 6.97 nm (r.m.s. calculated over the three nonzero Zernike modes used; see Extended Data Fig. 3).
While increasing the number of INR layers might reduce detail loss in the original method43, here we compare both approaches using identical network sizes. Chromatix uses a fully paraxial model for the field at the pupil plane, whereas the original implementation combines an exact (but high-sample-rate) pupil model with a paraxial approximation of the second lens in the 4f setup. This hybrid approach risks aliasing if undersampled. Chromatix’s consistent paraxial approximation leads to superior dendrite reconstruction (Fig. 2i–k), demonstrating not only higher speed but also the advantage of a unified, well-tested modeling framework like Chromatix when avoiding mismatches between simulation and reality.
Researchers in ref. 16 demonstrated that computational imaging can quantitatively map the 3D refractive index of highly scattering samples—beyond what traditional widefield microscopy allows—using intensity measurements alone. Their sample was a D. rerio embryo tail at 24 hours post-fertilization (hpf), illuminated coherently at multiple angles17. The refractive index (Fig. 2r) was inferred by matching real measurements (Fig. 2q) to a differentiable simulation (Fig. 2p) of light propagation through the sample. Because such samples scatter light many times, they require a multislice forward model52. Although we use the exact same physical model as the original MATLAB code17, our Chromatix implementation is 3–13× faster on comparable volumes—cutting reconstruction from hours to minutes. It’s also far more concise (~25 lines vs. ~107 in the original16,17) and more adaptable thanks to automatic differentiation. This speed boost lets us use better reconstruction settings, eliminating the large grid artifacts seen in the original result (Fig. 2r) in favor of a cleaner output (Fig. 2s).
Programmable optics and deep learning
Spatial light modulators (SLMs) now offer precise control over light via millions of individually addressable pixels. While commonly used for holography, this capability also allows engineers to design custom point-spread functions (PSFs) tailored to specific tasks53, such as capturing 3D fluorescent volumes in a single snapshot. Optimizing such complex, high-dimensional optical systems demands gradient-based methods and benefits greatly from GPU acceleration. One team3 developed a deep learning approach that integrates a reconfigurable microscope (the Holoscope) with a neural network to reconstruct 3D fluorescence volumes from single 2D snapshots (Fig. 3a). In this system, both the neural network weights and the SLM’s phase mask pixels are co-optimized using a differentiable model of the microscope. The PSF effectively compresses 3D information into a 2D image, which is then decoded by a FourierNet architecture that leverages structural priors of the sample. This enables rapid, accurate 3D reconstruction that combines hardware and algorithmic intelligence.

a, Holoscope3 implementation in Chromatix showing a programmable 3D snapshot microscope, compressing volumetric information into a 2D representation with subsequent FourierNet reconstruction. b–e, Holoscope demonstration using Chromatix showing the sample-specific PSF (b), simulated 2D image of the simulated 3D sample (c) (approximately 0.01 s to capture), ground truth 3D volume (d) (approximately 7.0 s to capture via confocal microscopy assuming a scan speed of 100 ns per voxel) and Chromatix-enabled 3D reconstruction (e) from the single simulated 2D image. f, DeepCGH architecture using a UNet and a propagation step to directly generate a hologram in a single feedforward step from target
All HTML tags, classes, IDs, and structure remain unchanged. Only the text has been reworded for clarity, flow, and readability while preserving scientific accuracy.3D patterns. g–l, Demonstration of DeepCGH12 using Chromatix showing requested stimulation patterns at three planes spaced 10 mm apart around the focal plane (g,i,k) and their resulting simulated intensity distributions (h,j,l). Colored insets show detail of the 3D patterns at each plane. Intensity values in g–l are normalized. For all panels, green highlights denote optimized parameters of either neural networks or optical systems.
This microscope design can therefore be programmed to function as a snapshot microscope optimized for various sample types, while using exactly the same hardware. The microscope is modeled using a 4f system with an SLM (phase mask) in the Fourier plane and is optimized for whole-brain imaging of fluorescently labeled D. rerio larvae. The PSF of the 4f system is simulated with coherent propagation, and the image is simulated as the incoherent sum of these PSFs that is efficiently implemented as a convolution of the PSF and the sample intensity. We show the learned PSF (Fig. 3b), the simulated 2D measurement of a virtual zebrafish volume (Fig. 3c) and the ground truth volume and simulated reconstruction3 (Fig. 3d,e). Chromatix reproduces the original results3 nearly exactly: on a test set of 10 volumes and their simulated images, reconstruction networks trained with identical PSFs offer a structure similarity index measure on a test set of 10 volumes of 0.979 ± 0.003 (mean ± standard error; higher is better) for both Chromatix and the original implementation3 (not significantly different at P = 0.695 via two-sided t-test, Extended Data Fig. 4). Chromatix also outperforms the original implementation3 in training speed by a factor of approximately 7× (Fig. 5). Practically, this reduces the optimization time for a single PSF from weeks to days.
SLMs also enable computer-generated holography systems for optogenetics, where 3D holographic stimulation patterns are used to perturb neural activity in the brain. Most holography systems rely on some form of iterative optimization (for example, refs. 33,34,54,55) to find the phase to display on the SLM. While this produces accurate solutions, iterating does become problematic when speed is paramount. For optogenetics, point cloud holography can be used to stimulate multiple neurons without iterative optimization of phase patterns, but this only allows for placing copies of a single pattern at the desired locations15. Due to the interest in holography for displays, fast holography algorithms for arbitrary patterns have emerged that use neural networks to quickly generate a hologram given a target pattern42,56. Applied to optogenetics, DeepCGH12 also demonstrates fast computer-generation of holograms by training a neural network to generate phase patterns from intensity images of arbitrary 3D patterns in a single feedforward inference step. We implemented DeepCGH12 (Fig. 3f). We show the desired target patterns and resulting simulated pattern at three different depth planes using the phase pattern produced by the DeepCGH method in Chromatix (Fig. 3g–l). We achieve nearly identical results to the original TensorFlow implementation: on a test set of 16 target patterns, Chromatix achieves a structure similarity index measure of 0.985 ± 0.001 (mean ± standard error; higher is better) versus 0.982 ± 0.001 for the original implementation (significantly different at P = 0.018 < 0.05 via two-sided t-test, Extended Data Fig. 5) and peak signal to noise ratio of 35.40 ± 0.37 (mean ± standard error; higher is better) for Chromatix versus 34.95 ± 0.16 for the original implementation12 (not significantly different at P = 0.177 via two-sided t-test). Our implementation is approximately 17 lines of code for a differentiable hologram simulation versus 33 lines in the original work12. While achieving the same quality, Chromatix provides a 2.5× performance improvement on a single GPU, which increases to over 10× when using 8 GPUs in parallel (Fig. 5).
Flexible modeling with optical building blocks
Because Chromatix models are constructed from components that can be flexibly combined (Fig. 1), we can straightforwardly construct complex optical models and also optimize them with arbitrary objective functions. We show another programmable microscope modeled as a 4f system with an SLM in the Fourier plane, followed by a neural network-based reconstruction step (Fig. 4a–f). The objective in this demonstration is to optimize the PSF of this programmable microscope to perform spectroscopic single-molecule localization57 from a single snapshot image: that is, to reconstruct multicolor point sources using only a single-channel image. The simulated samples consist of several point sources incoherently emitting fluorescence at 25 wavelengths from 400 nm to 650 nm that are simulated in parallel using Chromatix. We train the neural network to reconstruct the multicolor sample at the corresponding 10-nm intervals, giving us a hyperspectral cube from a single-channel 2D measurement. The optimized PSF allows visual classification of the color of these point sources on a monochrome simulated camera image (Fig. 4c,d) by taking advantage of different fringe patterns for different wavelengths. The reconstruction (Fig. 4f) reasonably matches the true colors of the points in the sample (Fig. 4e). We highlight that here we are optimizing the same programmable microscope model that was used for snapshot microscopy (Fig. 3a), but for an entirely new combination of sample type and objective.

a, Demonstration of PSF engineering for spectroscopic single-molecule localization microscopy using deep learning, where a neural network reconstructs both the structure and spectrum of a sparse 2D point sample from a single-channel image. b, Microscope model for multicolor PSF optimization with a single SLM in the Fourier plane. c, Optimized multicolor PSF for spectral imaging using a neural network, with different wavelengths and colors overlaid. d, Simulated single-channel image of multicolor fluorescent point sources. e, True simulated multicolor point sources, with different wavelengths and colors overlaid. f, Reconstructed multicolor point sources, with different wavelengths and colors overlaid. g, Iterative optimization workflow generating optimal phase masks for scattering-compensated holography. i, Chromatix model of holographic pattern formation through scattering media, which combines the holography model of Fig. 3f and the scattering sample model of Fig. 2p. h, Peak axial intensity distribution (normalized within the range 0–1) along the direction of propagation for the target pattern, an uncorrected pattern and the corrected pattern that are visualized below. j–m, Visual evolution of holographic pattern quality. Target pattern (
When the pattern produced by an optimized hologram transmitted through free space (k) is compared to that of the same hologram after passing through a scattering sample with a known 3D refractive index profile (l), and then to the pattern generated by the Chromatix-corrected hologram after traversing the same scattering medium (m), the differences become evident. In every panel, green highlights indicate optimized parameters, while dashed gray arrows show how gradients are propagated during iterative optimization.
Combining various wave-optics models in arbitrary configurations can expand the utility of differentiable simulations in biological research. In optogenetics, researchers use intricate 3D light patterns—often generated via holography—to precisely control neuronal activity. However, achieving accurate 3D holographic control is already difficult, and the challenge intensifies in biological tissue due to light scattering. As light travels through tissue, it scatters, potentially stimulating unintended neurons and raising the risk of phototoxicity. Here, we demonstrate how Chromatix can optimize holographic light patterns in highly scattering tissue by capturing the scattered output (Fig. 4g–m). Our simulation models a plane wave striking a phase mask (SLM), being focused by a thin lens, and then propagating through a scattering volume using the multislice beam propagation method—the same approach used for the scattering sample in Fig. 2p. This allows us to observe intensity distribution throughout the entire volume. Without correction, scattering leads to non-uniform stimulation (Fig. 4h). By incorporating the measured scattered intensity as feedback into the optimization loop, we achieve nearly uniform stimulation (blue line in Fig. 4h) across the full axial range. This in silico experiment illustrates how Chromatix empowers researchers to quickly prototype and refine their concepts, turning theoretical insights into practical outcomes.
High performance through parallelization
To showcase Chromatix’s computational efficiency and scalability, we measured iteration speeds across all training and optimization tasks discussed. Chromatix outperforms all prior optical methods, delivering 2–6× speedups on a single GPU and up to 22× faster performance on 8 GPUs in the best-case scenario (Fig. 5). Gains on a single GPU stem largely from reduced overhead thanks to JAX compilation, compared to implementations in MATLAB, PyTorch, or TensorFlow. More dramatically, Chromatix enables order-of-magnitude accelerations through parallelization—achieved with minimal code changes, thanks to its native integration with JAX. This scalability allows Chromatix to tackle large-scale problems and makes existing inverse problems far more tractable, as seen in the expanded field of view in ring deconvolution microscopy (Fig. 2a–f), the drastically reduced optimization time in refractive index microscopy (Fig. 2o–u), and faster snapshot PSF optimization using deep learning (Fig. 3a–e).

The vertical axis displays the relative speedup of Chromatix on 1–8 GPUs, with original implementations serving as the baseline (gray dashed line at 1×). Each point represents the mean speedup for a given method, with error bars indicating the standard error across individual optimization iterations. Violin plots in lighter shades visualize the full distribution of speedups. All speedups are measured relative to the original single-GPU versions. CGH stands for computer-generated holography.



