MENU

Up and Running with JAX – JIT Compilation, Vectorizing Transformations and autodiff

This article was first published on The Pleasure of Finding Things Out: A blog by James Triveri , and kindly contributed to python-bloggers. (You can report issue about the content on this page here)
Want to share your content on python-bloggers? click here.

I first learned about JAX a few years back in an article on functorch. Functorch was a library that brought JAX-like composable function transformations to PyTorch, initially developed as a separate library but has since been fully integrated into PyTorch’s core (as of PyTorch 2.0).

I’ve recently invested time in learning JAX, which has proven incredibly worthwhile. The clean functional approach makes my code more maintainable and reproducible, while delivering significant performance and efficiency improvements with surprisingly minimal changes to existing codebases.

So, what is JAX? It is a high-performance numerical computing library developed by Google Research. It combines the ease of use of Python and Numpy with the speed and efficiency of XLA (Accelerated Linear Algebra), making it particularly well-suited for machine learning research and numerical computing that requires high performance.

At its core, JAX extends Numpy’s functionality with automatic differentiation capabilities. This is essential for gradient-based optimization in machine learning. JAX also excels at just-in-time compilation, which translates Python functions into optimized machine code at runtime.

JAX takes a functional programming approach (no side-effects), emphasizing immutability and pure functions. Operations don’t modify their inputs but instead return new values. This is particularly valuable for numerical computations since it enables better parallelization and optimization. Rather than changing arrays in place, JAX functions create new arrays with updated values, resulting in code that is more composable and reproducible. As highlighted in JAX: The Sharp Bits:

JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.

Of the JAX features I researched,

vmap
vmap is the most readily applicable to the work I do.
vmap
vmap is a vectorizing transformation that automatically adds a batch dimension to calculations. It stands for “vectorized map” and lets you run a function across multiple inputs in a vectorized fashion without explicitly writing code for batch processing. This enables writing simple, single example functions, while simultaneously taking advantage of the performance benefits resulting from vectorized execution.

The JAX numpy submodule can often be used as drop-in replacement for Numpy since the API is almost identical. It provides the same API as Numpy, meaning that functions like

jnp.array
jnp.array,
jnp.sin
jnp.sin,
jnp.dot
jnp.dot,
jnp.mean
jnp.mean, and many others work just as they do in standard Numpy. The key difference is that JAX arrays are immutable and are optimized for GPU execution.

This post highlights three features of JAX: Just-in-time compilation, vectorizing transformations and automatic differentiation. In a future post, I’ll walk through the forward and backward pass for a fully-connected neural network implemented entirely in JAX. Those future posts will make heavy use of the content covered here.

Just-in-time Compilation

JIT compilation in JAX speeds up computations by transforming Python functions into optimized machine code. When you apply

jax.jit
jax.jit or the
jit
jit decorator to a function, JAX traces its operations and compiles them into an efficient, reusable representation. This means that instead of executing Python loops and function calls directly, JAX compiles them into a single, optimized computation graph that runs much faster on virtually any hardware.

The first time a JIT-compiled function is called, there’s a slight overhead as JAX compiles it, but subsequent calls run much faster since the compiled version is reused. JIT works best when inputs have a fixed shape and type, since changing them can trigger a recompilation. It can also be used in conjunction with

grad
grad,
vmap
vmap, and
pmap
pmap for even greater performance gains.

To demonstrate, we’ll implement a function that computes the great circle distance between two sets of coordinate pairs. The Haversine formula is defined as

where:

  • (latitude difference in radians)
  • (longitude difference in radians)
  • km (globally average value of radius of the Earth in kilometers)
  • great-circle distance

A simple implementation of the Haversine formula using trigonometric functions from jax.numpy is provided below. I originally attempted using Python’s builtin trigonometric functions, but this caused JIT compilation to fail – the issue seemed to resolved itself when using the JAX-native variants.

get_haversine
get_haversine accepts an array of
[lon0, lat0, lon1, lat1]
[lon0, lat0, lon1, lat1] and returns the great circle distance between (lon0, lat0), (lon1, lat1):

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
import jax.numpy as jnp
def get_haversine(coords):
"""
Calculate the great circle distance between two points on the earth
(specified in decimal degrees).
Parameters
----------
coords: array-like
Array containing the longitude and latitude of two points
arranged as [lon0, lat0, lon1, lat1].
Returns
-------
Great circle distance in km.
"""
# Average Earth radius in km.
R = 6371.0
lon0, lat0, lon1, lat1 = coords
# Convert degree latitudes and longitudes to radians.
rlon0 = jnp.radians(lon0)
rlat0 = jnp.radians(lat0)
rlon1 = jnp.radians(lon1)
rlat1 = jnp.radians(lat1)
dlon, dlat = rlon1 - rlon0, rlat1 - rlat0
a = jnp.sin(dlat / 2)**2 + jnp.cos(rlat0) * jnp.cos(rlat1) * jnp.sin(dlon / 2)**2
c = 2 * jnp.asin(jnp.sqrt(a))
return R * c
import jax.numpy as jnp def get_haversine(coords): """ Calculate the great circle distance between two points on the earth (specified in decimal degrees). Parameters ---------- coords: array-like Array containing the longitude and latitude of two points arranged as [lon0, lat0, lon1, lat1]. Returns ------- Great circle distance in km. """ # Average Earth radius in km. R = 6371.0 lon0, lat0, lon1, lat1 = coords # Convert degree latitudes and longitudes to radians. rlon0 = jnp.radians(lon0) rlat0 = jnp.radians(lat0) rlon1 = jnp.radians(lon1) rlat1 = jnp.radians(lat1) dlon, dlat = rlon1 - rlon0, rlat1 - rlat0 a = jnp.sin(dlat / 2)**2 + jnp.cos(rlat0) * jnp.cos(rlat1) * jnp.sin(dlon / 2)**2 c = 2 * jnp.asin(jnp.sqrt(a)) return R * c
import jax.numpy as jnp


def get_haversine(coords):
    """
    Calculate the great circle distance between two points on the earth 
    (specified in decimal degrees).

    Parameters
    ----------
    coords: array-like
        Array containing the longitude and latitude of two points
        arranged as [lon0, lat0, lon1, lat1].
   
    Returns
    -------
    Great circle distance in km.
    """

    # Average Earth radius in km.
    R = 6371.0

    lon0, lat0, lon1, lat1 = coords

    # Convert degree latitudes and longitudes to radians.
    rlon0 = jnp.radians(lon0)
    rlat0 = jnp.radians(lat0)
    rlon1 = jnp.radians(lon1)
    rlat1 = jnp.radians(lat1)
    dlon, dlat = rlon1 - rlon0, rlat1 - rlat0
    a = jnp.sin(dlat / 2)**2 + jnp.cos(rlat0) * jnp.cos(rlat1) * jnp.sin(dlon / 2)**2
    c = 2 * jnp.asin(jnp.sqrt(a))
    return R * c

To calculate the Haversine distance in kilometers between two points, say Durkin Park on the Southside of Chicago and Nectar’s in Burlington, Vermont, simply run:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
lon0, lat0 = -87.7295, 41.7390 # Durkin Park, Chicago, IL
lon1, lat1 = -73.2117, 44.4762 # Nectar's, Burlington, VT
# Put coordinates in JAX array.
coords = jnp.array([lon0, lat0, lon1, lat1])
d = get_haversine(coords)
print(f"Distance between Durkin Park and Nectar's: {d:,.0f} km")
lon0, lat0 = -87.7295, 41.7390 # Durkin Park, Chicago, IL lon1, lat1 = -73.2117, 44.4762 # Nectar's, Burlington, VT # Put coordinates in JAX array. coords = jnp.array([lon0, lat0, lon1, lat1]) d = get_haversine(coords) print(f"Distance between Durkin Park and Nectar's: {d:,.0f} km")
lon0, lat0 = -87.7295, 41.7390 # Durkin Park, Chicago, IL
lon1, lat1 = -73.2117, 44.4762 # Nectar's, Burlington, VT

# Put coordinates in JAX array.
coords = jnp.array([lon0, lat0, lon1, lat1])

d = get_haversine(coords)

print(f"Distance between Durkin Park and Nectar's: {d:,.0f} km")
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
Distance between Durkin Park and Nectar's: 1,215 km
Distance between Durkin Park and Nectar's: 1,215 km
Distance between Durkin Park and Nectar's: 1,215 km

We can JIT-compile

get_haversine
get_haversine and compare the run-time against the original implementation. Notice that we call
get_haversine_jit
get_haversine_jit once outside of
timeit
timeit to avoid the overhead associated with the initial compilation:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
from jax import jit
# Create jit-compiled version of get_haversine.
get_haversine_jit = jit(get_haversine)
_ = get_haversine_jit([-80., 40., -85., 45.]) # compiles on first call.
# Time the original.
%timeit -n100 get_haversine(coords).block_until_ready()
# Time the jit-compiled function.
%timeit -n100 get_haversine_jit(coords).block_until_ready()
from jax import jit # Create jit-compiled version of get_haversine. get_haversine_jit = jit(get_haversine) _ = get_haversine_jit([-80., 40., -85., 45.]) # compiles on first call. # Time the original. %timeit -n100 get_haversine(coords).block_until_ready() # Time the jit-compiled function. %timeit -n100 get_haversine_jit(coords).block_until_ready()
from jax import jit

# Create jit-compiled version of get_haversine. 
get_haversine_jit = jit(get_haversine)
_ = get_haversine_jit([-80., 40., -85., 45.])  # compiles on first call.

# Time the original.
%timeit -n100 get_haversine(coords).block_until_ready()

# Time the jit-compiled function.
%timeit -n100 get_haversine_jit(coords).block_until_ready()
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
265 μs ± 38.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.46 μs ± 1.03 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
265 μs ± 38.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 3.46 μs ± 1.03 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
265 μs ± 38.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.46 μs ± 1.03 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

We realized about a 75x speed up for very little work on our end. For this example, we created a new function

get_haversine_jit
get_haversine_jit so the jit-compiled version runtime could be compared against the original non-JITed version. It is possible to instead use the
@jit
@jit decorator, allowing for the original function name to be re-used:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
# Using jit decorator instead.
@jit
def get_haversine(coords):
"""
Calculate the great circle distance between two points on the earth
(specified in decimal degrees).
Parameters
----------
coords: array-like
Array containing the longitude and latitude of two points
arranged as [lon0, lat0, lon1, lat1]
Returns
-------
Great circle distance in km.
"""
R = 6371.0
# Convert degree latitudes and longitudes to radians.
lon0, lat0, lon1, lat1 = coords
rlon0 = jnp.radians(lon0)
rlat0 = jnp.radians(lat0)
rlon1 = jnp.radians(lon1)
rlat1 = jnp.radians(lat1)
dlon, dlat = rlon1 - rlon0, rlat1 - rlat0
a = jnp.sin(dlat / 2)**2 + jnp.cos(rlat0) * jnp.cos(rlat1) * jnp.sin(dlon / 2)**2
c = 2 * jnp.asin(jnp.sqrt(a))
return R * c
# Using jit decorator instead. @jit def get_haversine(coords): """ Calculate the great circle distance between two points on the earth (specified in decimal degrees). Parameters ---------- coords: array-like Array containing the longitude and latitude of two points arranged as [lon0, lat0, lon1, lat1] Returns ------- Great circle distance in km. """ R = 6371.0 # Convert degree latitudes and longitudes to radians. lon0, lat0, lon1, lat1 = coords rlon0 = jnp.radians(lon0) rlat0 = jnp.radians(lat0) rlon1 = jnp.radians(lon1) rlat1 = jnp.radians(lat1) dlon, dlat = rlon1 - rlon0, rlat1 - rlat0 a = jnp.sin(dlat / 2)**2 + jnp.cos(rlat0) * jnp.cos(rlat1) * jnp.sin(dlon / 2)**2 c = 2 * jnp.asin(jnp.sqrt(a)) return R * c
# Using jit decorator instead.

@jit
def get_haversine(coords):
    """
    Calculate the great circle distance between two points on the earth 
    (specified in decimal degrees).

    Parameters
    ----------
    coords: array-like
        Array containing the longitude and latitude of two points
        arranged as [lon0, lat0, lon1, lat1]
   
    Returns
    -------
    Great circle distance in km.
    """
    R = 6371.0

    # Convert degree latitudes and longitudes to radians.
    lon0, lat0, lon1, lat1 = coords
    rlon0 = jnp.radians(lon0)
    rlat0 = jnp.radians(lat0)
    rlon1 = jnp.radians(lon1)
    rlat1 = jnp.radians(lat1)
    dlon, dlat = rlon1 - rlon0, rlat1 - rlat0
    a = jnp.sin(dlat / 2)**2 + jnp.cos(rlat0) * jnp.cos(rlat1) * jnp.sin(dlon / 2)**2
    c = 2 * jnp.asin(jnp.sqrt(a))
    return R * c

There are some limitations to JIT-compilation in JAX. In particular, loops, if statements, and other control flow mechanisms may not work as expected. Refer to JAX: The Sharp Bits for additional gotchas.

Vectorizing Transformations

vmap
vmap automatically vectorizes operations, enabling the application of a function across multiple inputs without writing explicit loops. It enables batch processing while taking advantage of JAX’s optimized execution. Instead of using loops,
vmap
vmap efficiently maps a function over an array along a pre-specified axis.

To demonstrate, I’ll apply

vmap
vmap to
get_haversine
get_haversine, allowing it to accept coordinate arrays of shape n x 4 as opposed to 1 x 4. We will generate a random coordinate array of 10,000 x 4 using JAX’s random generator utilities.

In JAX, random number generation is handled a bit differently than in Numpy to ensure functional purity. JAX uses explicit PRNG keys to generate random numbers instead of relying on global state. A “key” is a special array that acts as a seed, and every time you use it, JAX produces the same random numbers for the same key.

Since JAX enforces immutability, you can’t reuse a key for multiple random calls without getting the same result. Instead, you split a key using

jax.random.split
jax.random.split, which deterministically generates new, unique keys from the original one. Each split key is independent, allowing for the generation of different random numbers while maintaining reproducibility. This approach makes JAX’s random functions fully compatible with its JIT compilation and parallelization features.

In the next cell, we create a 10,000 x 4 array of random coordinate pairs. We are interested in computing the Haversine distance for each pair of coordinates, but don’t want to rewrite

get_haversine
get_haversine to process more than a single pair of points at a time.

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
import jax.random as random
# Create a 1,000,000 x 4 array of random longitudes and latitudes.
# Longitudes are in the range -175 to 175.
# Latitudes are in the range -85 to 85.
n = 10_000
# Seed for reproducibility. Split key for different random sequences.
key = random.PRNGKey(516)
keys = random.split(key, 4)
lon0 = random.uniform(keys[0], shape=(n,), minval=-175., maxval=175.)
lat0 = random.uniform(keys[1], shape=(n,), minval=-85., maxval=85.)
lon1 = random.uniform(keys[2], shape=(n,), minval=-175., maxval=175.)
lat1 = random.uniform(keys[3], shape=(n,), minval=-85., maxval=85.)
coords = jnp.stack([lon0, lat0, lon1, lat1], axis=1) # Shape (n, 4)
print(f"coords.shape: {coords.shape}")
import jax.random as random # Create a 1,000,000 x 4 array of random longitudes and latitudes. # Longitudes are in the range -175 to 175. # Latitudes are in the range -85 to 85. n = 10_000 # Seed for reproducibility. Split key for different random sequences. key = random.PRNGKey(516) keys = random.split(key, 4) lon0 = random.uniform(keys[0], shape=(n,), minval=-175., maxval=175.) lat0 = random.uniform(keys[1], shape=(n,), minval=-85., maxval=85.) lon1 = random.uniform(keys[2], shape=(n,), minval=-175., maxval=175.) lat1 = random.uniform(keys[3], shape=(n,), minval=-85., maxval=85.) coords = jnp.stack([lon0, lat0, lon1, lat1], axis=1) # Shape (n, 4) print(f"coords.shape: {coords.shape}")
import jax.random as random

# Create a 1,000,000 x 4 array of random longitudes and latitudes.
# Longitudes are in the range -175 to 175.
# Latitudes are in the range -85 to 85.
n = 10_000

# Seed for reproducibility. Split key for different random sequences.
key = random.PRNGKey(516)  
keys = random.split(key, 4)

lon0 = random.uniform(keys[0], shape=(n,), minval=-175., maxval=175.)
lat0 = random.uniform(keys[1], shape=(n,), minval=-85., maxval=85.)
lon1 = random.uniform(keys[2], shape=(n,), minval=-175., maxval=175.)
lat1 = random.uniform(keys[3], shape=(n,), minval=-85., maxval=85.)
coords = jnp.stack([lon0, lat0, lon1, lat1], axis=1)  # Shape (n, 4)

print(f"coords.shape: {coords.shape}")
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
coords.shape: (10000, 4)
coords.shape: (10000, 4)
coords.shape: (10000, 4)

Then applying vectorization to

get_haversine
get_haversine is as simple as wrapping the original function with
vmap
vmap:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
from jax import vmap
# Vectorize get_haversine.
get_haversine_vmap = vmap(get_haversine)
# Calculate distances between 10k coordinate pairs.
d = get_haversine_vmap(coords)
d[:10]
from jax import vmap # Vectorize get_haversine. get_haversine_vmap = vmap(get_haversine) # Calculate distances between 10k coordinate pairs. d = get_haversine_vmap(coords) d[:10]
from jax import vmap

# Vectorize get_haversine.
get_haversine_vmap = vmap(get_haversine)

# Calculate distances between 10k coordinate pairs.
d = get_haversine_vmap(coords)

d[:10]
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
Array([ 8697.932 , 6449.9453, 8237.629 , 7416.7593, 9463.392 ,
3435.566 , 8059.4575, 10055.319 , 16480.527 , 6943.8413], dtype=float32)
Array([ 8697.932 , 6449.9453, 8237.629 , 7416.7593, 9463.392 , 3435.566 , 8059.4575, 10055.319 , 16480.527 , 6943.8413], dtype=float32)
Array([ 8697.932 ,  6449.9453,  8237.629 ,  7416.7593,  9463.392 ,
        3435.566 ,  8059.4575, 10055.319 , 16480.527 ,  6943.8413],      dtype=float32)

Not surprisingly, vectorization provides a massive speedup vs. native looping:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
%timeit -n1 for c in coords: get_haversine(c)
%timeit -n1 get_haversine_vmap(coords)
%timeit -n1 for c in coords: get_haversine(c) %timeit -n1 get_haversine_vmap(coords)
%timeit -n1 for c in coords: get_haversine(c)

%timeit -n1 get_haversine_vmap(coords)
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
2.42 s ± 20.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5.45 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.42 s ± 20.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 5.45 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.42 s ± 20.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5.45 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

It is also possible to combine just-in-time compilation along with vectorized transformations for additional performance gains.

Refer to the JAX documentation on automatic vectorization for more advanced use cases of

vmap
vmap, specifically how to apply a vectorized transformation along a specific axis of a multi-dimensional array.

Automatic Differentiation

Automatic differentiation (autodiff) in JAX is a method for computing derivatives efficiently and accurately using computational differentiation. Unlike numerical differentiation which relies on finite differences and can be prone to errors, or symbolic differentiation, which can become computationally expensive, autodiff in JAX works by systematically applying the chain rule at a computational level.

JAX provides

grad
grad for computing gradients of scalar-valued functions,
jacfwd
jacfwd and
jacrev
jacrev for Jacobians, and
hessian
hessian for second-order derivatives. It uses forward-mode autodiff for computing derivatives of functions with a small number of inputs, while reverse-mode autodiff is well-suited for functions with many inputs but a single output, which is ideal for deep learning applications.

As a simple example of using

grad
grad for a scalar-valued function, given a continuous random variable with CDF , the PDF is obtained by differentiating :

For the exponential distribution, the CDF and PDF are given by

The exact value of the exponential PDF can be compared with the result returned by

grad
grad applied to the CDF to verify they are the same. The result will also be compared against the PDF at a given value of returned by
scipy.stats.expon
scipy.stats.expon. For the purposes of demonstration, we set , which is hard-coded within
expon_cdf
expon_cdf and
expon_pdf
expon_pdf:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
from scipy.stats import expon
def expon_cdf(x):
"""
Exponential distribution CDF.
"""
return 1 - jnp.exp(-.10 * x)
def expon_pdf(x):
"""
Exponential distribution PDF.
"""
return .10 * jnp.exp(-.10 * x)
# Exponential distribution with mean 10.
r = expon(scale=1/10)
from scipy.stats import expon def expon_cdf(x): """ Exponential distribution CDF. """ return 1 - jnp.exp(-.10 * x) def expon_pdf(x): """ Exponential distribution PDF. """ return .10 * jnp.exp(-.10 * x) # Exponential distribution with mean 10. r = expon(scale=1/10)
from scipy.stats import expon


def expon_cdf(x):
    """
    Exponential distribution CDF.
    """
    return 1 - jnp.exp(-.10 * x)


def expon_pdf(x):
    """
    Exponential distribution PDF.
    """

    return .10 * jnp.exp(-.10 * x)

# Exponential distribution with mean 10. 
r = expon(scale=1/10)

In order to obtain the derivative of

expon_cdf
expon_cdf using JAX, pass
expon_cdf
expon_cdf into
grad
grad. The result is a callable that can accept any scalar value on and will return the exponential PDF at that point:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
from jax import grad
# Compute derivative of exponential CDF.
jax_expon_pdf = grad(expon_cdf)
from jax import grad # Compute derivative of exponential CDF. jax_expon_pdf = grad(expon_cdf)
from jax import grad

# Compute derivative of exponential CDF.
jax_expon_pdf = grad(expon_cdf)

Comparing the analytical PDF,

jax_expon_pdf
jax_expon_pdf and the Scipy-generated PDF evaluated at 4.5:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
v0 = expon_pdf(4.5)
v1 = jax_expon_pdf(4.5)
v2 = r.pdf(4.5)
print(f"Exact PDF : {v0:.8f}")
print(f"JAX PDF : {v1:.8f}")
print(f"Scipy PDF : {v2:.8f}")
v0 = expon_pdf(4.5) v1 = jax_expon_pdf(4.5) v2 = r.pdf(4.5) print(f"Exact PDF : {v0:.8f}") print(f"JAX PDF : {v1:.8f}") print(f"Scipy PDF : {v2:.8f}")
v0 = expon_pdf(4.5)
v1 = jax_expon_pdf(4.5)
v2 = r.pdf(4.5)

print(f"Exact PDF : {v0:.8f}")
print(f"JAX PDF   : {v1:.8f}")
print(f"Scipy PDF : {v2:.8f}")
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
Exact PDF : 0.063763
JAX PDF : 0.063763
Scipy PDF : 0.063763
Exact PDF : 0.063763 JAX PDF : 0.063763 Scipy PDF : 0.063763
Exact PDF : 0.063763
JAX PDF   : 0.063763
Scipy PDF : 0.063763

A particularly useful feature of

grad
grad is that we can pass
jax_expon_pdf
jax_expon_pdf into
grad
grad and obtain the second derivative of the exponential CDF. Again we compare the JAX result against the exact analytical solution:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
# Analytical solution for comparison.
def expon_cdf_second_deriv(x):
"""
Second derivative of exponential CDF.
"""
return -(.10**2) * jnp.exp(-.10 * x)
# Compute second derivative of exponential CDF using grad.
jax_expon_cdf_second_deriv = grad(jax_expon_pdf)
v0 = expon_cdf_second_deriv(4.5)
v1 = jax_expon_cdf_second_deriv(4.5)
print(f"Exact second derivative : {v0:.8f}")
print(f"JAX second derivative : {v1:.8f}")
# Analytical solution for comparison. def expon_cdf_second_deriv(x): """ Second derivative of exponential CDF. """ return -(.10**2) * jnp.exp(-.10 * x) # Compute second derivative of exponential CDF using grad. jax_expon_cdf_second_deriv = grad(jax_expon_pdf) v0 = expon_cdf_second_deriv(4.5) v1 = jax_expon_cdf_second_deriv(4.5) print(f"Exact second derivative : {v0:.8f}") print(f"JAX second derivative : {v1:.8f}")
# Analytical solution for comparison.
def expon_cdf_second_deriv(x):
    """
    Second derivative of exponential CDF.
    """
    return -(.10**2) * jnp.exp(-.10 * x)


# Compute second derivative of exponential CDF using grad.
jax_expon_cdf_second_deriv = grad(jax_expon_pdf)

v0 = expon_cdf_second_deriv(4.5)
v1 = jax_expon_cdf_second_deriv(4.5)

print(f"Exact second derivative : {v0:.8f}")
print(f"JAX second derivative   : {v1:.8f}")
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
Exact second derivative : -0.00637628
JAX second derivative : -0.00637628
Exact second derivative : -0.00637628 JAX second derivative : -0.00637628
Exact second derivative : -0.00637628
JAX second derivative   : -0.00637628

It is not an exaggeration to say

vmap
vmap,
jit
jit, and
grad
grad have transformed my machine learning workflows. Vectorization without loops, lightning-fast compilation and flexible gradient computation let me build cleaner, faster models with less code. I’m finding new applications all the time and will continue to explore additional ways to leverage JAX and related libraries like flax.

To leave a comment for the author, please follow the link and comment on their blog: The Pleasure of Finding Things Out: A blog by James Triveri .

Want to share your content on python-bloggers? click here.