JAXbind: Bind any function to JAX

JAX is widely used in machine learning and scientific computing, the latter of which often relies on existing high-performance code that we would ideally like to incorporate into JAX. Reimplementing the existing code in JAX is often impractical and the existing interface in JAX for binding custom code either limits the user to a single Jacobian product or requires deep knowledge of JAX and its C++ backend for general Jacobian products. With JAXbind we drastically reduce the effort required to bind custom functions implemented in other programming languages with full support for Jacobian-vector products and vector-Jacobian products to JAX. Specifically, JAXbind provides an easy-to-use Python interface for defining custom, so-called JAX primitives. Via JAXbind, any function callable from Python can be exposed as a JAX primitive. JAXbind allows a user to interface the JAX function transformation engine with custom derivatives and batching rules, enabling all JAX transformations for the custom primitive.


Statement of Need
The use of JAX (Bradbury et al., 2018) is widespread in the natural sciences.Of particular interest is JAX's powerful transformation system.It enables a user to retrieve arbitrary derivatives of functions, batch computations, and just-in-time compile code for additional performance.Its transformation system requires that all components of the computation are written in JAX.
A plethora of high-performance code is not written in JAX and thus not accessible from within JAX.Rewriting these codes is often infeasible and/or inefficient.Ideally, we would like to mix existing high-performance code with JAX code.However, connecting code to JAX requires knowledge of the internals of JAX and its C++ backend.
In this paper, we present JAXbind, a package for bridging any function to JAX without in-depth knowledge of JAX's transformation system.The interface is accessible from Python without requiring any development in C++.The package is able to register any function and its partial derivatives and their transpose functions as a JAX native call, a so-called primitive.
We believe JAXbind to be highly useful in scientific computing.We intend to use this package to connect the Hartley transform and the spherical harmonic transform from DUCC (Reinecke, 2024) to the probabilistic programming package NIFTy (Edenhofer et al., 2024) as well as the radio interferometry response from DUCC with the radio astronomy package resolve (Arras et al., 2024).Furthermore, we intend to connect the nonuniform FFT from DUCC with JAX for applications in strong-lensing astrophysics.We envision many further applications within and outside of astrophysics.
The functionality of JAXbind extends the external callback functionality in JAX.Currently, JAXbind, akin to the external callback functions in JAX, briefly requires Python's global interpreter lock (GIL) to call the user-specified Python function.In contrast to JAX's external callback functions, JAXbind allows for both a custom Jacobian-vector product and vector-Jacobian product.To the best of our knowledge no other code currently exists for easily binding generic functions and both of their Jacobian products to JAX, without the need for C++ or LLVM.The package that comes the closest is Enzyme-JAX (W. S. Moses & Zinenko, 2024), which allows one to bind arbitrary LLVM/MLIR, including C++, with automatically-generated (W. S. Moses et al., 2021Moses et al., , 2022;;W. Moses & Churavy, 2020) or manually-defined derivatives to JAX.
PyTorch (Ansel et al., 2024) and TensorFlow (Abadi et al., 2015) also provide interfaces for custom extensions.PyTorch has an extensively documented Python interface1 for wrapping custom Python functions as PyTorch functions.This interface connects the custom function to PyTorch's automatic differentiation engine, allowing for custom Jacobian and Jacobian transposed applications, similar to what is possible with JAXbind.Additionally, PyTorch allows a user to interface its C++ backend with custom C++ or CUDA extensions2 .JAXbind, in contrast, currently only supports functions executed on the CPU, although the JAX built-in C++ interface also allows for custom GPU kernels.TensorFlow includes a C++ interface3 for custom functions that can be executed on the CPU or GPU.Custom gradients can be added to these functions.

Automatic Differentiation and Code Example
Automatic differentiation is a core feature of JAX and often one of the main reasons for using it.Thus, it is essential that custom functions registered with JAX support automatic differentiation.In the following, we will outline which functions our package requires to enable automatic differentiation via JAX.For simplicity, we assume that we want to connect the nonlinear function f (x 1 , x 2 ) = x 1 x 2 2 to JAX.The JAXbind package expects the Python function for f to take three positional arguments.The first argument, out, is a tuple into which the results are written.The second argument is also a tuple containing the input to the function, in our case, x 1 and x 2 .Via kwargs_dump, any keyword arguments given to the registered JAX primitive can be forwarded to f in a serialized form.JAX's automatic differentiation engine can compute the Jacobian-vector product jvp and vector-Jacobian product vjp of JAX primitives.The Jacobian-vector product in JAX is a function applying the Jacobian of f at a position x to a tangent vector.In mathematical nomenclature this operation is called the pushforward of f and can be denoted as ∂f (x) : T x X → T f (x) Y , with T x X and T f (x) Y being the tangent spaces of X and Y at the positions x and f (x).As the implementation of f is not JAX native, JAX cannot automatically compute the jvp.Instead, an implementation of the pushforward has to be provided, which JAXbind will register as the jvp of the JAX primitive of f .For our example, this Jacobian-vector-product function is given by ∂f The vector-Jacobian product vjp in JAX is the linear transpose of the Jacobian-vector product.In mathematical nomenclature this is the pullback (∂f (x)) T : T f (x) Y → T x X of f .Analogously to the jvp, the user has to implement this function as JAX cannot automatically construct it.For our example function, the vector-Jacobian product is (∂f (x 1 , x 2 )) T (dy) = (x 2 2 dy, 2x 1 x 2 dy).To just-in-time compile the function, JAX needs to abstractly evaluate the code, i.e., it needs to be able to infer the shape and dtype of the output of the function given only the shape and dtype of the input.We have to provide these abstract evaluation functions returning the output shape and dtype given an input shape and dtype for f as well as for the vjp application.The output shape of the jvp is identical to the output shape of f itself and does not need to be specified again.The abstract evaluation functions take normal positional and keyword arguments.We have now defined all ingredients necessary to register a JAX primitive for our function f using the JAXbind package.f_jax = jaxbind.get_nonlinear_call(f, (f_jvp, f_vjp), f_abstract, f_abstract_T ) f_jax is a JAX primitive registered via the JAXbind package supporting all JAX transformations.We can now compute the jvp and vjp of the new JAX primitive and even jit-compile and batch it.
In scientific computing, linear functions such as, e.g., spherical harmonic transforms are widespread.If the function f is linear, differentiation becomes trivial.Specifically for a linear function f , the pushforward or jvp of f is identical to f itself and independent of the position at which it is computed.Expressed in formulas, ∂f (x)(dx) = f (dx) if f is linear in x.Analogously, the pullback or vjp becomes independent of the initial position and is given by the linear transpose of f , thus (∂f (x)) T (dy) = f T (dy).Also, all higher order derivatives can be expressed in terms of f and its transpose.To make use of these simplifications, JAXbind provides a special interface for linear functions, supporting higher order derivatives, only requiring an implementation of the function and its transpose.

Platforms
Currently, JAXbind only supports primitives that act on CPU memory.In the future, GPU support could be added, which should work analogously to the CPU support in most respects.The automatic differentiation in JAX is backend agnostic and would thus not require any additional bindings to work on the GPU.