Up and Running with JAX – JIT Compilation, Vectorizing Transformations and autodiff
Want to share your content on python-bloggers? click here.
I first learned about JAX a few years back in an article on functorch. Functorch was a library that brought JAX-like composable function transformations to PyTorch, initially developed as a separate library but has since been fully integrated into PyTorch’s core (as of PyTorch 2.0).
I’ve recently invested time in learning JAX, which has proven incredibly worthwhile. The clean functional approach makes my code more maintainable and reproducible, while delivering significant performance and efficiency improvements with surprisingly minimal changes to existing codebases.
So, what is JAX? It is a high-performance numerical computing library developed by Google Research. It combines the ease of use of Python and Numpy with the speed and efficiency of XLA (Accelerated Linear Algebra), making it particularly well-suited for machine learning research and numerical computing that requires high performance.
At its core, JAX extends Numpy’s functionality with automatic differentiation capabilities. This is essential for gradient-based optimization in machine learning. JAX also excels at just-in-time compilation, which translates Python functions into optimized machine code at runtime.
JAX takes a functional programming approach (no side-effects), emphasizing immutability and pure functions. Operations don’t modify their inputs but instead return new values. This is particularly valuable for numerical computations since it enables better parallelization and optimization. Rather than changing arrays in place, JAX functions create new arrays with updated values, resulting in code that is more composable and reproducible. As highlighted in JAX: The Sharp Bits:
JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.
Of the JAX features I researched, vmap
is the most readily applicable to the work I do. vmap
is a vectorizing transformation that automatically adds a batch dimension to calculations. It stands for “vectorized map” and lets you run a function across multiple inputs in a vectorized fashion without explicitly writing code for batch processing. This enables writing simple, single example functions, while simultaneously taking advantage of the performance benefits resulting from vectorized execution.
The JAX numpy submodule can often be used as drop-in replacement for Numpy since the API is almost identical. It provides the same API as Numpy, meaning that functions like jnp.array
, jnp.sin
, jnp.dot
, jnp.mean
, and many others work just as they do in standard Numpy. The key difference is that JAX arrays are immutable and are optimized for GPU execution.
This post highlights three features of JAX: Just-in-time compilation, vectorizing transformations and automatic differentiation. In a future post, I’ll walk through the forward and backward pass for a fully-connected neural network implemented entirely in JAX. Those future posts will make heavy use of the content covered here.
Just-in-time Compilation
JIT compilation in JAX speeds up computations by transforming Python functions into optimized machine code. When you apply jax.jit
or the jit
decorator to a function, JAX traces its operations and compiles them into an efficient, reusable representation. This means that instead of executing Python loops and function calls directly, JAX compiles them into a single, optimized computation graph that runs much faster on virtually any hardware.
The first time a JIT-compiled function is called, there’s a slight overhead as JAX compiles it, but subsequent calls run much faster since the compiled version is reused. JIT works best when inputs have a fixed shape and type, since changing them can trigger a recompilation. It can also be used in conjunction with grad
, vmap
, and pmap
for even greater performance gains.
To demonstrate, we’ll implement a function that computes the great circle distance between two sets of coordinate pairs. The Haversine formula is defined as
where:
(latitude difference in radians)
(longitude difference in radians)
km (globally average value of radius of the Earth in kilometers)
great-circle distance
A simple implementation of the Haversine formula using trigonometric functions from jax.numpy is provided below. I originally attempted using Python’s builtin trigonometric functions, but this caused JIT compilation to fail – the issue seemed to resolved itself when using the JAX-native variants. get_haversine
accepts an array of [lon0, lat0, lon1, lat1]
and returns the great circle distance between (lon0, lat0), (lon1, lat1):
import jax.numpy as jnp def get_haversine(coords): """ Calculate the great circle distance between two points on the earth (specified in decimal degrees). Parameters ---------- coords: array-like Array containing the longitude and latitude of two points arranged as [lon0, lat0, lon1, lat1]. Returns ------- Great circle distance in km. """ # Average Earth radius in km. R = 6371.0 lon0, lat0, lon1, lat1 = coords # Convert degree latitudes and longitudes to radians. rlon0 = jnp.radians(lon0) rlat0 = jnp.radians(lat0) rlon1 = jnp.radians(lon1) rlat1 = jnp.radians(lat1) dlon, dlat = rlon1 - rlon0, rlat1 - rlat0 a = jnp.sin(dlat / 2)**2 + jnp.cos(rlat0) * jnp.cos(rlat1) * jnp.sin(dlon / 2)**2 c = 2 * jnp.asin(jnp.sqrt(a)) return R * c
To calculate the Haversine distance in kilometers between two points, say Durkin Park on the Southside of Chicago and Nectar’s in Burlington, Vermont, simply run:
lon0, lat0 = -87.7295, 41.7390 # Durkin Park, Chicago, IL lon1, lat1 = -73.2117, 44.4762 # Nectar's, Burlington, VT # Put coordinates in JAX array. coords = jnp.array([lon0, lat0, lon1, lat1]) d = get_haversine(coords) print(f"Distance between Durkin Park and Nectar's: {d:,.0f} km")
Distance between Durkin Park and Nectar's: 1,215 km
We can JIT-compile get_haversine
and compare the run-time against the original implementation. Notice that we call get_haversine_jit
once outside of timeit
to avoid the overhead associated with the initial compilation:
from jax import jit # Create jit-compiled version of get_haversine. get_haversine_jit = jit(get_haversine) _ = get_haversine_jit([-80., 40., -85., 45.]) # compiles on first call. # Time the original. %timeit -n100 get_haversine(coords).block_until_ready() # Time the jit-compiled function. %timeit -n100 get_haversine_jit(coords).block_until_ready()
265 μs ± 38.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 3.46 μs ± 1.03 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
We realized about a 75x speed up for very little work on our end. For this example, we created a new function get_haversine_jit
so the jit-compiled version runtime could be compared against the original non-JITed version. It is possible to instead use the @jit
decorator, allowing for the original function name to be re-used:
# Using jit decorator instead. @jit def get_haversine(coords): """ Calculate the great circle distance between two points on the earth (specified in decimal degrees). Parameters ---------- coords: array-like Array containing the longitude and latitude of two points arranged as [lon0, lat0, lon1, lat1] Returns ------- Great circle distance in km. """ R = 6371.0 # Convert degree latitudes and longitudes to radians. lon0, lat0, lon1, lat1 = coords rlon0 = jnp.radians(lon0) rlat0 = jnp.radians(lat0) rlon1 = jnp.radians(lon1) rlat1 = jnp.radians(lat1) dlon, dlat = rlon1 - rlon0, rlat1 - rlat0 a = jnp.sin(dlat / 2)**2 + jnp.cos(rlat0) * jnp.cos(rlat1) * jnp.sin(dlon / 2)**2 c = 2 * jnp.asin(jnp.sqrt(a)) return R * c
There are some limitations to JIT-compilation in JAX. In particular, loops, if statements, and other control flow mechanisms may not work as expected. Refer to JAX: The Sharp Bits for additional gotchas.
Vectorizing Transformations
vmap
automatically vectorizes operations, enabling the application of a function across multiple inputs without writing explicit loops. It enables batch processing while taking advantage of JAX’s optimized execution. Instead of using loops, vmap
efficiently maps a function over an array along a pre-specified axis.
To demonstrate, I’ll apply vmap
to get_haversine
, allowing it to accept coordinate arrays of shape n x 4 as opposed to 1 x 4. We will generate a random coordinate array of 10,000 x 4 using JAX’s random generator utilities.
In JAX, random number generation is handled a bit differently than in Numpy to ensure functional purity. JAX uses explicit PRNG keys to generate random numbers instead of relying on global state. A “key” is a special array that acts as a seed, and every time you use it, JAX produces the same random numbers for the same key.
Since JAX enforces immutability, you can’t reuse a key for multiple random calls without getting the same result. Instead, you split a key using jax.random.split
, which deterministically generates new, unique keys from the original one. Each split key is independent, allowing for the generation of different random numbers while maintaining reproducibility. This approach makes JAX’s random functions fully compatible with its JIT compilation and parallelization features.
In the next cell, we create a 10,000 x 4 array of random coordinate pairs. We are interested in computing the Haversine distance for each pair of coordinates, but don’t want to rewrite get_haversine
to process more than a single pair of points at a time.
import jax.random as random # Create a 1,000,000 x 4 array of random longitudes and latitudes. # Longitudes are in the range -175 to 175. # Latitudes are in the range -85 to 85. n = 10_000 # Seed for reproducibility. Split key for different random sequences. key = random.PRNGKey(516) keys = random.split(key, 4) lon0 = random.uniform(keys[0], shape=(n,), minval=-175., maxval=175.) lat0 = random.uniform(keys[1], shape=(n,), minval=-85., maxval=85.) lon1 = random.uniform(keys[2], shape=(n,), minval=-175., maxval=175.) lat1 = random.uniform(keys[3], shape=(n,), minval=-85., maxval=85.) coords = jnp.stack([lon0, lat0, lon1, lat1], axis=1) # Shape (n, 4) print(f"coords.shape: {coords.shape}")
coords.shape: (10000, 4)
Then applying vectorization to get_haversine
is as simple as wrapping the original function with vmap
:
from jax import vmap # Vectorize get_haversine. get_haversine_vmap = vmap(get_haversine) # Calculate distances between 10k coordinate pairs. d = get_haversine_vmap(coords) d[:10]
Array([ 8697.932 , 6449.9453, 8237.629 , 7416.7593, 9463.392 , 3435.566 , 8059.4575, 10055.319 , 16480.527 , 6943.8413], dtype=float32)
Not surprisingly, vectorization provides a massive speedup vs. native looping:
%timeit -n1 for c in coords: get_haversine(c) %timeit -n1 get_haversine_vmap(coords)
2.42 s ± 20.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 5.45 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
It is also possible to combine just-in-time compilation along with vectorized transformations for additional performance gains.
Refer to the JAX documentation on automatic vectorization for more advanced use cases of vmap
, specifically how to apply a vectorized transformation along a specific axis of a multi-dimensional array.
Automatic Differentiation
Automatic differentiation (autodiff) in JAX is a method for computing derivatives efficiently and accurately using computational differentiation. Unlike numerical differentiation which relies on finite differences and can be prone to errors, or symbolic differentiation, which can become computationally expensive, autodiff in JAX works by systematically applying the chain rule at a computational level.
JAX provides grad
for computing gradients of scalar-valued functions, jacfwd
and jacrev
for Jacobians, and hessian
for second-order derivatives. It uses forward-mode autodiff for computing derivatives of functions with a small number of inputs, while reverse-mode autodiff is well-suited for functions with many inputs but a single output, which is ideal for deep learning applications.
As a simple example of using grad
for a scalar-valued function, given a continuous random variable with CDF
, the PDF
is obtained by differentiating
:
For the exponential distribution, the CDF and PDF are given by
The exact value of the exponential PDF can be compared with the result returned by grad
applied to the CDF to verify they are the same. The result will also be compared against the PDF at a given value of returned by
scipy.stats.expon
. For the purposes of demonstration, we set , which is hard-coded within
expon_cdf
and expon_pdf
:
from scipy.stats import expon def expon_cdf(x): """ Exponential distribution CDF. """ return 1 - jnp.exp(-.10 * x) def expon_pdf(x): """ Exponential distribution PDF. """ return .10 * jnp.exp(-.10 * x) # Exponential distribution with mean 10. r = expon(scale=1/10)
In order to obtain the derivative of expon_cdf
using JAX, pass expon_cdf
into grad
. The result is a callable that can accept any scalar value on and will return the exponential PDF at that point:
from jax import grad # Compute derivative of exponential CDF. jax_expon_pdf = grad(expon_cdf)
Comparing the analytical PDF, jax_expon_pdf
and the Scipy-generated PDF evaluated at 4.5:
v0 = expon_pdf(4.5) v1 = jax_expon_pdf(4.5) v2 = r.pdf(4.5) print(f"Exact PDF : {v0:.8f}") print(f"JAX PDF : {v1:.8f}") print(f"Scipy PDF : {v2:.8f}")
Exact PDF : 0.063763 JAX PDF : 0.063763 Scipy PDF : 0.063763
A particularly useful feature of grad
is that we can pass jax_expon_pdf
into grad
and obtain the second derivative of the exponential CDF. Again we compare the JAX result against the exact analytical solution:
# Analytical solution for comparison. def expon_cdf_second_deriv(x): """ Second derivative of exponential CDF. """ return -(.10**2) * jnp.exp(-.10 * x) # Compute second derivative of exponential CDF using grad. jax_expon_cdf_second_deriv = grad(jax_expon_pdf) v0 = expon_cdf_second_deriv(4.5) v1 = jax_expon_cdf_second_deriv(4.5) print(f"Exact second derivative : {v0:.8f}") print(f"JAX second derivative : {v1:.8f}")
Exact second derivative : -0.00637628 JAX second derivative : -0.00637628
It is not an exaggeration to say vmap
, jit
, and grad
have transformed my machine learning workflows. Vectorization without loops, lightning-fast compilation and flexible gradient computation let me build cleaner, faster models with less code. I’m finding new applications all the time and will continue to explore additional ways to leverage JAX and related libraries like flax.
Want to share your content on python-bloggers? click here.