Re-Envisioning Numerical Information Field Theory (NIFTy.re): A Library for Gaussian Processes and Variational Inference

Imaging is the process of transforming noisy, incomplete data into a space that humans can interpret. NIFTy is a Bayesian framework for imaging and has already successfully been applied to many fields in astrophysics. Previous design decisions held the performance and the development of methods in NIFTy back. We present a rewrite of NIFTy , coined NIFTy.re , which reworks the modeling principle, extends the inference strategies, and outsources much of the heavy lifting to JAX. The rewrite dramatically accelerates models written in NIFTy , lays the foundation for new types of inference machineries, improves maintainability

that was used in NIFTy to JAX's automatic differentiation engine.This lays the foundation for new types of inference machineries that make use of the higher order derivatives provided by JAX.Through these changes, we envision to harness significant gains in maintainability of NIFTy.recompared to NIFTy and a faster development cycle for new features.
We expect NIFTy.re to be highly useful for many imaging applications and envision many applications within and outside of astrophysics (Arras, Frank, et al., 2019;Arras et al., 2022;Eberle et al., 2022Eberle et al., , 2023;;Frank et al., 2017;S. Hutschenreuter et al., 2022;Sebastian Hutschenreuter et al., 2023;Leike et al., 2020;Leike & Enßlin, 2019;Mertsch & Phan, 2023;J. Roth et al., 2023;Jakob Roth et al., 2023;Scheel-Platz et al., 2023;Tsouros et al., 2024;Welling et al., 2021;Westerkamp et al., 2023).NIFTy.re has already been successfully used in two galactic tomography publications (Edenhofer et al., 2024;Leike et al., 2022).A very early version of NIFTy.reenabled a 100-billion-dimensional reconstruction using a maximum posterior inference.In a newer publication, NIFTy.re was used to infer a 500-million-dimensional posterior distribution using VI (Knollmüller & Enßlin, 2019).The latter publication extensively used NIFTy.re'sGPU support to reduce the runtime by two orders of magnitude compared to the CPU.With NIFTy.rebridging ideas from NIFTy to JAX, we envision many new possibilities for inferring classical machine learning models with NIFTy's inference methods and a plethora of opportunities to use NIFTy-components such as the GP models in classical neural network frameworks.

Core Components
NIFTy.re brings tried and tested structured GP models and VI algorithms to JAX.GP models are highly useful for imaging problems, and VI algorithms are essential to probe high-dimensional posteriors, which are often encountered in imaging problems.NIFTy.reinfers the parameters of interest from noisy data via a stochastic mapping that goes in the opposite direction, from the parameters of interest to the data.
NIFTy and NIFTy.rebuild up hierarchical models for the posterior inference.The log-posterior function reads ln (|) ∶= ℓ(, ()) + ln () + const with log-likelihood ℓ, forward model  mapping the parameters of interest  to the data space, and log-prior ln ().The goal of the inference is to draw samples from the posterior (|).
What is considered part of the likelihood versus part of the prior is ill-defined.Without loss of generality, NIFTy and NIFTy.rere-formulate models such that the prior is always standard Gaussian.They implicitly define a mapping from a new latent space with a priori standard Gaussian parameters  to the parameters of interest .The mapping () is incorporated into the forward model (()) in such a way that all relevant details of the prior model are encoded in the forward model.This choice of re-parameterization (Rezende & Mohamed, 2015) is called standardization.It is often carried out implicitly in the background without user input.

Gaussian Processes
One standard tool from the NIFTy.retoolbox is the so-called correlated field GP model from NIFTy.This model relies on the harmonic domain being easily accessible.For example, for pixels spaced on a regular Cartesian grid, the natural choice to represent a stationary kernel is the Fourier domain.In the generative picture, a realization  drawn from a GP then reads  = FT ∘ √  ∘  with FT the (fast) Fourier transform, √  the square-root of the power-spectrum in harmonic space, and  standard Gaussian random variables.In the implementation in NIFTy.re and NIFTy, the user can choose between two adaptive kernel models, a non-parametric kernel √  and a Matérn kernel √  (Arras et al., 2022;Guardiani et al., 2022 for details on their implementation).A code example that initializes a non-parametric GP prior for a 128 × 128 space with unit volume is shown in the following.from nifty8 import re as jft dims = (128, 128) cfm = jft.CorrelatedFieldMaker("cf") cfm.set_amplitude_total_offset(offset_mean=2, offset_std=(1e-1, 3e-2)) # Parameters for the kernel and the regular 2D Cartesian grid for which # it is defined cfm.add_fluctuations( dims, distances=tuple(1.0 / d for d in dims), fluctuations=(1.0,5e-1), loglogavgslope=(-3.0,2e-1), flexibility=(1e0, 2e-1), asperity=(5e-1, 5e-2), prefix="ax1", non_parametric_kind="power", ) # Get the forward model for the GP prior correlated_field = cfm.finalize()Not all problems are well described by regularly spaced pixels.For more complicated pixel spacings, NIFTy.refeatures Iterative Charted Refinement (Edenhofer et al., 2022), a GP model for arbitrarily deformed spaces.This model exploits nearest neighbor relations on various coarsenings of the discretized modeled space and runs very efficiently on GPUs.For one-dimensional problems with arbitrarily spaced pixels, NIFTy.re also implements multiple flavors of Gauss-Markov processes.

Building Up Complex Models
Models are rarely just a GP prior.Commonly, a model contains at least a few non-linearities that transform the GP prior or combine it with other random variables.For building more complex models, NIFTy.reprovides a Model class that offers a somewhat familiar object-oriented design yet is fully JAX compatible and functional under the hood.The following code shows how to build a slightly more complex model using the objects from the previous example.for the purpose of compiling.Depending on the value, JAX will either treat the attribute as an unknown placeholder or as a known concrete attribute and potentially inline it during compilation.This mechanism is extensively used in likelihoods to avoid inlining large constants such as the data and to avoid expensive re-compilations whenever possible.

Variational Inference
NIFTy.re is built for models with millions to billions of degrees of freedom.To probe the posterior efficiently and accurately, NIFTy.rerelies on VI.Specifically, NIFTy.reimplements Metric Gaussian Variational Inference (MGVI) and its successor geometric Variational Inference (geoVI) (Frank et al., 2021;Frank, 2022;Knollmüller & Enßlin, 2019).At the core of both MGVI and geoVI lies an alternating procedure in which one switches between optimizing the Kullback-Leibler divergence for a specific shape of the variational posterior and updating the shape of the variational posterior.MGVI and geoVI define the variational posterior via samples, specifically, via samples drawn around an expansion point.The samples in MGVI and geoVI exploit model-intrinsic knowledge of the posterior's approximate shape, encoded in the Fisher information metric and the prior curvature (Frank et al., 2021).
NIFTy.re allows for much finer control over the way samples are drawn and updated compared to NIFTy.NIFTy.reexposes stand-alone functions for drawing MGVI and geoVI samples from any arbitrary model with a likelihood from NIFTy.re and a forward model that is differentiable by JAX.In addition to stand-alone sampling functions, NIFTy.reprovides tools to configure and execute the alternating Kullback-Leibler divergence optimization and sample adaption at a lower abstraction level.These tools are provided in a JAXopt/Optax-style optimizer class (Blondel et al., 2022;DeepMind et al., 2020).
A typical minimization with NIFTy.re is shown in the following.It retrieves six independent, antithetically mirrored samples from the approximate posterior via 25 iterations of alternating between optimization and sample adaption.The final result is stored in the samples variable.A convenient one-shot wrapper for the code below is jft.optimize_kl.By virtue of all modeling tools in NIFTy.rebeing written in JAX, it is also possible to combine NIFTy.retools with BlackJAX (Cabezas & Louf, 2023) or any other posterior sampler in the JAX ecosystem.Figure 1 shows an exemplary posterior reconstruction employing the above model.The posterior mean agrees with the data but removes noisy structures.The posterior standard deviation is approximately equal to typical differences between the posterior mean and the data.

Performance of NIFTy.re compared to NIFTy
We test the performance of NIFTy.re against NIFTy for the simple yet representative model from above.To assess the performance, we compare the time required to apply   ∶=   +  to random input with   denoting the Fisher metric of the overall likelihood at position  and  the identity matrix.Within NIFTy.re, the Fisher metric of the overall likelihood is decomposed into  † ,  −1  , with  , the implicit Jacobian of the forward model  at  and  −1 the Fisher-metric of the Poisson likelihood.We choose to benchmark   as a typical VI minimization in NIFTy.re and NIFTy is dominated by calls to this function.Figure 2 shows the median evaluation time in NIFTy of applying   to new, random tangent positions and the evaluation time in NIFTy.re of building   and applying it to new, random tangent positions for exponentially larger models.The 16%-quantiles and the 84%-quantiles of the timings are obscured by the marker symbols.We chose to exclude the build time of   in NIFTy from the comparison, putting NIFTy at an advantage, as its automatic differentiation is built around calls to   with  rarely varying.We ran the benchmark on one CPU core, eight CPU cores, and on a GPU on a compute-node with an Intel Xeon Platinum 8358 CPU clocked at 2.60G Hz and an NVIDIA A100 SXM4 80 GB HBM2 GPU.The benchmark used jax==0.4.23 and jaxlib==0.4.23+cuda12.cudnn89.We vary the size of the model by increasing the size of the two-dimensional square image grid.
For small image sizes, NIFTy.re on the CPU is about one order of magnitude faster than NIFTy.Both reach about the same performance at an image size of roughly 15,000 pixels and continue to perform roughly the same for larger image sizes.The performance increases by a factor of three to four with eight cores for NIFTy.re and NIFTy, although NIFTy.re is slightly better at using the additional cores.On the GPU, NIFTy.re is consistently about one to two orders of magnitude faster than NIFTy for images larger than 100,000 pixels.
We believe the performance benefits of NIFTy.re on the CPU for small models stem from the reduced Python overhead by just-in-time compiling computations.At image sizes larger than roughly 15,000 pixels, both evaluation times are dominated by the fast Fourier transform and are hence roughly the same as both use the same underlying implementation (Reinecke, 2024).Models in NIFTy.re and NIFTy are often well aligned with GPU programming models and thus consistently perform well on the GPU.Modeling components such as the new GP models implemented in NIFTy.re are even better aligned with GPU programming paradigms and yield even higher performance gains (Edenhofer et al., 2022).

Conclusion
NIFTy.re implements the core GP and VI machinery of the Bayesian imaging package NIFTy in JAX.The rewrite moves much of the heavy-lifting from home-grown solutions to JAX, and we envision significant gains in maintainability of NIFTy.re and a faster development cycle moving forward.The rewrite accelerates typical models written in NIFTy by one to two orders of magnitude, lays the foundation for new types of inference machineries by enabling higher order derivatives via JAX, and enables the interoperability of NIFTy's VI and GP methods with the JAX machine learning ecosystem.

Figure 2 :
Figure 2:Median evaluation time of applying the Fisher metric plus the identity metric to random input for NIFTy.re and NIFTy on the CPU (one and eight core(s) of an Intel Xeon Platinum 8358 CPU clocked at 2.60G Hz) and the GPU (A100 SXM4 80 GB HBM2).The quantile range from the 16%-to the 84%-quantile is obscured by the marker symbols.
All GP models in NIFTy.re as well as all likelihoods behave like instances of jft.Model, meaning that JAX understands what it means if a computation involves self, other jft.Model instances, or their attributes.In other words, correlated_field, forward, and lh from the code snippets shown here are all so-called pytrees in JAX, and, for example, the following is valid code jax.