GPJax: A Gaussian Process Framework in JAX

Gaussian processes (GPs, Rasmussen & Williams, 2006) are Bayesian nonparametric models that have been successfully used in applications such as geostatistics (Matheron, 1963), Bayesian optimisation (Mockus et al., 1978), and reinforcement learning (Deisenroth & Rasmussen, 2011). GPJax is a didactic GP library targeted at researchers who wish to develop novel GP methodology. The scope of GPJax is to provide users with a set of composable objects for constructing GP models that closely resemble the underlying maths that one would write on paper. Furthermore, by the virtue of being written in JAX (Bradbury et al., 2018), GPJax natively supports CPUs, GPUs and TPUs through efficient compilation to XLA, automatic differentiation and vectorised operations. Consequently, GPJax provides a modern GP package that can effortlessly be tailored, extended and interleaved with other libraries to meet the individual needs of researchers and scientists.


Summary
Gaussian processes (GPs, Rasmussen & Williams, 2006) are Bayesian nonparametric models that have been successfully used in applications such as geostatistics (Matheron, 1963), Bayesian optimisation (Mockus et al., 1978), and reinforcement learning (Deisenroth & Rasmussen, 2011). GPJax is a didactic GP library targeted at researchers who wish to develop novel GP methodology. The scope of GPJax is to provide users with a set of composable objects for constructing GP models that closely resemble the underlying maths that one would write on paper. Furthermore, by the virtue of being written in JAX (Bradbury et al., 2018), GPJax natively supports CPUs, GPUs and TPUs through efficient compilation to XLA, automatic differentiation and vectorised operations. Consequently, GPJax provides a modern GP package that can effortlessly be tailored, extended and interleaved with other libraries to meet the individual needs of researchers and scientists.

Statement of Need
From both an applied and methodological perspective, GPs are widely employed in the statistics and machine learning communities. High-quality software packages that promote GP modelling are accountable for much of their success. However, there currently exists a gap within the JAX ecosystem for a Gaussian process package to be developed that incorporates scalable inference techniques. GPJax seeks to resolve this.
GPJax has been carefully tailored to amalgamate with the JAX ecosystem. For efficient Markov Chain Monte Carlo inference, GPJax can utilise samplers from BlackJax (BlackJax, 2021) and TensorFlow Probability (Abadi et al., 2016). For gradient-based optimisation, GPJax integrates seamlessly with Optax (Babuschkin et al., 2020), providing a vast suite of optimisers and learning rate schedules. To efficiently represent probability distributions, GPJax leverages Distrax (Babuschkin et al., 2020) and TensorFlow Probability (Abadi et al., 2016). To combine GPs with deep learning methods, GPJax can incorporate the functionality provided within Haiku (Babuschkin et al., 2020). The GPJax documentation includes examples of each of these integrations.
The foundation of each abstraction given in GPJax is a Chex (Babuschkin et al., 2020) dataclass object. These require significantly less boilerplate code than regular Python classes, leading to a more readable codebase. Moreover, Chex dataclasses are registered as PyTree nodes, facilitating the applications of JAX operations such as just-in-time compilation and automatic differentiation to any GPJax object.
The intimacy between GPJax and the underlying maths also makes GPJax an excellent package for people new to GP modelling. Having the ability to easily cross-reference the contents of a textbook with the code that one is writing is invaluable when trying to build an intuition for a new statistical method. We further support this effort in GPJax through documentation that provides detailed explanations of the operations conducted within each notebook.

Wider Software Ecosystem
Within the Python community, the three most popular packages for GP modelling are GPFlow (Matthews et al., 2017), GPyTorch (Gardner et al., 2018), and GPy (GPy, 2012). Despite these packages being indispensable tools for the community, none support integration with a JAX-based workflow. On the other hand, BayesNewton (Wilkinson et al., 2021) and TinyGP (Foreman-Mackey, 2021) packages utilise a Jax backend. However, BayesNewton is designed on top of ObJax (Objax Developers, 2020), making integration with the broader Jax ecosystem challenging. Meanwhile, TinyGP offers excellent integration with inference frameworks such as NumPyro (Phan et al., 2019) but does not yet support inducing points frameworks (e.g., Hensman et al., 2013). GPJax exists to resolve these issues. Furthermore, modern research from the GP literature, graph kernels (Borovitskiy et al., 2021) and Wasserstein barycentres for GPs (Mallasto & Feragen, 2017), for example, are supported within GPJax but absent from these packages. Finally, the Stheno package (Bruinsma, 2022) supports a JAX backend along with TensorFlow, PyTorch and Numpy. Whilst this integrates GPs into an extensive JAX workflow, GPJax has the advantage of being a pure JAX codebase, whereas Stheno requires using a custom linear algebra framework.

External Usage
Two recent research papers (Pinder et al., 2021 utilise the graph kernel functionality provided by GPJax. Furthermore, GPJax is being used to build probabilistic ensembles of climate models (Amos et al., 2022) and perform adaptive sampling in deep-sea environmental (Dodd et al., 2022).

Acknowledgments
As an open-source project, GPJax has benefitted from contributions made by the wider community. We especially thank Juan Emmanuel Johnson and are grateful for the thoughts and advice from the wider GP community.

Funding Statement
TP is supported by the Data Science for the Natural Environment project (EPSRC grant number EP/R01860X/1). DD is supported by the STOR-i Centre for Doctoral Training (EPSRC grant number EP/S022252/1) and the Research Hub for Transforming Energy Infrastructure through Digital Engineering (ARC grant number IH200100009).