Up and Running with JAX – Backpropagation and Training Neural Networks
Want to share your content on python-bloggers? click here.
In the third and final installment of the Up and Running with JAX series, we demonstrate the remaining steps required to train and evaluate a simple neural network, specifically the implementation of the loss function, backward pass and training loop. As in Part 2, the focus will be on predicting class labels for the MNIST dataset, which consists of 28×28 pixel images of handwritten digits (0-9). The training loop consists of the following steps:
- Load a batch of training data.
- Obtain model predictions for current batch of images.
- Calculate the loss for current batch predictions vs. targets.
- Calculate backward gradients over the weights and biases.
- Update the weights and biases using the gradient information.
- Calculate the loss on a set of data that not used for training.
We begin by loading the dataset and functions implemented in Part 2 that facilitate weight initialization and the network forward pass:
import warnings
import numpy as np
import pandas as pd
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.transforms import v2
import matplotlib.pyplot as plt
np.set_printoptions(suppress=True, precision=5, linewidth=1000)
pd.options.mode.chained_assignment = None
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option("display.precision", 5)
warnings.filterwarnings("ignore")
# Batch size.
bs = 64
train_data = datasets.MNIST(
root="data",
train=True,
download=True,
transform=v2.Compose([ToTensor()])
)
valid_data = datasets.MNIST(
root="data",
train=False,
download=True,
transform=v2.Compose([ToTensor()])
)
# Convert PIL images to NumPy arrays.
train_data_arr = train_data.data.numpy() / 255.0 # Normalize pixel values to [0, 1]
valid_data_arr = valid_data.data.numpy() / 255.0 # Normalize pixel values to [0, 1]
train_data_arr = train_data_arr.reshape(-1, 28 * 28) # Flatten images to 1D arrays
valid_data_arr = valid_data_arr.reshape(-1, 28 * 28) # Flatten images to 1D arrays
train_labels = train_data.targets.numpy()
valid_labels = valid_data.targets.numpy()
# Create training and validation batches of 64.
train_batches = [
(train_data_arr[(bs * ii):(bs * (ii + 1))], train_labels[(bs * ii):(bs * (ii + 1))])
for ii in range(len(train_data_arr) // bs)
]
valid_batches = [
(valid_data_arr[(bs * ii):(bs * (ii + 1))], valid_labels[(bs * ii):(bs * (ii + 1))])
for ii in range(len(valid_data_arr) // bs)
]
print(f"train_data_arr.shape: {train_data_arr.shape}")
print(f"valid_data_arr.shape: {valid_data_arr.shape}")
print(f"train_labels.shape : {train_labels.shape}")
print(f"valid_labels.shape : {valid_labels.shape}")
print(f"len(train_batches) : {len(train_batches)}")
print(f"len(valid_batches) : {len(valid_batches)}")train_data_arr.shape: (60000, 784) valid_data_arr.shape: (10000, 784) train_labels.shape : (60000,) valid_labels.shape : (10000,) len(train_batches) : 937 len(valid_batches) : 156
"""
Functions introduced in Part 2. Refer to
https://www.jtrive.com/posts/intro-to-jax-part-2/intro-to-jax-part-2.html
for more information.
"""
from jax import random, vmap
import jax.numpy as jnp
from jax.nn import relu
def initialize_weights(sizes, key, scale=.02):
"""
"Initialize weights and biases for each layer for simple fully-connected
network.
Parameters
----------
sizes : list of int
List of integers representing the number of neurons in each layer.
key : jax.random.PRNGKey
Random key for JAX.
Returns
-------
List of initialized weights and biases for each layer.
"""
keys = random.split(key, len(sizes) - 1)
params = []
for m, n, k in zip(sizes[:-1], sizes[1:], keys):
w_key, b_key = random.split(k)
w = scale * random.normal(w_key, (m, n))
b = scale * random.normal(b_key, (n,))
params.append((w, b))
return params
def forward(params, X):
"""
Forward pass for simple fully-connected network.
Parameters
----------
params : list of tuples
List of tuples containing weights and biases for each layer.
X : jax.numpy.ndarray
Input data.
Returns
-------
jax.numpy.ndarray
"""
a = X
for W, b in params[:-1]:
z = jnp.dot(a, W) + b
a = relu(z)
W, b = params[-1]
return jnp.dot(a, W) + b
# Auto-vectorization of forward pass.
batch_forward = vmap(forward, in_axes=(None, 0))Cross-Entropy Loss and Softmax
Categorical cross-entropy loss is the most commonly used loss function for multi-class classification with mutually-exclusive classes. A lower cross-entropy loss means the predicted probabilities are closer to the true labels. A key characteristic of cross entropy loss is that it rewards/penalizes the probabilities of correct classes only: The value is independent of how the remaining probability is split between the incorrect classes.
For a single sample with ( C ) classes, the cross-entropy loss is give by
where: – is the batch size. –
is the true label (1 for the correct class, 0 otherwise). –
is the predicted probability for class
(from softmax). – The
ensures the loss is large when the predicted probability is low for the correct class.
If we had a single vector of actual labels representing the index of the correct class (i.e., yact from above), simply compute the negative log of the probability at this index to get the cross entropy loss for that sample (since cross-entropy doesn’t consider incorrect classes).
We forego one-hot encoding our targets, so our loss function accepts a batch of final layer activations (logits) and targets (labels) represented as a single integer between 0 and 9 per sample. Using a batch size of 64, logits has shape (64, 10), and labels (64,):
from jax.nn import log_softmax
def cross_entropy_loss(params, X, y):
"""
Compute the loss for the given logits and labels.
Parameters
----------
params : list of tuples
List of tuples containing weights and biases for each layer.
logits :
Batch of final layer activations.
labels :
Batch of true labels, a single integer per sample.
Returns
-------
Computed loss.
"""
# Compute logits for the batch.
logits = forward(params, X)
# Convert logits to log probabilities.
log_probs = log_softmax(logits)
return -log_probs[jnp.arange(len(y)), y].mean()
The softmax function converts a vector logits into a probability distribution over classes. Logits refer to the raw, unnormalized output values produced by the last layer of a neural network before applying an activation function. It is commonly used in classification tasks. In some deep learning frameworks, cross-entropy loss is combined with softmax within a single function (see for example, CrossEntropyLoss in PyTorch). For a vector of length
, softmax is defined as:
where:
exponentiates each element.
- The denominator ensures all probabilities sum to 1.
If any of the is large,
can become extremely large, potentially causing overflow errors in computation. For example:
import numpy as np z = np.array([1000, 2000, 3000]) # Large values softmax = np.exp(z) / np.sum(np.exp(z)) # OverflowError
Since is astronomically large, Python will struggle to handle it. The solution is to subtract the max value of each sample instead of using
directly:
This shifts all values down without affecting the final probabilities (since shifting inside the exponent maintains relative differences).
Backward Pass
In order to obtain the gradients of the loss function w.r.t. the model parameters, JAX’s grad function can be used. grad computes the gradient of a scalar-valued function with respect to its inputs. It performs automatic differentiation by tracing the computation and building a backward pass to compute derivatives. grad accepts a Python function and returns a new function that computes the gradient of the original function. The returned function takes the same inputs as the original and returns the derivative w.r.t. the argument specified (the first argument by default).
Note that grad only returns the gradients of the loss function with respect to the parameters, and not the actual loss value. This is important information to have during training. We can instead use value_and_grad, which returns the actual loss value along with the gradients as a tuple. In the next cell, update implements the gradient update. I’ve included an accuracy function, which is used to evaluate model performance after each epoch:
from jax import value_and_grad
def update(params, X, y, lr=.01):
"""
Update weights and biases using gradient descent.
Parameters
----------
params : list of tuples
List of tuples containing weights and biases for each layer.
X : jax.numpy.ndarray
Input data.
y : jax.numpy.ndarray
True labels.
lr : float
Learning rate.
Returns
-------
tuple
Updated weights and biases.
"""
# Compute loss and gradients.
loss, grads = value_and_grad(cross_entropy_loss)(params, X, y)
# Unpack parameters and gradients.
(W1, b1), (W2, b2) = params
(dW1, db1), (dW2, db2) = grads
# Update weights and biases.
W1_new = W1 - lr * dW1
b1_new = b1 - lr * db1
W2_new = W2 - lr * dW2
b2_new = b2 - lr * db2
return [(W1_new, b1_new), (W2_new, b2_new)], loss
def accuracy(logits, labels):
"""
Compute accuracy.
Parameters
----------
logits : jax.numpy.ndarray
Final layer activations.
labels : jax.numpy.ndarray
True labels.
Returns
-------
float
Accuracy.
"""
preds = jnp.argmax(logits, axis=1)
return (preds == labels).mean()We have everything setup train the network. The training loop is provided in the next cell, where the network is trained for 25 epochs:
from time import perf_counter
# Layer sizes.
sizes = [784, 128, 10]
# Number of epochs.
n_epochs = 25
# Learning rate.
lr = 0.01
# Store loss, accuracy and runtime.
results = []
# Initialize weights ands biases.
params = initialize_weights(sizes, key=random.PRNGKey(516), scale=.02)
for epoch in range(n_epochs):
start_time = perf_counter()
losses = []
for X, y in train_batches:
# Compute loss.
params, loss = update(params, X, y, lr=lr)
losses.append(loss.item())
epoch_time = perf_counter() - start_time
avg_loss = np.mean(losses)
train_acc = np.mean([accuracy(forward(params, X), y).item() for X, y in train_batches])
valid_acc = np.mean([accuracy(forward(params, X), y).item() for X, y in valid_batches])
results.append((epoch + 1, avg_loss, train_acc, valid_acc, epoch_time))
print(f"Epoch {epoch + 1}/{n_epochs}: loss: {avg_loss:.4f}, train acc.: {train_acc:.3f}, valid acc.: {valid_acc:.3f}, time: {epoch_time:.2f} sec.")
Epoch 1/25: loss: 1.5768, train acc.: 0.818, valid acc.: 0.824, time: 11.59 sec. Epoch 2/25: loss: 0.5859, train acc.: 0.873, valid acc.: 0.881, time: 11.40 sec. Epoch 3/25: loss: 0.4297, train acc.: 0.891, valid acc.: 0.896, time: 15.88 sec. Epoch 4/25: loss: 0.3753, train acc.: 0.899, valid acc.: 0.904, time: 12.13 sec. Epoch 5/25: loss: 0.3464, train acc.: 0.905, valid acc.: 0.909, time: 13.94 sec. Epoch 6/25: loss: 0.3270, train acc.: 0.909, valid acc.: 0.913, time: 14.45 sec. Epoch 7/25: loss: 0.3123, train acc.: 0.913, valid acc.: 0.918, time: 13.47 sec. Epoch 8/25: loss: 0.2999, train acc.: 0.916, valid acc.: 0.920, time: 13.78 sec. Epoch 9/25: loss: 0.2891, train acc.: 0.919, valid acc.: 0.922, time: 13.89 sec. Epoch 10/25: loss: 0.2792, train acc.: 0.922, valid acc.: 0.925, time: 13.20 sec. Epoch 11/25: loss: 0.2699, train acc.: 0.925, valid acc.: 0.927, time: 13.02 sec. Epoch 12/25: loss: 0.2611, train acc.: 0.927, valid acc.: 0.930, time: 13.94 sec. Epoch 13/25: loss: 0.2529, train acc.: 0.929, valid acc.: 0.931, time: 12.08 sec. Epoch 14/25: loss: 0.2450, train acc.: 0.932, valid acc.: 0.932, time: 13.29 sec. Epoch 15/25: loss: 0.2377, train acc.: 0.934, valid acc.: 0.934, time: 12.39 sec. Epoch 16/25: loss: 0.2306, train acc.: 0.936, valid acc.: 0.936, time: 12.09 sec. Epoch 17/25: loss: 0.2240, train acc.: 0.938, valid acc.: 0.937, time: 11.55 sec. Epoch 18/25: loss: 0.2176, train acc.: 0.940, valid acc.: 0.939, time: 11.41 sec. Epoch 19/25: loss: 0.2116, train acc.: 0.941, valid acc.: 0.940, time: 11.47 sec. Epoch 20/25: loss: 0.2058, train acc.: 0.943, valid acc.: 0.941, time: 11.74 sec. Epoch 21/25: loss: 0.2003, train acc.: 0.945, valid acc.: 0.942, time: 11.55 sec. Epoch 22/25: loss: 0.1951, train acc.: 0.946, valid acc.: 0.943, time: 11.44 sec. Epoch 23/25: loss: 0.1901, train acc.: 0.948, valid acc.: 0.945, time: 11.19 sec. Epoch 24/25: loss: 0.1853, train acc.: 0.949, valid acc.: 0.946, time: 11.07 sec. Epoch 25/25: loss: 0.1807, train acc.: 0.950, valid acc.: 0.947, time: 10.66 sec.
The loss decreases as a function of epoch, and the best validation accuracy was realized on the 25th epoch. We can visualize our training metrics:
v_color = "#8fcdb5"
t_color = "#031f14"
# Unpack results list.
_, loss, tacc, vacc, epoch_time = zip(*results)
xx = range(len(loss))
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 3.5), sharex=True, tight_layout=True)
# Training loss.
ax[0].set_title("training loss", fontsize=8)
ax[0].plot(xx, loss, color="#303846")
ax[0].set_xlabel("epoch", fontsize=8)
ax[0].set_ylabel("loss", fontsize=8)
#ax[0].set_xticks(xx)
ax[0].tick_params(axis="x", which="major", direction='in', labelsize=6)
ax[0].tick_params(axis="y", which="major", direction='in', labelsize=6)
ax[0].xaxis.set_ticks_position("none")
ax[0].yaxis.set_ticks_position("none")
ax[0].grid(True)
ax[0].set_axisbelow(True)
# Training and validation accuracy.
ax[1].set_title("train + validation accuracy", fontsize=8)
ax[1].plot(xx, tacc, color="#303846", label="train acc.")
ax[1].plot(xx, vacc, color="#E02C70", label="valid acc.")
ax[1].set_xlabel("epoch", fontsize=8)
ax[1].set_ylabel("accuracy", fontsize=8)
#ax[1].set_xticks(xx)
ax[1].tick_params(axis="x", which="major", direction='in', labelsize=6)
ax[1].tick_params(axis="y", which="major", direction='in', labelsize=6)
ax[1].xaxis.set_ticks_position("none")
ax[1].yaxis.set_ticks_position("none")
ax[1].grid(True)
ax[1].set_axisbelow(True)
ax[1].legend(loc="lower right", fontsize=8, frameon=False)
# Runtime.
ax[2].set_title("epoch runtime", fontsize=8)
ax[2].bar(xx, epoch_time, color=t_color)
ax[2].set_xlabel("epoch", fontsize=8)
ax[2].set_ylabel("runtime", fontsize=8)
#ax[2].set_xticks(xx)
ax[2].tick_params(axis="x", which="major", direction='in', labelsize=6)
ax[2].tick_params(axis="y", which="major", direction='in', labelsize=6)
ax[2].xaxis.set_ticks_position("none")
ax[2].yaxis.set_ticks_position("none")
ax[2].grid(True)
ax[2].set_axisbelow(True)
plt.show()

Given the shape of the training and validation accuracy curves, it’s likely that the network still had room to improve, and with additional epochs would almost certainly have achieved even better performance.
JIT Compilation
On average, it took around 13 seconds for one full pass through the data using CPU. We can reduce the runtime drastically by just-in-time compiling the update function. Recall from the first installment of the series that Just-In-Time (JIT) compilation in JAX refers to the process of transforming a Python function into highly optimized, low-level code (usually XLA-compiled) that runs much faster. This can be accomplished using the @jit decorator. update now becomes:
from jax import jit
@jit
def update(params, X, y, lr=.01):
"""
Update weights and biases using gradient descent.
Parameters
----------
params : list of tuples
List of tuples containing weights and biases for each layer.
X : jax.numpy.ndarray
Input data.
y : jax.numpy.ndarray
True labels.
lr : float
Learning rate.
Returns
-------
tuple
Updated weights and biases.
"""
# Compute loss and gradients.
loss, grads = value_and_grad(cross_entropy_loss)(params, X, y)
# Unpack parameters and gradients.
(W1, b1), (W2, b2) = params
(dW1, db1), (dW2, db2) = grads
# Update weights and biases.
W1_new = W1 - lr * dW1
b1_new = b1 - lr * db1
W2_new = W2 - lr * dW2
b2_new = b2 - lr * db2
return [(W1_new, b1_new), (W2_new, b2_new)], lossLet’s retrain the network and assess the impact JIT compilation has on per-epoch training time:
"""
Same training loop as before, but now using JIT compilation.
"""
from time import perf_counter
# Layer sizes.
sizes = [784, 128, 10]
# Number of epochs.
n_epochs = 25
# Learning rate.
lr = 0.01
# Store loss, accuracy and runtime.
results = []
# Initialize weights ands biases.
params = initialize_weights(sizes, key=random.PRNGKey(516), scale=.02)
for epoch in range(n_epochs):
start_time = perf_counter()
losses = []
for X, y in train_batches:
# Compute loss.
params, loss = update(params, X, y, lr=lr)
losses.append(loss.item())
epoch_time = perf_counter() - start_time
avg_loss = np.mean(losses)
train_acc = np.mean([accuracy(forward(params, X), y).item() for X, y in train_batches])
valid_acc = np.mean([accuracy(forward(params, X), y).item() for X, y in valid_batches])
results.append((epoch + 1, avg_loss, train_acc, valid_acc, epoch_time))
print(f"Epoch {epoch + 1}/{n_epochs}: loss: {avg_loss:.4f}, train acc.: {train_acc:.3f}, valid acc.: {valid_acc:.3f}, time: {epoch_time:.2f} sec.")
Epoch 1/25: loss: 1.5768, train acc.: 0.818, valid acc.: 0.824, time: 1.20 sec. Epoch 2/25: loss: 0.5859, train acc.: 0.873, valid acc.: 0.881, time: 0.85 sec. Epoch 3/25: loss: 0.4297, train acc.: 0.891, valid acc.: 0.896, time: 0.84 sec. Epoch 4/25: loss: 0.3753, train acc.: 0.899, valid acc.: 0.904, time: 0.83 sec. Epoch 5/25: loss: 0.3464, train acc.: 0.905, valid acc.: 0.909, time: 0.82 sec. Epoch 6/25: loss: 0.3270, train acc.: 0.909, valid acc.: 0.913, time: 0.84 sec. Epoch 7/25: loss: 0.3123, train acc.: 0.913, valid acc.: 0.918, time: 0.85 sec. Epoch 8/25: loss: 0.2999, train acc.: 0.916, valid acc.: 0.920, time: 0.83 sec. Epoch 9/25: loss: 0.2891, train acc.: 0.919, valid acc.: 0.922, time: 0.83 sec. Epoch 10/25: loss: 0.2792, train acc.: 0.922, valid acc.: 0.925, time: 0.82 sec. Epoch 11/25: loss: 0.2699, train acc.: 0.924, valid acc.: 0.927, time: 0.85 sec. Epoch 12/25: loss: 0.2611, train acc.: 0.927, valid acc.: 0.930, time: 0.87 sec. Epoch 13/25: loss: 0.2529, train acc.: 0.929, valid acc.: 0.931, time: 0.81 sec. Epoch 14/25: loss: 0.2450, train acc.: 0.932, valid acc.: 0.932, time: 0.81 sec. Epoch 15/25: loss: 0.2377, train acc.: 0.934, valid acc.: 0.934, time: 0.83 sec. Epoch 16/25: loss: 0.2307, train acc.: 0.936, valid acc.: 0.936, time: 0.85 sec. Epoch 17/25: loss: 0.2240, train acc.: 0.938, valid acc.: 0.937, time: 0.82 sec. Epoch 18/25: loss: 0.2176, train acc.: 0.940, valid acc.: 0.939, time: 0.82 sec. Epoch 19/25: loss: 0.2116, train acc.: 0.941, valid acc.: 0.940, time: 0.87 sec. Epoch 20/25: loss: 0.2058, train acc.: 0.943, valid acc.: 0.941, time: 0.84 sec. Epoch 21/25: loss: 0.2003, train acc.: 0.945, valid acc.: 0.942, time: 0.80 sec. Epoch 22/25: loss: 0.1951, train acc.: 0.946, valid acc.: 0.943, time: 0.85 sec. Epoch 23/25: loss: 0.1901, train acc.: 0.948, valid acc.: 0.944, time: 0.84 sec. Epoch 24/25: loss: 0.1853, train acc.: 0.949, valid acc.: 0.946, time: 0.81 sec. Epoch 25/25: loss: 0.1807, train acc.: 0.950, valid acc.: 0.947, time: 0.87 sec.
By simply adding the jit decorator to the update function, the average training time per epoch dropped from around 13 seconds to under a second, with no degradation of performance. Pretty remarkable!
Conclusion
JAX is a powerful tool for deep learning because it combines NumPy-like syntax with automatic differentiation, just-in-time (JIT) compilation for performance, and seamless GPU/TPU support, all while enabling functional programming patterns that make complex model transformations and optimization easier to express. In this series, we’ve only scratched the surface of what’s possible with JAX. For those eager to explore further, I recommend Deep Learning with JAX by Grigory Sapunov, which dives into more advanced topics and real-world applications of the framework.
I’m currently spending time getting familiar with Flax, a powerful, higher-level deep learning library that makes it easier to define, train, and manage models without sacrificing flexibility or performance (you can think of JAX as the engine and Flax as the framework that helps you build with that engine). It offers tools to define neural networks, handle parameter initialization and state management, and integrates nicely with JAX’s functional approach. It’s designed to make building and training deep learning models easier and more scalable without hiding the JAX underpinnings. More on Flax in a future post.
Want to share your content on python-bloggers? click here.