Up and Running with JAX – Fully-Connected Network Forward Pass
Want to share your content on python-bloggers? click here.
In a previous post, I introduced JAX with particular emphasis on JIT compilation, vectorizing transformations and automatic differentiation. In this post, we walkthrough an implementation of the forward pass for a fully-connected neural network with the goal of classifying MNIST handwritten digits, incorporating concepts from the first post.
We begin by loading MNIST training and validation sets, convert the PIL images to Numpy arrays, and create image-label batches of size 64:
import warnings import numpy as np import pandas as pd import torch import torchvision import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import datasets from torchvision.transforms import ToTensor from torchvision.transforms import v2 import matplotlib.pyplot as plt np.set_printoptions(suppress=True, precision=5, linewidth=1000) pd.options.mode.chained_assignment = None pd.set_option('display.max_columns', None) pd.set_option('display.width', None) pd.set_option("display.precision", 5) warnings.filterwarnings("ignore") # Batch size. bs = 64 train_data = datasets.MNIST( root="data", train=True, download=True, transform=v2.Compose([ToTensor()]) ) valid_data = datasets.MNIST( root="data", train=False, download=True, transform=v2.Compose([ToTensor()]) ) # Convert PIL images to NumPy arrays. train_data_arr = train_data.data.numpy() / 255.0 # Normalize pixel values to [0, 1] valid_data_arr = valid_data.data.numpy() / 255.0 # Normalize pixel values to [0, 1] train_data_arr = train_data_arr.reshape(-1, 28 * 28) # Flatten images to 1D arrays valid_data_arr = valid_data_arr.reshape(-1, 28 * 28) # Flatten images to 1D arrays train_labels = train_data.targets.numpy() valid_labels = valid_data.targets.numpy() # Create training and validation batches of 64. train_batches = [ (train_data_arr[(bs * ii):(bs * (ii + 1))], train_labels[(bs * ii):(bs * (ii + 1))]) for ii in range(len(train_data_arr) // bs) ] valid_batches = [ (valid_data_arr[(bs * ii):(bs * (ii + 1))], valid_labels[(bs * ii):(bs * (ii + 1))]) for ii in range(len(valid_data_arr) // bs) ] print(f"train_data_arr.shape: {train_data_arr.shape}") print(f"valid_data_arr.shape: {valid_data_arr.shape}") print(f"train_labels.shape : {train_labels.shape}") print(f"valid_labels.shape : {valid_labels.shape}") print(f"len(train_batches) : {len(train_batches)}") print(f"len(valid_batches) : {len(valid_batches)}")
train_data_arr.shape: (60000, 784) valid_data_arr.shape: (10000, 784) train_labels.shape : (60000,) valid_labels.shape : (10000,) len(train_batches) : 937 len(valid_batches) : 156
We can visualize a batch of images and labels using matplotlib:
import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import ImageGrid X, y = train_batches[0] X = np.expand_dims(X.reshape(-1, 28, 28), 1) fig = plt.figure(figsize=(8., 8.), tight_layout=False) grid = ImageGrid(fig, 111, nrows_ncols=(8, 8), axes_pad=0.20) for ax, X_ii, y_ii in zip(grid, X, y): X_ii = np.transpose(X_ii, (1, 2, 0)) ax.imshow(X_ii, cmap="gray") ax.axis("off") ax.set_title(f"{y_ii}", fontsize=8) plt.show()
Our goal is to create a model that accepts a batch of 64 images, and returns a class prediction for each image in the batch. Our architecture is presented in the image below:
The pre-activations for layer are computed as
, where:
= layer
pre-activations (value prior to applying non-linearity like ReLU).
= layer
activations, with
representing the original input.
= The weight matrix for layer
.
= bias vector for layer
.
For the network shown above assuming a batch size of 64:
: Input matrix with dimension 64×784.
: Weight matrix with dimension 784×128.
Bias vector of length 128.
: Matrix of pre-activations with dimension 64×128.
: Non-linearity applied to
. Activation matrix with dimension 64×128.
: Weight matrix with dimension 128×10.
: Bias vector of length 10.
: Matrix of pre-activations with dimension 64×10.
: Non-linearity applied to
. Activation matrix with dimension 64×10.
The forward pass feeds an image of size 28×28 into the network, which produces a probability distribution over all classes. The class with the highest probability is our class prediction, which for MNIST will be one of 10 digits 0-9. Specifically:
- Each 28×28 image is flattened to have shape 1×784. The input layer has the same size as the flattened image (784,).
- The hidden layer consists of 128 neurons. The matrix of weights projecting from the input layer to the first hidden layer has dimension 784×128, plus a bias vector of length 128.
- The output layer consists of 10 neurons, which is the same the number of classes in the dataset. The matrix of weights projecting from the hidden layer to the output layer has dimension 128×10, along with a bias vector of length 10.
- Applying softmax to the output layer results in a probability distribution over classes.
Weight initialization is handled automatically in PyTorch, but when working in JAX, The first step is to initialize the network weights. We can create a helper function to assist with randomly assigning values to the weight matrices and bias vectors.
In JAX, random number generation is handled a bit differently than in Numpy to ensure functional purity. JAX uses explicit PRNG keys to generate random numbers instead of relying on global state. A “key” is a special array that acts as a seed, and every time you use it, JAX produces the same random numbers for the same key.
Since JAX enforces immutability, you can’t reuse a key for multiple random calls without getting the same result. Instead, you split a key using
jax.random.split
, which deterministically generates new, unique keys from the original one. Each split key is independent, allowing for the generation of different random numbers while maintaining reproducibility. In the next cell, we initialize weights using small random normal values:from jax import random def initialize_weights(sizes, key, scale=.02): """ "Initialize weights and biases for each layer for simple fully-connected network. Parameters ---------- sizes : list of int List of integers representing the number of neurons in each layer. key : jax.random.PRNGKey Random key for JAX. Returns ------- List of iniitialized weights and biases for each layer. """ keys = random.split(key, len(sizes) - 1) params = [] for m, n, k in zip(sizes[:-1], sizes[1:], keys): w_key, b_key = random.split(k) w = scale * random.normal(w_key, (m, n)) b = scale * random.normal(b_key, (n,)) params.append((w, b)) return params # Initialize weights and biases for each layer. sizes = [784, 128, 10] params = initialize_weights(sizes, key=random.PRNGKey(516), scale=.02) # Print shape of each layer's weights and biases. print(f"W1 shape: {params[0][0].shape}") print(f"b1 shape: {params[0][1].shape}") print(f"W2 shape: {params[1][0].shape}") print(f"b2 shape: {params[1][1].shape}")
W1 shape: (784, 128) b1 shape: (128,) W2 shape: (128, 10) b2 shape: (10,)
In PyTorch, models inherit from
nn.Module
and must implement a forward
method that defines the network’s computation flow. The forward method orchestrates how input tensors transform through pre-specified operations to produce outputs.For our JAX implementation we’ll create a similar function, but the weights must be explicitly passed as parameters rather than stored as internal state. Unlike PyTorch’s object-oriented approach where weights are hidden properties of the model instance, JAX follows a functional paradigm that requires all state to be passed explicitly between function calls, eliminating hidden state.
In our
forward
function, we incorporate ReLU activation between layers to introduce non-linearity:import jax.numpy as jnp from jax.nn import relu def forward(params, X): """ Forward pass for simple fully-connected network. Parameters ---------- params : list of tuples List of tuples containing weights and biases for each layer. X : jax.numpy.ndarray Input data. Returns ------- jax.numpy.ndarray """ a = X for W, b in params[:-1]: z = jnp.dot(a, W) + b a = relu(z) W, b = params[-1] return jnp.dot(a, W) + b
We can pass a single flattened image array into
forward
, and it should return a 1×10 vector of activations. The output will not be a probability distribution since softmax hasn’t been applied, but we can still test it to ensure that the shape of the output is consistent with our expectations:# Get first image from first training batch. X, y = train_batches[0] # Convert to JAX array. X0 = jnp.asarray(X[0].flatten()) # Pass X0 into forward. ypred = forward(params, X0) print(f"ypred.shape: {ypred.shape}") print(f"ypred: {ypred}")
ypred.shape: (10,) ypred: [-0.01938 0.04435 0.01545 0.01266 0.00116 -0.07045 -0.03737 0.00276 0.01255 -0.00883]
Auto-Vectorizing the Forward Pass
As implmented,
forward
is only capable of processing a single flattened image at a time. However, we can use vmap
, introduced in the first post, to process a batch of images at a time without any modification to forward
. vmap
enables batch processing while taking advantage of JAX’s optimized execution. Instead of using loops, it efficiently maps a function over an array along a pre-specified axis:from jax import vmap batch_forward = vmap(forward, in_axes=(None, 0))
in_axes
controls which input array axes to vectorize over, and its length must equal the number of positional arguments associated with the original function. In our case, the first argument to forward
is params
, which stays the same within the context of the forward pass. The second argument corresponds to our input image, and the ‘0’ indicates that vectorization should be applied along the 0th axis (which is batch dimension).We can pass a batch of size 64 x 784 into
batch_forward
, and return an output of size 64×10:# Get first batch of flattened training images. X, y = train_batches[0] ypreds = batch_forward(params, X) print(f"ypreds.shape: {ypreds.shape}") print(f"ypreds:\n{ypreds}")
ypreds.shape: (64, 10) ypreds: [[-0.01938 0.04435 0.01545 0.01266 0.00116 -0.07045 -0.03737 0.00276 0.01255 -0.00883] [ 0.00129 0.0109 0.00602 -0.00877 -0.01916 -0.05659 -0.0483 -0.02185 -0.02209 0.01736] [-0.02005 0.00758 0.02579 -0.03333 0.00352 -0.0578 -0.04556 -0.0556 -0.04529 0.06416] [ 0.02695 0.00363 0.02104 -0.00848 0.03085 -0.06641 -0.03336 -0.01503 -0.04532 0.0081 ] [-0.02473 -0.01859 0.04075 -0.00455 0.01993 -0.04665 -0.03698 -0.0063 -0.03235 0.03909] [-0.02891 0.02183 0.01171 -0.00624 0.02155 -0.07441 -0.04383 -0.01707 -0.01268 0.02351] [ 0.01002 -0.01742 0.01145 -0.03244 0.01756 -0.04202 -0.04045 -0.03138 -0.01564 0.03871] [-0.01686 0.04742 0.04909 0.01424 -0.00733 -0.04773 -0.0763 -0.03617 0.01079 0.0016 ] [ 0.00961 -0.00337 0.01458 -0.01242 0.01239 -0.03938 -0.01167 -0.01025 -0.02025 0.03112] [-0.00639 -0.01172 0.04447 -0.0033 0.0276 -0.08083 -0.0409 -0.023 -0.03266 0.03201] [-0.01249 0.02364 0.03073 -0.0103 0.01524 -0.04002 -0.07046 -0.00143 -0.03037 0.03692] [-0.01155 -0.00816 0.01417 0.01124 0.01693 -0.03797 -0.00644 -0.00504 -0.04066 0.02742] [-0.03043 0.03531 -0.00233 -0.02903 0.02681 -0.07377 -0.09124 -0.0348 -0.02728 0.01781] [-0.03805 0.0015 0.02286 -0.02599 0.02493 -0.02707 -0.0489 0.00085 -0.03588 0.01173] [-0.00322 -0.00982 0.00318 0.00086 0.01338 -0.05927 -0.01745 -0.019 -0.02419 0.04078] [-0.00524 -0.03266 0.04316 -0.00133 0.02968 -0.06335 -0.01557 -0.02004 -0.05118 0.04478] [ 0.01196 0.00738 0.0139 -0.01251 0.02977 -0.0482 -0.04011 -0.04932 -0.01502 0.00449] [ 0.01775 -0.00792 0.02539 0.01883 0.01768 -0.06387 -0.0317 -0.0299 -0.01473 0.02359] [-0.01174 0.01263 0.01114 -0.02254 0.01541 -0.0507 -0.04069 -0.00546 -0.02738 0.03351] [ 0.00519 -0.00914 0.04497 0.01114 0.02668 -0.07109 -0.02906 0.00836 -0.03017 0.02743] [-0.01821 -0.00343 0.01697 -0.00879 0.03184 -0.06576 -0.04461 -0.02077 -0.0576 0.02691] [ 0.02505 -0.00406 0.01754 -0.01299 -0.03357 -0.0621 -0.05533 -0.01607 -0.05241 0.00787] [-0.00548 -0.0151 0.02098 0.00532 -0.0007 -0.06299 -0.02551 0.02606 -0.02592 0.02215] [ 0.02538 -0.00074 0.01906 -0.00907 0.03593 -0.0677 -0.02988 -0.01161 -0.04277 0.01133] [ 0.00403 0.00756 0.02052 0.02553 0.02037 -0.03283 -0.03985 0.01178 -0.02486 -0.00577] [-0.00577 0.01179 -0.00397 -0.01612 0.006 -0.05153 -0.08116 -0.0568 -0.02216 0.02107] [-0.01831 -0.00545 0.03149 -0.00173 0.02892 -0.01969 -0.03263 -0.01362 -0.03298 0.03155] [-0.02663 0.047 0.01241 -0.00937 -0.00653 -0.07457 -0.08388 -0.00652 -0.00656 0.03299] [-0.01582 -0.00341 -0.00549 -0.07723 0.03082 -0.08325 -0.02868 -0.03757 -0.03521 0.05039] [ 0.01516 -0.02163 0.03846 -0.00524 0.03453 -0.07424 -0.00601 -0.01762 -0.04946 0.02805] [ 0.02345 0.03865 0.0027 -0.03011 0.02129 -0.06911 -0.0642 -0.06447 -0.02161 0.0513 ] [-0.00958 0.00089 0.00206 0.00014 -0.03024 -0.05849 -0.03403 -0.04682 -0.02192 0.02192] [ 0.01296 0.01323 -0.00675 -0.02761 0.0001 -0.02599 -0.03824 -0.03056 -0.03029 0.00598] [ 0.015 0.00759 0.03926 0.00927 0.0302 -0.08289 -0.04693 0.00788 -0.02175 0.031 ] [-0.03891 0.03415 0.00166 -0.01509 0.0031 -0.05396 -0.07296 -0.00382 -0.02682 0.05255] [ 0.01267 0.00989 0.02874 0.01091 0.04832 -0.05603 -0.01568 -0.00115 -0.02438 0.00899] [-0.02444 0.03956 -0.00092 -0.01447 -0.00312 -0.06536 -0.05304 -0.00953 -0.0111 0.02981] [-0.0344 0.04178 0.01164 -0.03304 0.01373 -0.06036 -0.07474 0.01716 -0.03821 0.02719] [ 0.02056 -0.04235 0.02223 -0.03208 0.02435 -0.03469 0.0082 -0.01539 -0.02368 0.00596] [-0.01846 0.01328 -0.01612 -0.0376 -0.0026 -0.01918 -0.03758 -0.0154 -0.01747 0.02746] [-0.01283 -0.01919 0.0135 0.00131 0.01254 -0.04955 -0.02221 -0.01261 -0.02833 0.03626] [ 0.01384 0.017 0.02705 -0.01914 0.03964 -0.05375 -0.05693 -0.01484 -0.01374 0.03448] [-0.00959 -0.03199 0.0208 0.01219 0.01838 -0.04273 0.0039 -0.00971 -0.02165 0.02525] [-0.00754 -0.02387 0.03144 0.01683 0.01158 -0.04538 -0.02697 0.00096 -0.02717 0.0279 ] [-0.00149 -0.01607 0.03226 0.00932 0.03503 -0.04532 -0.02371 -0.04158 -0.02733 0.01295] [-0.00841 -0.02569 0.02426 -0.01079 0.04106 -0.04494 -0.02315 -0.01234 -0.03156 0.03288] [ 0.03858 -0.00225 0.03607 0.01419 0.05866 -0.04707 -0.0479 -0.0585 -0.03581 0.0068 ] [ 0.00396 0.00229 0.03064 0.00296 0.04075 -0.05858 -0.02379 0.01113 -0.01696 0.02085] [-0.0306 0.01159 0.01966 -0.0238 0.0085 -0.04504 -0.07059 -0.02362 -0.02067 0.04078] [-0.03142 0.03721 0.02696 0.00428 0.00507 -0.06188 -0.07031 -0.00547 -0.01252 0.0217 ] [-0.01006 -0.00553 0.009 -0.02372 0.00812 -0.04891 -0.06205 -0.01728 -0.02563 0.02564] [ 0.00656 0.03307 0.01764 0.00273 -0.00087 -0.07939 -0.07847 -0.05009 -0.03572 0.00226] [-0.01152 0.00209 0.01268 -0.02766 0.02629 -0.08758 -0.02191 -0.01157 -0.03755 0.07643] [ 0.00849 -0.00357 0.04292 0.00966 0.01071 -0.06685 -0.01274 -0.03098 -0.04872 0.04148] [-0.02366 0.01482 0.0214 -0.0017 0.03714 -0.06655 -0.00035 -0.03575 -0.00123 0.06236] [-0.00574 -0.01814 0.01388 -0.01289 0.03698 -0.05857 -0.04548 -0.01367 -0.03049 0.05411] [-0.00067 0.03152 0.01473 -0.00112 0.04264 -0.07165 -0.08533 -0.07986 -0.03663 0.02476] [ 0.00238 -0.01577 0.03842 -0.01902 0.03321 -0.05336 -0.03538 0.00514 -0.03226 0.02748] [-0.05187 -0.02097 0.04671 -0.02635 0.03248 -0.0399 -0.06063 -0.01354 -0.03793 0.04167] [ 0.02789 -0.00258 0.01699 -0.01496 0.02993 -0.06447 -0.03116 -0.00742 -0.03525 0.01357] [ 0.03633 0.0008 0.01882 -0.00777 0.00483 -0.0582 -0.05376 -0.08621 -0.0305 0.0121 ] [ 0.00676 0.01654 0.03401 0.00623 0.035 -0.06879 -0.0183 -0.00864 -0.03068 0.02204] [-0.04109 0.01182 0.00634 -0.02956 -0.00965 -0.064 -0.05708 -0.01416 -0.01804 0.03834] [-0.00258 0.03007 0.04005 -0.00407 -0.04672 -0.05536 -0.07473 -0.01553 -0.01838 -0.00027]]
To ensure each row sums to 1, we can apply softmax to
ypreds
:from jax.nn import softmax yprobs = softmax(ypreds, axis=1) print(f"yprobs.shape: {yprobs.shape}") print(f"yprobs:\n{yprobs}")
yprobs.shape: (64, 10) yprobs: [[0.0985 0.10498 0.10199 0.10171 0.10054 0.09359 0.09674 0.1007 0.10169 0.09954] [0.10153 0.10251 0.10201 0.10051 0.09947 0.09582 0.09661 0.0992 0.09918 0.10317] [0.09949 0.10228 0.10416 0.09818 0.10186 0.0958 0.09698 0.09601 0.09701 0.10823] [0.10349 0.1011 0.10288 0.09988 0.10389 0.09426 0.09743 0.09923 0.09627 0.10155] [0.0982 0.09881 0.10485 0.1002 0.10269 0.09607 0.09701 0.10003 0.09746 0.10468] [0.09813 0.10323 0.10219 0.10038 0.1032 0.09376 0.09667 0.09929 0.09973 0.10341] [0.102 0.09924 0.10215 0.09776 0.10277 0.09683 0.09698 0.09787 0.09942 0.10497] [0.09886 0.10543 0.1056 0.10198 0.09981 0.09586 0.09316 0.09697 0.10163 0.1007 ] [0.10125 0.09994 0.10175 0.09904 0.10153 0.09641 0.09911 0.09926 0.09827 0.10345] [0.10024 0.09971 0.10547 0.10055 0.10371 0.09305 0.09684 0.09859 0.09765 0.10417] [0.09929 0.10294 0.10367 0.0995 0.10208 0.09659 0.09369 0.10039 0.09753 0.10432] [0.09923 0.09956 0.10181 0.10151 0.10209 0.09664 0.09973 0.09987 0.09638 0.10317] [0.09897 0.1057 0.1018 0.09911 0.10481 0.09478 0.09314 0.09854 0.09929 0.10387] [0.09734 0.10127 0.10345 0.09852 0.10367 0.09841 0.09629 0.1012 0.09755 0.10231] [0.10039 0.09973 0.10104 0.10081 0.10208 0.09492 0.09898 0.09882 0.09831 0.10491] [0.10013 0.09742 0.1051 0.10052 0.10369 0.09448 0.0991 0.09866 0.09563 0.10527] [0.10216 0.10169 0.10236 0.09969 0.104 0.0962 0.09698 0.09609 0.09944 0.1014 ] [0.10221 0.09962 0.10299 0.10232 0.1022 0.0942 0.09728 0.09745 0.09894 0.10281] [0.09965 0.10211 0.10196 0.09858 0.10239 0.09584 0.09681 0.10028 0.09811 0.10426] [0.10063 0.09919 0.10471 0.10123 0.10281 0.09323 0.09724 0.10094 0.09713 0.10289] [0.09956 0.10104 0.10313 0.1005 0.10467 0.09494 0.09697 0.09931 0.09572 0.10416] [0.10442 0.10142 0.10364 0.10052 0.09847 0.0957 0.09635 0.10021 0.09663 0.10264] [0.10003 0.09907 0.10271 0.10112 0.10051 0.09444 0.09805 0.10324 0.09801 0.10283] [0.10324 0.10058 0.10259 0.09975 0.10434 0.09407 0.09769 0.09949 0.09644 0.1018 ] [0.10052 0.10087 0.10219 0.1027 0.10217 0.09688 0.0962 0.1013 0.09765 0.09953] [0.10137 0.10317 0.10155 0.10033 0.10257 0.09684 0.09401 0.09633 0.09972 0.10413] [0.09848 0.09975 0.1035 0.10012 0.10324 0.09834 0.09708 0.09894 0.09704 0.10351] [0.09849 0.10601 0.10241 0.1002 0.10049 0.09388 0.09301 0.10049 0.10049 0.10454] [0.10039 0.10165 0.10144 0.09441 0.10519 0.09385 0.09911 0.09823 0.09847 0.10727] [0.10206 0.09837 0.10446 0.1 0.10405 0.09333 0.09992 0.09876 0.09567 0.10338] [0.10343 0.10502 0.10131 0.09804 0.10321 0.09429 0.09475 0.09473 0.09887 0.10635] [0.10078 0.10184 0.10196 0.10176 0.09872 0.09597 0.09834 0.09709 0.09954 0.104 ] [0.10258 0.10261 0.10058 0.0985 0.10127 0.09866 0.09746 0.09821 0.09824 0.10187] [0.10156 0.10081 0.10405 0.10098 0.10312 0.09209 0.09546 0.10084 0.0979 0.1032 ] [0.09728 0.10466 0.10131 0.09963 0.10146 0.09583 0.09403 0.10076 0.09846 0.1066 ] [0.10101 0.10073 0.10265 0.10083 0.10468 0.0943 0.09819 0.09963 0.09734 0.10064] [0.09864 0.10516 0.10099 0.09963 0.10077 0.09469 0.09586 0.10013 0.09997 0.10414] [0.0978 0.10555 0.10241 0.09794 0.10263 0.0953 0.09394 0.10298 0.09743 0.10402] [0.10273 0.09647 0.1029 0.09746 0.10312 0.09721 0.10147 0.0991 0.09829 0.10124] [0.09937 0.10258 0.09961 0.09749 0.10096 0.0993 0.09749 0.09968 0.09947 0.10404] [0.0995 0.09887 0.10216 0.10092 0.10206 0.09591 0.09857 0.09952 0.09797 0.10451] [0.10161 0.10193 0.10296 0.09831 0.10426 0.09497 0.09466 0.09873 0.09884 0.10373] [0.09937 0.09717 0.10244 0.10156 0.10219 0.09613 0.10072 0.09936 0.09818 0.10289] [0.09964 0.09802 0.1036 0.1021 0.10156 0.09594 0.09772 0.10049 0.0977 0.10323] [0.10047 0.09902 0.10392 0.10157 0.10421 0.09617 0.09827 0.09653 0.09791 0.10194] [0.09971 0.098 0.10302 0.09947 0.10476 0.09613 0.09825 0.09932 0.09743 0.10391] [0.10424 0.10007 0.10398 0.10173 0.10635 0.09568 0.0956 0.0946 0.09677 0.10098] [0.10023 0.10006 0.10294 0.10013 0.10398 0.09415 0.09748 0.10095 0.09815 0.10193] [0.09824 0.10248 0.10331 0.09891 0.10216 0.09684 0.09439 0.09893 0.09922 0.10551] [0.09769 0.10463 0.10356 0.10124 0.10132 0.09476 0.09396 0.10026 0.09955 0.10302] [0.10047 0.10092 0.1024 0.0991 0.10231 0.09664 0.09538 0.09974 0.09892 0.10412] [0.10244 0.10519 0.10358 0.10205 0.10168 0.094 0.09409 0.09679 0.0982 0.102 ] [0.09957 0.10093 0.10201 0.09797 0.10341 0.09228 0.09854 0.09956 0.09701 0.10872] [0.1013 0.10008 0.10484 0.10141 0.10152 0.09394 0.09917 0.09738 0.09566 0.10469] [0.09754 0.10137 0.10203 0.0997 0.10365 0.09344 0.09984 0.09637 0.09975 0.1063 ] [0.10017 0.09894 0.10216 0.09946 0.10454 0.09502 0.09627 0.09938 0.09772 0.10635] [0.10145 0.10477 0.10303 0.10141 0.10594 0.0945 0.09322 0.09373 0.09787 0.10407] [0.10069 0.09888 0.10438 0.09856 0.10384 0.09523 0.09696 0.10097 0.09726 0.10325] [0.09612 0.09914 0.10608 0.09861 0.10458 0.09728 0.09528 0.09988 0.09747 0.10555] [0.10348 0.10038 0.10236 0.09914 0.10369 0.09435 0.09755 0.09989 0.09715 0.10201] [0.10534 0.10166 0.10351 0.10079 0.10207 0.09584 0.09626 0.09319 0.09853 0.10282] [0.10069 0.10168 0.10347 0.10064 0.10358 0.09336 0.0982 0.09915 0.09699 0.10224] [0.09764 0.10295 0.10239 0.09878 0.10076 0.09543 0.0961 0.10031 0.09992 0.10572] [0.10117 0.10452 0.10557 0.10102 0.0968 0.09596 0.09412 0.09986 0.09958 0.1014 ]]
At this point, the outputs are meaningless and are pretty close to uniformly distributed over classes. This is because we haven’t yet calculated the gradient of the loss function with respect to each weight, which allows the network to adjust its weights and biases to minimize prediction errors. In the next post, we’ll implement backpropagation entirely in JAX and walkthrough how to construct the training and validation loops.
Want to share your content on python-bloggers? click here.