flowMC: Normalizing-flow enhanced sampling package for probabilistic inference in Jax

flowMC is a Python library for accelerated Markov Chain Monte Carlo (MCMC) leveraging deep generative modeling. It is built on top of the machine learning libraries JAX and Flax. At its core, flowMC uses a local sampler and a learnable global sampler in tandem to efficiently sample posterior distributions. While multiple chains of the local sampler generate samples over the region of interest in the target parameter space, the package uses these samples to train a normalizing flow model, then uses it to propose global jumps across the parameter space. The flowMC sampler can handle non-trivial geometry, such as multimodal distributions and distributions with local correlations. The key features of flowMC are summarized in the following list: * Since flowMC is built on top of JAX, it supports gradient-based samplers through automatic differentiation such as MALA and Hamiltonian Monte Carlo (HMC). * flowMC uses state-of-the-art normalizing flow models such as Rational-Quadratic Splines to power its global sampler. These models are very efficient in capturing important features within a relatively short training time. * Use of accelerators such as GPUs and TPUs are natively supported. The code also supports the use of multiple accelerators with SIMD parallelism. * By default, Just-in-time (JIT) compilations are used to further speed up the sampling process. * We provide a simple black box interface for the users who want to use flowMC by its default parameters, yet provide at the same time an extensive guide explaining trade-offs while tuning the sampler parameters. The tight integration of all the above features makes flowMC a highly performant yet simple- to-use package for statistical inference.

flowMC is a Python library for accelerated Markov Chain Monte Carlo (MCMC) leveraging deep generative modelling.It is built on top of the machine learning libraries JAX and Flax.At its core, flowMC uses a local sampler and a learnable global sampler in tandem to efficiently sample posterior distributions.While multiple chains of the local sampler generate samples over the region of interest in the target parameter space, the package uses these samples to train a normalizing flow model, then use it to propose global jumps across the parameter space.The flowMCsampler can handle non-trivial geometry, such as multimodal distributions and distributions with local correlations.
The key features of flowMC are summarized in the following list: • Since flowMC is built on top of JAX, it supports gradient-based samplers through automatic differentiation such as MALA and Hamiltonian Monte Carlo (HMC).• flowMC uses state-of-the-art normalizing flow models such as Rational-Quadratic Splines to power its global sampler.These models are very efficient in capturing important features within a relatively short training time.• Use of accelerators such as GPUs and TPUs are natively supported.The code also supports the use of multiple accelerators with SIMD parallelism.• By default, Just-in-time (JIT) compilations are used to further speed up the sampling process.• We provide a simple black box interface for the users who want to use flowMC by its default parameters, yet provide at the same time an extensive guide explaining trade-offs while tuning the sampler parameters.
The tight integration of all the above features makes flowMC a highly performant yet simpleto-use package for statistical inference.

Statement of need
Bayesian inference requires to compute expectations with respect to a posterior distribution on parameters θ after collecting observations D. This posterior is given by where (D|θ) is the likelihood induced by the model, p 0 (θ) the prior on the parameters and Z(D) the model evidence.As soon as the dimension of θ exceeds 3 or 4, it is necessary to resort to a robust sampling strategy such as a MCMC.Drastic gains in computational efficiency can be obtained by a careful selection of the MCMC transition kernel which can be assisted by machine learning methods and libraries.
Gradient-based sampler In a high dimensional space, sampling methods which leverage gradient information of the target distribution are shown to be efficient by proposing new samples likely to be accepted.flowMC supports gradient-based samplers such as MALA and HMC through automatic differentiation with Jax.The computational cost of obtaining a gradient in this way is often of the same order as evaluating the target function itself, making gradient-based samplers favorable with respect to the efficiency/accuracy trade-off.
Learned transition kernels with normalizing flow Posterior distribution of many real-world problems have non-trivial geometry such as multi-modality and local correlations, which could drastically slow down the convergence of the sampler only based on gradient information.
To address this problem, flowMC also uses a generative model, namely a normalizing flow (NF) (Kobyzev et al., 2021;Papamakarios et al., 2021), that is trained to mimic the posterior distribution and used as a proposal in Metropolis-Hastings MCMC steps.Variant of this idea have been explored in the past few years (e.g., Albergo et al., 2019;Hoffman et al., 2019;Parno & Marzouk, 2018, and references therein).Despite the growing interest for these methods, few accessible implementations for non-experts already exist, especially with GPU and TPU supports.Notably, a version of the NeuTra sampler (Hoffman et al., 2019) is available in Pyro (Bingham et al., 2019) and the PocoMC package (Karamanis et al., 2022) implements a version of Sequential Monte Carlo including NFs.
flowMC implements the method proposed by Gabrié et al. (2021).As individual chains explore their local neighborhood through gradient-based MCMC steps, multiple chains can be used to train the NF, so it can learn the global landscape of the posterior distribution.In turn, the chains can be propagated with a Metropolis-Hastings kernel using the NF to propose globally in the parameter space.The cycle of local sampling, NF tuning and global sampling is repeated until obtaining chains of the desired length.The entire algorithm belongs to the class of adaptive MCMCs (Andrieu & Thoms, 2008) collecting information from the chains previous steps to simultaneously improve the transition kernel.Usual MCMC diagnostics can be applied to assess the robustness of the inference results, thereby avoiding the common concern of validating the NF model.If further sampling from the posterior is necessary, the flow trained during a previous run can be reused without further training.The mathematical detail of the method are explained in (Gabrié et al., 2021).
Use of Accelerator Modern accelerators such as GPUs and TPUs are designed to execute dense computation in parallel.Due to the sequential nature of MCMC, a common approach in leveraging accelerators is to run multiple chains in parallel, then combine their results to obtain the posterior distribution.However, a large portion of the computation time comes from the burn-in phase for which chain-parallelization provides no speed up.flowMC is built on top of JAX, so that it supports the use of GPU and TPU accelerators by default.Users can write codes in the same way as they would do on a CPU, and the library will automatically detect the available accelerators and use them at run time.Furthermore, the library leverage Just-In-Time compilations to further improve the performance of the sampler.

Simplicity and extensibility
We provide a black-box interface with a few tuning parameters for users who intend to use flowMC without too much customization on the sampler side.The only inputs we require from the users are the log-posterior function and initial position of the chains.On top of the black-box interface, the package offers automatic tuning for the local samplers, in order to reduce the number of hyperparameters the users have to manage.
While we provide a high-level API for most of the users, the code is also designed to be extensible.In particular, custom local and global sampling kernels can be integrated in the sampler module.