MENU

Up and Running with JAX – Fully-Connected Network Forward Pass

This article was first published on The Pleasure of Finding Things Out: A blog by James Triveri , and kindly contributed to python-bloggers. (You can report issue about the content on this page here)
Want to share your content on python-bloggers? click here.

In a previous post, I introduced JAX with particular emphasis on JIT compilation, vectorizing transformations and automatic differentiation. In this post, we walkthrough an implementation of the forward pass for a fully-connected neural network with the goal of classifying MNIST handwritten digits, incorporating concepts from the first post.

We begin by loading MNIST training and validation sets, convert the PIL images to Numpy arrays, and create image-label batches of size 64:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
import warnings
import numpy as np
import pandas as pd
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.transforms import v2
import matplotlib.pyplot as plt
np.set_printoptions(suppress=True, precision=5, linewidth=1000)
pd.options.mode.chained_assignment = None
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option("display.precision", 5)
warnings.filterwarnings("ignore")
# Batch size.
bs = 64
train_data = datasets.MNIST(
root="data",
train=True,
download=True,
transform=v2.Compose([ToTensor()])
)
valid_data = datasets.MNIST(
root="data",
train=False,
download=True,
transform=v2.Compose([ToTensor()])
)
# Convert PIL images to NumPy arrays.
train_data_arr = train_data.data.numpy() / 255.0 # Normalize pixel values to [0, 1]
valid_data_arr = valid_data.data.numpy() / 255.0 # Normalize pixel values to [0, 1]
train_data_arr = train_data_arr.reshape(-1, 28 * 28) # Flatten images to 1D arrays
valid_data_arr = valid_data_arr.reshape(-1, 28 * 28) # Flatten images to 1D arrays
train_labels = train_data.targets.numpy()
valid_labels = valid_data.targets.numpy()
# Create training and validation batches of 64.
train_batches = [
(train_data_arr[(bs * ii):(bs * (ii + 1))], train_labels[(bs * ii):(bs * (ii + 1))])
for ii in range(len(train_data_arr) // bs)
]
valid_batches = [
(valid_data_arr[(bs * ii):(bs * (ii + 1))], valid_labels[(bs * ii):(bs * (ii + 1))])
for ii in range(len(valid_data_arr) // bs)
]
print(f"train_data_arr.shape: {train_data_arr.shape}")
print(f"valid_data_arr.shape: {valid_data_arr.shape}")
print(f"train_labels.shape : {train_labels.shape}")
print(f"valid_labels.shape : {valid_labels.shape}")
print(f"len(train_batches) : {len(train_batches)}")
print(f"len(valid_batches) : {len(valid_batches)}")
import warnings import numpy as np import pandas as pd import torch import torchvision import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import datasets from torchvision.transforms import ToTensor from torchvision.transforms import v2 import matplotlib.pyplot as plt np.set_printoptions(suppress=True, precision=5, linewidth=1000) pd.options.mode.chained_assignment = None pd.set_option('display.max_columns', None) pd.set_option('display.width', None) pd.set_option("display.precision", 5) warnings.filterwarnings("ignore") # Batch size. bs = 64 train_data = datasets.MNIST( root="data", train=True, download=True, transform=v2.Compose([ToTensor()]) ) valid_data = datasets.MNIST( root="data", train=False, download=True, transform=v2.Compose([ToTensor()]) ) # Convert PIL images to NumPy arrays. train_data_arr = train_data.data.numpy() / 255.0 # Normalize pixel values to [0, 1] valid_data_arr = valid_data.data.numpy() / 255.0 # Normalize pixel values to [0, 1] train_data_arr = train_data_arr.reshape(-1, 28 * 28) # Flatten images to 1D arrays valid_data_arr = valid_data_arr.reshape(-1, 28 * 28) # Flatten images to 1D arrays train_labels = train_data.targets.numpy() valid_labels = valid_data.targets.numpy() # Create training and validation batches of 64. train_batches = [ (train_data_arr[(bs * ii):(bs * (ii + 1))], train_labels[(bs * ii):(bs * (ii + 1))]) for ii in range(len(train_data_arr) // bs) ] valid_batches = [ (valid_data_arr[(bs * ii):(bs * (ii + 1))], valid_labels[(bs * ii):(bs * (ii + 1))]) for ii in range(len(valid_data_arr) // bs) ] print(f"train_data_arr.shape: {train_data_arr.shape}") print(f"valid_data_arr.shape: {valid_data_arr.shape}") print(f"train_labels.shape : {train_labels.shape}") print(f"valid_labels.shape : {valid_labels.shape}") print(f"len(train_batches) : {len(train_batches)}") print(f"len(valid_batches) : {len(valid_batches)}")
import warnings

import numpy as np
import pandas as pd

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.transforms import v2
import matplotlib.pyplot as plt

np.set_printoptions(suppress=True, precision=5, linewidth=1000)
pd.options.mode.chained_assignment = None
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option("display.precision", 5)
warnings.filterwarnings("ignore")

# Batch size.
bs = 64

train_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=v2.Compose([ToTensor()])
)

valid_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=v2.Compose([ToTensor()])
)

# Convert PIL images to NumPy arrays.
train_data_arr = train_data.data.numpy() / 255.0      # Normalize pixel values to [0, 1]
valid_data_arr = valid_data.data.numpy() / 255.0      # Normalize pixel values to [0, 1] 
train_data_arr = train_data_arr.reshape(-1, 28 * 28)  # Flatten images to 1D arrays
valid_data_arr = valid_data_arr.reshape(-1, 28 * 28)  # Flatten images to 1D arrays
train_labels = train_data.targets.numpy()
valid_labels = valid_data.targets.numpy()

# Create training and validation batches of 64.
train_batches = [
    (train_data_arr[(bs * ii):(bs * (ii + 1))], train_labels[(bs * ii):(bs * (ii + 1))]) 
    for ii in range(len(train_data_arr) // bs)
]
valid_batches = [
    (valid_data_arr[(bs * ii):(bs * (ii + 1))], valid_labels[(bs * ii):(bs * (ii + 1))]) 
    for ii in range(len(valid_data_arr) // bs)
]

print(f"train_data_arr.shape: {train_data_arr.shape}")
print(f"valid_data_arr.shape: {valid_data_arr.shape}")
print(f"train_labels.shape  : {train_labels.shape}")
print(f"valid_labels.shape  : {valid_labels.shape}")
print(f"len(train_batches)  : {len(train_batches)}")
print(f"len(valid_batches)  : {len(valid_batches)}")
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
train_data_arr.shape: (60000, 784)
valid_data_arr.shape: (10000, 784)
train_labels.shape : (60000,)
valid_labels.shape : (10000,)
len(train_batches) : 937
len(valid_batches) : 156
train_data_arr.shape: (60000, 784) valid_data_arr.shape: (10000, 784) train_labels.shape : (60000,) valid_labels.shape : (10000,) len(train_batches) : 937 len(valid_batches) : 156
train_data_arr.shape: (60000, 784)
valid_data_arr.shape: (10000, 784)
train_labels.shape  : (60000,)
valid_labels.shape  : (10000,)
len(train_batches)  : 937
len(valid_batches)  : 156

We can visualize a batch of images and labels using matplotlib:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
X, y = train_batches[0]
X = np.expand_dims(X.reshape(-1, 28, 28), 1)
fig = plt.figure(figsize=(8., 8.), tight_layout=False)
grid = ImageGrid(fig, 111, nrows_ncols=(8, 8), axes_pad=0.20)
for ax, X_ii, y_ii in zip(grid, X, y):
X_ii = np.transpose(X_ii, (1, 2, 0))
ax.imshow(X_ii, cmap="gray")
ax.axis("off")
ax.set_title(f"{y_ii}", fontsize=8)
plt.show()
import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import ImageGrid X, y = train_batches[0] X = np.expand_dims(X.reshape(-1, 28, 28), 1) fig = plt.figure(figsize=(8., 8.), tight_layout=False) grid = ImageGrid(fig, 111, nrows_ncols=(8, 8), axes_pad=0.20) for ax, X_ii, y_ii in zip(grid, X, y): X_ii = np.transpose(X_ii, (1, 2, 0)) ax.imshow(X_ii, cmap="gray") ax.axis("off") ax.set_title(f"{y_ii}", fontsize=8) plt.show()
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

X, y = train_batches[0]
X = np.expand_dims(X.reshape(-1, 28, 28), 1)

fig = plt.figure(figsize=(8., 8.), tight_layout=False)

grid = ImageGrid(fig, 111, nrows_ncols=(8, 8), axes_pad=0.20)

for ax, X_ii, y_ii in zip(grid, X, y):
    X_ii = np.transpose(X_ii, (1, 2, 0))
    ax.imshow(X_ii, cmap="gray")
    ax.axis("off")
    ax.set_title(f"{y_ii}", fontsize=8)

plt.show()

Our goal is to create a model that accepts a batch of 64 images, and returns a class prediction for each image in the batch. Our architecture is presented in the image below:

The pre-activations for layer are computed as , where:

  • = layer pre-activations (value prior to applying non-linearity like ReLU).
  • = layer activations, with representing the original input.
  • = The weight matrix for layer .
  • = bias vector for layer .

For the network shown above assuming a batch size of 64:

  • : Input matrix with dimension 64×784.
  • : Weight matrix with dimension 784×128.
  • Bias vector of length 128.
  • : Matrix of pre-activations with dimension 64×128.
  • : Non-linearity applied to . Activation matrix with dimension 64×128.
  • : Weight matrix with dimension 128×10.
  • : Bias vector of length 10.
  • : Matrix of pre-activations with dimension 64×10.
  • : Non-linearity applied to . Activation matrix with dimension 64×10.

The forward pass feeds an image of size 28×28 into the network, which produces a probability distribution over all classes. The class with the highest probability is our class prediction, which for MNIST will be one of 10 digits 0-9. Specifically:

  • Each 28×28 image is flattened to have shape 1×784. The input layer has the same size as the flattened image (784,).
  • The hidden layer consists of 128 neurons. The matrix of weights projecting from the input layer to the first hidden layer has dimension 784×128, plus a bias vector of length 128.
  • The output layer consists of 10 neurons, which is the same the number of classes in the dataset. The matrix of weights projecting from the hidden layer to the output layer has dimension 128×10, along with a bias vector of length 10.
  • Applying softmax to the output layer results in a probability distribution over classes.

Weight initialization is handled automatically in PyTorch, but when working in JAX, The first step is to initialize the network weights. We can create a helper function to assist with randomly assigning values to the weight matrices and bias vectors.

In JAX, random number generation is handled a bit differently than in Numpy to ensure functional purity. JAX uses explicit PRNG keys to generate random numbers instead of relying on global state. A “key” is a special array that acts as a seed, and every time you use it, JAX produces the same random numbers for the same key.

Since JAX enforces immutability, you can’t reuse a key for multiple random calls without getting the same result. Instead, you split a key using

jax.random.split
jax.random.split, which deterministically generates new, unique keys from the original one. Each split key is independent, allowing for the generation of different random numbers while maintaining reproducibility. In the next cell, we initialize weights using small random normal values:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
from jax import random
def initialize_weights(sizes, key, scale=.02):
"""
"Initialize weights and biases for each layer for simple fully-connected
network.
Parameters
----------
sizes : list of int
List of integers representing the number of neurons in each layer.
key : jax.random.PRNGKey
Random key for JAX.
Returns
-------
List of iniitialized weights and biases for each layer.
"""
keys = random.split(key, len(sizes) - 1)
params = []
for m, n, k in zip(sizes[:-1], sizes[1:], keys):
w_key, b_key = random.split(k)
w = scale * random.normal(w_key, (m, n))
b = scale * random.normal(b_key, (n,))
params.append((w, b))
return params
# Initialize weights and biases for each layer.
sizes = [784, 128, 10]
params = initialize_weights(sizes, key=random.PRNGKey(516), scale=.02)
# Print shape of each layer's weights and biases.
print(f"W1 shape: {params[0][0].shape}")
print(f"b1 shape: {params[0][1].shape}")
print(f"W2 shape: {params[1][0].shape}")
print(f"b2 shape: {params[1][1].shape}")
from jax import random def initialize_weights(sizes, key, scale=.02): """ "Initialize weights and biases for each layer for simple fully-connected network. Parameters ---------- sizes : list of int List of integers representing the number of neurons in each layer. key : jax.random.PRNGKey Random key for JAX. Returns ------- List of iniitialized weights and biases for each layer. """ keys = random.split(key, len(sizes) - 1) params = [] for m, n, k in zip(sizes[:-1], sizes[1:], keys): w_key, b_key = random.split(k) w = scale * random.normal(w_key, (m, n)) b = scale * random.normal(b_key, (n,)) params.append((w, b)) return params # Initialize weights and biases for each layer. sizes = [784, 128, 10] params = initialize_weights(sizes, key=random.PRNGKey(516), scale=.02) # Print shape of each layer's weights and biases. print(f"W1 shape: {params[0][0].shape}") print(f"b1 shape: {params[0][1].shape}") print(f"W2 shape: {params[1][0].shape}") print(f"b2 shape: {params[1][1].shape}")
from jax import random


def initialize_weights(sizes, key, scale=.02):
    """
    "Initialize weights and biases for each layer for simple fully-connected 
    network.

    Parameters
    ----------
    sizes : list of int
        List of integers representing the number of neurons in each layer.

    key : jax.random.PRNGKey
        Random key for JAX.

    Returns
    -------
    List of iniitialized weights and biases for each layer.
    """
    keys = random.split(key, len(sizes) - 1)
    params = []
    for m, n, k in zip(sizes[:-1], sizes[1:], keys):
        w_key, b_key = random.split(k)
        w = scale * random.normal(w_key, (m, n))
        b = scale * random.normal(b_key, (n,))
        params.append((w, b))
    return params


# Initialize weights and biases for each layer.
sizes = [784, 128, 10]

params = initialize_weights(sizes, key=random.PRNGKey(516), scale=.02)   

# Print shape of each layer's weights and biases.
print(f"W1 shape: {params[0][0].shape}")
print(f"b1 shape: {params[0][1].shape}")
print(f"W2 shape: {params[1][0].shape}")  
print(f"b2 shape: {params[1][1].shape}")
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
W1 shape: (784, 128)
b1 shape: (128,)
W2 shape: (128, 10)
b2 shape: (10,)
W1 shape: (784, 128) b1 shape: (128,) W2 shape: (128, 10) b2 shape: (10,)
W1 shape: (784, 128)
b1 shape: (128,)
W2 shape: (128, 10)
b2 shape: (10,)

In PyTorch, models inherit from

nn.Module
nn.Module and must implement a
forward
forward method that defines the network’s computation flow. The forward method orchestrates how input tensors transform through pre-specified operations to produce outputs.

For our JAX implementation we’ll create a similar function, but the weights must be explicitly passed as parameters rather than stored as internal state. Unlike PyTorch’s object-oriented approach where weights are hidden properties of the model instance, JAX follows a functional paradigm that requires all state to be passed explicitly between function calls, eliminating hidden state.

In our

forward
forward function, we incorporate ReLU activation between layers to introduce non-linearity:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
import jax.numpy as jnp
from jax.nn import relu
def forward(params, X):
"""
Forward pass for simple fully-connected network.
Parameters
----------
params : list of tuples
List of tuples containing weights and biases for each layer.
X : jax.numpy.ndarray
Input data.
Returns
-------
jax.numpy.ndarray
"""
a = X
for W, b in params[:-1]:
z = jnp.dot(a, W) + b
a = relu(z)
W, b = params[-1]
return jnp.dot(a, W) + b
import jax.numpy as jnp from jax.nn import relu def forward(params, X): """ Forward pass for simple fully-connected network. Parameters ---------- params : list of tuples List of tuples containing weights and biases for each layer. X : jax.numpy.ndarray Input data. Returns ------- jax.numpy.ndarray """ a = X for W, b in params[:-1]: z = jnp.dot(a, W) + b a = relu(z) W, b = params[-1] return jnp.dot(a, W) + b
import jax.numpy as jnp
from jax.nn import relu


def forward(params, X):
    """
    Forward pass for simple fully-connected network.

    Parameters
    ----------
    params : list of tuples
        List of tuples containing weights and biases for each layer.

    X : jax.numpy.ndarray
        Input data.

    Returns
    -------
    jax.numpy.ndarray
    """
    a = X
    for W, b in params[:-1]:
        z = jnp.dot(a, W) + b
        a = relu(z)
    W, b = params[-1]
    return jnp.dot(a, W) + b

We can pass a single flattened image array into

forward
forward, and it should return a 1×10 vector of activations. The output will not be a probability distribution since softmax hasn’t been applied, but we can still test it to ensure that the shape of the output is consistent with our expectations:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
# Get first image from first training batch.
X, y = train_batches[0]
# Convert to JAX array.
X0 = jnp.asarray(X[0].flatten())
# Pass X0 into forward.
ypred = forward(params, X0)
print(f"ypred.shape: {ypred.shape}")
print(f"ypred: {ypred}")
# Get first image from first training batch. X, y = train_batches[0] # Convert to JAX array. X0 = jnp.asarray(X[0].flatten()) # Pass X0 into forward. ypred = forward(params, X0) print(f"ypred.shape: {ypred.shape}") print(f"ypred: {ypred}")
# Get first image from first training batch.
X, y = train_batches[0]

# Convert to JAX array.
X0 = jnp.asarray(X[0].flatten())

# Pass X0 into forward.
ypred = forward(params, X0)

print(f"ypred.shape: {ypred.shape}")
print(f"ypred: {ypred}")
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
ypred.shape: (10,)
ypred: [-0.01938 0.04435 0.01545 0.01266 0.00116 -0.07045 -0.03737 0.00276 0.01255 -0.00883]
ypred.shape: (10,) ypred: [-0.01938 0.04435 0.01545 0.01266 0.00116 -0.07045 -0.03737 0.00276 0.01255 -0.00883]
ypred.shape: (10,)
ypred: [-0.01938  0.04435  0.01545  0.01266  0.00116 -0.07045 -0.03737  0.00276  0.01255 -0.00883]

Auto-Vectorizing the Forward Pass

As implmented,

forward
forward is only capable of processing a single flattened image at a time. However, we can use
vmap
vmap, introduced in the first post, to process a batch of images at a time without any modification to
forward
forward.
vmap
vmap enables batch processing while taking advantage of JAX’s optimized execution. Instead of using loops, it efficiently maps a function over an array along a pre-specified axis:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
from jax import vmap
batch_forward = vmap(forward, in_axes=(None, 0))
from jax import vmap batch_forward = vmap(forward, in_axes=(None, 0))
from jax import vmap

batch_forward = vmap(forward, in_axes=(None, 0))

in_axes
in_axes controls which input array axes to vectorize over, and its length must equal the number of positional arguments associated with the original function. In our case, the first argument to
forward
forward is
params
params, which stays the same within the context of the forward pass. The second argument corresponds to our input image, and the ‘0’ indicates that vectorization should be applied along the 0th axis (which is batch dimension).

We can pass a batch of size 64 x 784 into

batch_forward
batch_forward, and return an output of size 64×10:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
# Get first batch of flattened training images.
X, y = train_batches[0]
ypreds = batch_forward(params, X)
print(f"ypreds.shape: {ypreds.shape}")
print(f"ypreds:\n{ypreds}")
# Get first batch of flattened training images. X, y = train_batches[0] ypreds = batch_forward(params, X) print(f"ypreds.shape: {ypreds.shape}") print(f"ypreds:\n{ypreds}")
# Get first batch of flattened training images.
X, y = train_batches[0]

ypreds = batch_forward(params, X)

print(f"ypreds.shape: {ypreds.shape}")
print(f"ypreds:\n{ypreds}")
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
ypreds.shape: (64, 10)
ypreds:
[[-0.01938 0.04435 0.01545 0.01266 0.00116 -0.07045 -0.03737 0.00276 0.01255 -0.00883]
[ 0.00129 0.0109 0.00602 -0.00877 -0.01916 -0.05659 -0.0483 -0.02185 -0.02209 0.01736]
[-0.02005 0.00758 0.02579 -0.03333 0.00352 -0.0578 -0.04556 -0.0556 -0.04529 0.06416]
[ 0.02695 0.00363 0.02104 -0.00848 0.03085 -0.06641 -0.03336 -0.01503 -0.04532 0.0081 ]
[-0.02473 -0.01859 0.04075 -0.00455 0.01993 -0.04665 -0.03698 -0.0063 -0.03235 0.03909]
[-0.02891 0.02183 0.01171 -0.00624 0.02155 -0.07441 -0.04383 -0.01707 -0.01268 0.02351]
[ 0.01002 -0.01742 0.01145 -0.03244 0.01756 -0.04202 -0.04045 -0.03138 -0.01564 0.03871]
[-0.01686 0.04742 0.04909 0.01424 -0.00733 -0.04773 -0.0763 -0.03617 0.01079 0.0016 ]
[ 0.00961 -0.00337 0.01458 -0.01242 0.01239 -0.03938 -0.01167 -0.01025 -0.02025 0.03112]
[-0.00639 -0.01172 0.04447 -0.0033 0.0276 -0.08083 -0.0409 -0.023 -0.03266 0.03201]
[-0.01249 0.02364 0.03073 -0.0103 0.01524 -0.04002 -0.07046 -0.00143 -0.03037 0.03692]
[-0.01155 -0.00816 0.01417 0.01124 0.01693 -0.03797 -0.00644 -0.00504 -0.04066 0.02742]
[-0.03043 0.03531 -0.00233 -0.02903 0.02681 -0.07377 -0.09124 -0.0348 -0.02728 0.01781]
[-0.03805 0.0015 0.02286 -0.02599 0.02493 -0.02707 -0.0489 0.00085 -0.03588 0.01173]
[-0.00322 -0.00982 0.00318 0.00086 0.01338 -0.05927 -0.01745 -0.019 -0.02419 0.04078]
[-0.00524 -0.03266 0.04316 -0.00133 0.02968 -0.06335 -0.01557 -0.02004 -0.05118 0.04478]
[ 0.01196 0.00738 0.0139 -0.01251 0.02977 -0.0482 -0.04011 -0.04932 -0.01502 0.00449]
[ 0.01775 -0.00792 0.02539 0.01883 0.01768 -0.06387 -0.0317 -0.0299 -0.01473 0.02359]
[-0.01174 0.01263 0.01114 -0.02254 0.01541 -0.0507 -0.04069 -0.00546 -0.02738 0.03351]
[ 0.00519 -0.00914 0.04497 0.01114 0.02668 -0.07109 -0.02906 0.00836 -0.03017 0.02743]
[-0.01821 -0.00343 0.01697 -0.00879 0.03184 -0.06576 -0.04461 -0.02077 -0.0576 0.02691]
[ 0.02505 -0.00406 0.01754 -0.01299 -0.03357 -0.0621 -0.05533 -0.01607 -0.05241 0.00787]
[-0.00548 -0.0151 0.02098 0.00532 -0.0007 -0.06299 -0.02551 0.02606 -0.02592 0.02215]
[ 0.02538 -0.00074 0.01906 -0.00907 0.03593 -0.0677 -0.02988 -0.01161 -0.04277 0.01133]
[ 0.00403 0.00756 0.02052 0.02553 0.02037 -0.03283 -0.03985 0.01178 -0.02486 -0.00577]
[-0.00577 0.01179 -0.00397 -0.01612 0.006 -0.05153 -0.08116 -0.0568 -0.02216 0.02107]
[-0.01831 -0.00545 0.03149 -0.00173 0.02892 -0.01969 -0.03263 -0.01362 -0.03298 0.03155]
[-0.02663 0.047 0.01241 -0.00937 -0.00653 -0.07457 -0.08388 -0.00652 -0.00656 0.03299]
[-0.01582 -0.00341 -0.00549 -0.07723 0.03082 -0.08325 -0.02868 -0.03757 -0.03521 0.05039]
[ 0.01516 -0.02163 0.03846 -0.00524 0.03453 -0.07424 -0.00601 -0.01762 -0.04946 0.02805]
[ 0.02345 0.03865 0.0027 -0.03011 0.02129 -0.06911 -0.0642 -0.06447 -0.02161 0.0513 ]
[-0.00958 0.00089 0.00206 0.00014 -0.03024 -0.05849 -0.03403 -0.04682 -0.02192 0.02192]
[ 0.01296 0.01323 -0.00675 -0.02761 0.0001 -0.02599 -0.03824 -0.03056 -0.03029 0.00598]
[ 0.015 0.00759 0.03926 0.00927 0.0302 -0.08289 -0.04693 0.00788 -0.02175 0.031 ]
[-0.03891 0.03415 0.00166 -0.01509 0.0031 -0.05396 -0.07296 -0.00382 -0.02682 0.05255]
[ 0.01267 0.00989 0.02874 0.01091 0.04832 -0.05603 -0.01568 -0.00115 -0.02438 0.00899]
[-0.02444 0.03956 -0.00092 -0.01447 -0.00312 -0.06536 -0.05304 -0.00953 -0.0111 0.02981]
[-0.0344 0.04178 0.01164 -0.03304 0.01373 -0.06036 -0.07474 0.01716 -0.03821 0.02719]
[ 0.02056 -0.04235 0.02223 -0.03208 0.02435 -0.03469 0.0082 -0.01539 -0.02368 0.00596]
[-0.01846 0.01328 -0.01612 -0.0376 -0.0026 -0.01918 -0.03758 -0.0154 -0.01747 0.02746]
[-0.01283 -0.01919 0.0135 0.00131 0.01254 -0.04955 -0.02221 -0.01261 -0.02833 0.03626]
[ 0.01384 0.017 0.02705 -0.01914 0.03964 -0.05375 -0.05693 -0.01484 -0.01374 0.03448]
[-0.00959 -0.03199 0.0208 0.01219 0.01838 -0.04273 0.0039 -0.00971 -0.02165 0.02525]
[-0.00754 -0.02387 0.03144 0.01683 0.01158 -0.04538 -0.02697 0.00096 -0.02717 0.0279 ]
[-0.00149 -0.01607 0.03226 0.00932 0.03503 -0.04532 -0.02371 -0.04158 -0.02733 0.01295]
[-0.00841 -0.02569 0.02426 -0.01079 0.04106 -0.04494 -0.02315 -0.01234 -0.03156 0.03288]
[ 0.03858 -0.00225 0.03607 0.01419 0.05866 -0.04707 -0.0479 -0.0585 -0.03581 0.0068 ]
[ 0.00396 0.00229 0.03064 0.00296 0.04075 -0.05858 -0.02379 0.01113 -0.01696 0.02085]
[-0.0306 0.01159 0.01966 -0.0238 0.0085 -0.04504 -0.07059 -0.02362 -0.02067 0.04078]
[-0.03142 0.03721 0.02696 0.00428 0.00507 -0.06188 -0.07031 -0.00547 -0.01252 0.0217 ]
[-0.01006 -0.00553 0.009 -0.02372 0.00812 -0.04891 -0.06205 -0.01728 -0.02563 0.02564]
[ 0.00656 0.03307 0.01764 0.00273 -0.00087 -0.07939 -0.07847 -0.05009 -0.03572 0.00226]
[-0.01152 0.00209 0.01268 -0.02766 0.02629 -0.08758 -0.02191 -0.01157 -0.03755 0.07643]
[ 0.00849 -0.00357 0.04292 0.00966 0.01071 -0.06685 -0.01274 -0.03098 -0.04872 0.04148]
[-0.02366 0.01482 0.0214 -0.0017 0.03714 -0.06655 -0.00035 -0.03575 -0.00123 0.06236]
[-0.00574 -0.01814 0.01388 -0.01289 0.03698 -0.05857 -0.04548 -0.01367 -0.03049 0.05411]
[-0.00067 0.03152 0.01473 -0.00112 0.04264 -0.07165 -0.08533 -0.07986 -0.03663 0.02476]
[ 0.00238 -0.01577 0.03842 -0.01902 0.03321 -0.05336 -0.03538 0.00514 -0.03226 0.02748]
[-0.05187 -0.02097 0.04671 -0.02635 0.03248 -0.0399 -0.06063 -0.01354 -0.03793 0.04167]
[ 0.02789 -0.00258 0.01699 -0.01496 0.02993 -0.06447 -0.03116 -0.00742 -0.03525 0.01357]
[ 0.03633 0.0008 0.01882 -0.00777 0.00483 -0.0582 -0.05376 -0.08621 -0.0305 0.0121 ]
[ 0.00676 0.01654 0.03401 0.00623 0.035 -0.06879 -0.0183 -0.00864 -0.03068 0.02204]
[-0.04109 0.01182 0.00634 -0.02956 -0.00965 -0.064 -0.05708 -0.01416 -0.01804 0.03834]
[-0.00258 0.03007 0.04005 -0.00407 -0.04672 -0.05536 -0.07473 -0.01553 -0.01838 -0.00027]]
ypreds.shape: (64, 10) ypreds: [[-0.01938 0.04435 0.01545 0.01266 0.00116 -0.07045 -0.03737 0.00276 0.01255 -0.00883] [ 0.00129 0.0109 0.00602 -0.00877 -0.01916 -0.05659 -0.0483 -0.02185 -0.02209 0.01736] [-0.02005 0.00758 0.02579 -0.03333 0.00352 -0.0578 -0.04556 -0.0556 -0.04529 0.06416] [ 0.02695 0.00363 0.02104 -0.00848 0.03085 -0.06641 -0.03336 -0.01503 -0.04532 0.0081 ] [-0.02473 -0.01859 0.04075 -0.00455 0.01993 -0.04665 -0.03698 -0.0063 -0.03235 0.03909] [-0.02891 0.02183 0.01171 -0.00624 0.02155 -0.07441 -0.04383 -0.01707 -0.01268 0.02351] [ 0.01002 -0.01742 0.01145 -0.03244 0.01756 -0.04202 -0.04045 -0.03138 -0.01564 0.03871] [-0.01686 0.04742 0.04909 0.01424 -0.00733 -0.04773 -0.0763 -0.03617 0.01079 0.0016 ] [ 0.00961 -0.00337 0.01458 -0.01242 0.01239 -0.03938 -0.01167 -0.01025 -0.02025 0.03112] [-0.00639 -0.01172 0.04447 -0.0033 0.0276 -0.08083 -0.0409 -0.023 -0.03266 0.03201] [-0.01249 0.02364 0.03073 -0.0103 0.01524 -0.04002 -0.07046 -0.00143 -0.03037 0.03692] [-0.01155 -0.00816 0.01417 0.01124 0.01693 -0.03797 -0.00644 -0.00504 -0.04066 0.02742] [-0.03043 0.03531 -0.00233 -0.02903 0.02681 -0.07377 -0.09124 -0.0348 -0.02728 0.01781] [-0.03805 0.0015 0.02286 -0.02599 0.02493 -0.02707 -0.0489 0.00085 -0.03588 0.01173] [-0.00322 -0.00982 0.00318 0.00086 0.01338 -0.05927 -0.01745 -0.019 -0.02419 0.04078] [-0.00524 -0.03266 0.04316 -0.00133 0.02968 -0.06335 -0.01557 -0.02004 -0.05118 0.04478] [ 0.01196 0.00738 0.0139 -0.01251 0.02977 -0.0482 -0.04011 -0.04932 -0.01502 0.00449] [ 0.01775 -0.00792 0.02539 0.01883 0.01768 -0.06387 -0.0317 -0.0299 -0.01473 0.02359] [-0.01174 0.01263 0.01114 -0.02254 0.01541 -0.0507 -0.04069 -0.00546 -0.02738 0.03351] [ 0.00519 -0.00914 0.04497 0.01114 0.02668 -0.07109 -0.02906 0.00836 -0.03017 0.02743] [-0.01821 -0.00343 0.01697 -0.00879 0.03184 -0.06576 -0.04461 -0.02077 -0.0576 0.02691] [ 0.02505 -0.00406 0.01754 -0.01299 -0.03357 -0.0621 -0.05533 -0.01607 -0.05241 0.00787] [-0.00548 -0.0151 0.02098 0.00532 -0.0007 -0.06299 -0.02551 0.02606 -0.02592 0.02215] [ 0.02538 -0.00074 0.01906 -0.00907 0.03593 -0.0677 -0.02988 -0.01161 -0.04277 0.01133] [ 0.00403 0.00756 0.02052 0.02553 0.02037 -0.03283 -0.03985 0.01178 -0.02486 -0.00577] [-0.00577 0.01179 -0.00397 -0.01612 0.006 -0.05153 -0.08116 -0.0568 -0.02216 0.02107] [-0.01831 -0.00545 0.03149 -0.00173 0.02892 -0.01969 -0.03263 -0.01362 -0.03298 0.03155] [-0.02663 0.047 0.01241 -0.00937 -0.00653 -0.07457 -0.08388 -0.00652 -0.00656 0.03299] [-0.01582 -0.00341 -0.00549 -0.07723 0.03082 -0.08325 -0.02868 -0.03757 -0.03521 0.05039] [ 0.01516 -0.02163 0.03846 -0.00524 0.03453 -0.07424 -0.00601 -0.01762 -0.04946 0.02805] [ 0.02345 0.03865 0.0027 -0.03011 0.02129 -0.06911 -0.0642 -0.06447 -0.02161 0.0513 ] [-0.00958 0.00089 0.00206 0.00014 -0.03024 -0.05849 -0.03403 -0.04682 -0.02192 0.02192] [ 0.01296 0.01323 -0.00675 -0.02761 0.0001 -0.02599 -0.03824 -0.03056 -0.03029 0.00598] [ 0.015 0.00759 0.03926 0.00927 0.0302 -0.08289 -0.04693 0.00788 -0.02175 0.031 ] [-0.03891 0.03415 0.00166 -0.01509 0.0031 -0.05396 -0.07296 -0.00382 -0.02682 0.05255] [ 0.01267 0.00989 0.02874 0.01091 0.04832 -0.05603 -0.01568 -0.00115 -0.02438 0.00899] [-0.02444 0.03956 -0.00092 -0.01447 -0.00312 -0.06536 -0.05304 -0.00953 -0.0111 0.02981] [-0.0344 0.04178 0.01164 -0.03304 0.01373 -0.06036 -0.07474 0.01716 -0.03821 0.02719] [ 0.02056 -0.04235 0.02223 -0.03208 0.02435 -0.03469 0.0082 -0.01539 -0.02368 0.00596] [-0.01846 0.01328 -0.01612 -0.0376 -0.0026 -0.01918 -0.03758 -0.0154 -0.01747 0.02746] [-0.01283 -0.01919 0.0135 0.00131 0.01254 -0.04955 -0.02221 -0.01261 -0.02833 0.03626] [ 0.01384 0.017 0.02705 -0.01914 0.03964 -0.05375 -0.05693 -0.01484 -0.01374 0.03448] [-0.00959 -0.03199 0.0208 0.01219 0.01838 -0.04273 0.0039 -0.00971 -0.02165 0.02525] [-0.00754 -0.02387 0.03144 0.01683 0.01158 -0.04538 -0.02697 0.00096 -0.02717 0.0279 ] [-0.00149 -0.01607 0.03226 0.00932 0.03503 -0.04532 -0.02371 -0.04158 -0.02733 0.01295] [-0.00841 -0.02569 0.02426 -0.01079 0.04106 -0.04494 -0.02315 -0.01234 -0.03156 0.03288] [ 0.03858 -0.00225 0.03607 0.01419 0.05866 -0.04707 -0.0479 -0.0585 -0.03581 0.0068 ] [ 0.00396 0.00229 0.03064 0.00296 0.04075 -0.05858 -0.02379 0.01113 -0.01696 0.02085] [-0.0306 0.01159 0.01966 -0.0238 0.0085 -0.04504 -0.07059 -0.02362 -0.02067 0.04078] [-0.03142 0.03721 0.02696 0.00428 0.00507 -0.06188 -0.07031 -0.00547 -0.01252 0.0217 ] [-0.01006 -0.00553 0.009 -0.02372 0.00812 -0.04891 -0.06205 -0.01728 -0.02563 0.02564] [ 0.00656 0.03307 0.01764 0.00273 -0.00087 -0.07939 -0.07847 -0.05009 -0.03572 0.00226] [-0.01152 0.00209 0.01268 -0.02766 0.02629 -0.08758 -0.02191 -0.01157 -0.03755 0.07643] [ 0.00849 -0.00357 0.04292 0.00966 0.01071 -0.06685 -0.01274 -0.03098 -0.04872 0.04148] [-0.02366 0.01482 0.0214 -0.0017 0.03714 -0.06655 -0.00035 -0.03575 -0.00123 0.06236] [-0.00574 -0.01814 0.01388 -0.01289 0.03698 -0.05857 -0.04548 -0.01367 -0.03049 0.05411] [-0.00067 0.03152 0.01473 -0.00112 0.04264 -0.07165 -0.08533 -0.07986 -0.03663 0.02476] [ 0.00238 -0.01577 0.03842 -0.01902 0.03321 -0.05336 -0.03538 0.00514 -0.03226 0.02748] [-0.05187 -0.02097 0.04671 -0.02635 0.03248 -0.0399 -0.06063 -0.01354 -0.03793 0.04167] [ 0.02789 -0.00258 0.01699 -0.01496 0.02993 -0.06447 -0.03116 -0.00742 -0.03525 0.01357] [ 0.03633 0.0008 0.01882 -0.00777 0.00483 -0.0582 -0.05376 -0.08621 -0.0305 0.0121 ] [ 0.00676 0.01654 0.03401 0.00623 0.035 -0.06879 -0.0183 -0.00864 -0.03068 0.02204] [-0.04109 0.01182 0.00634 -0.02956 -0.00965 -0.064 -0.05708 -0.01416 -0.01804 0.03834] [-0.00258 0.03007 0.04005 -0.00407 -0.04672 -0.05536 -0.07473 -0.01553 -0.01838 -0.00027]]
ypreds.shape: (64, 10)
ypreds:
[[-0.01938  0.04435  0.01545  0.01266  0.00116 -0.07045 -0.03737  0.00276  0.01255 -0.00883]
 [ 0.00129  0.0109   0.00602 -0.00877 -0.01916 -0.05659 -0.0483  -0.02185 -0.02209  0.01736]
 [-0.02005  0.00758  0.02579 -0.03333  0.00352 -0.0578  -0.04556 -0.0556  -0.04529  0.06416]
 [ 0.02695  0.00363  0.02104 -0.00848  0.03085 -0.06641 -0.03336 -0.01503 -0.04532  0.0081 ]
 [-0.02473 -0.01859  0.04075 -0.00455  0.01993 -0.04665 -0.03698 -0.0063  -0.03235  0.03909]
 [-0.02891  0.02183  0.01171 -0.00624  0.02155 -0.07441 -0.04383 -0.01707 -0.01268  0.02351]
 [ 0.01002 -0.01742  0.01145 -0.03244  0.01756 -0.04202 -0.04045 -0.03138 -0.01564  0.03871]
 [-0.01686  0.04742  0.04909  0.01424 -0.00733 -0.04773 -0.0763  -0.03617  0.01079  0.0016 ]
 [ 0.00961 -0.00337  0.01458 -0.01242  0.01239 -0.03938 -0.01167 -0.01025 -0.02025  0.03112]
 [-0.00639 -0.01172  0.04447 -0.0033   0.0276  -0.08083 -0.0409  -0.023   -0.03266  0.03201]
 [-0.01249  0.02364  0.03073 -0.0103   0.01524 -0.04002 -0.07046 -0.00143 -0.03037  0.03692]
 [-0.01155 -0.00816  0.01417  0.01124  0.01693 -0.03797 -0.00644 -0.00504 -0.04066  0.02742]
 [-0.03043  0.03531 -0.00233 -0.02903  0.02681 -0.07377 -0.09124 -0.0348  -0.02728  0.01781]
 [-0.03805  0.0015   0.02286 -0.02599  0.02493 -0.02707 -0.0489   0.00085 -0.03588  0.01173]
 [-0.00322 -0.00982  0.00318  0.00086  0.01338 -0.05927 -0.01745 -0.019   -0.02419  0.04078]
 [-0.00524 -0.03266  0.04316 -0.00133  0.02968 -0.06335 -0.01557 -0.02004 -0.05118  0.04478]
 [ 0.01196  0.00738  0.0139  -0.01251  0.02977 -0.0482  -0.04011 -0.04932 -0.01502  0.00449]
 [ 0.01775 -0.00792  0.02539  0.01883  0.01768 -0.06387 -0.0317  -0.0299  -0.01473  0.02359]
 [-0.01174  0.01263  0.01114 -0.02254  0.01541 -0.0507  -0.04069 -0.00546 -0.02738  0.03351]
 [ 0.00519 -0.00914  0.04497  0.01114  0.02668 -0.07109 -0.02906  0.00836 -0.03017  0.02743]
 [-0.01821 -0.00343  0.01697 -0.00879  0.03184 -0.06576 -0.04461 -0.02077 -0.0576   0.02691]
 [ 0.02505 -0.00406  0.01754 -0.01299 -0.03357 -0.0621  -0.05533 -0.01607 -0.05241  0.00787]
 [-0.00548 -0.0151   0.02098  0.00532 -0.0007  -0.06299 -0.02551  0.02606 -0.02592  0.02215]
 [ 0.02538 -0.00074  0.01906 -0.00907  0.03593 -0.0677  -0.02988 -0.01161 -0.04277  0.01133]
 [ 0.00403  0.00756  0.02052  0.02553  0.02037 -0.03283 -0.03985  0.01178 -0.02486 -0.00577]
 [-0.00577  0.01179 -0.00397 -0.01612  0.006   -0.05153 -0.08116 -0.0568  -0.02216  0.02107]
 [-0.01831 -0.00545  0.03149 -0.00173  0.02892 -0.01969 -0.03263 -0.01362 -0.03298  0.03155]
 [-0.02663  0.047    0.01241 -0.00937 -0.00653 -0.07457 -0.08388 -0.00652 -0.00656  0.03299]
 [-0.01582 -0.00341 -0.00549 -0.07723  0.03082 -0.08325 -0.02868 -0.03757 -0.03521  0.05039]
 [ 0.01516 -0.02163  0.03846 -0.00524  0.03453 -0.07424 -0.00601 -0.01762 -0.04946  0.02805]
 [ 0.02345  0.03865  0.0027  -0.03011  0.02129 -0.06911 -0.0642  -0.06447 -0.02161  0.0513 ]
 [-0.00958  0.00089  0.00206  0.00014 -0.03024 -0.05849 -0.03403 -0.04682 -0.02192  0.02192]
 [ 0.01296  0.01323 -0.00675 -0.02761  0.0001  -0.02599 -0.03824 -0.03056 -0.03029  0.00598]
 [ 0.015    0.00759  0.03926  0.00927  0.0302  -0.08289 -0.04693  0.00788 -0.02175  0.031  ]
 [-0.03891  0.03415  0.00166 -0.01509  0.0031  -0.05396 -0.07296 -0.00382 -0.02682  0.05255]
 [ 0.01267  0.00989  0.02874  0.01091  0.04832 -0.05603 -0.01568 -0.00115 -0.02438  0.00899]
 [-0.02444  0.03956 -0.00092 -0.01447 -0.00312 -0.06536 -0.05304 -0.00953 -0.0111   0.02981]
 [-0.0344   0.04178  0.01164 -0.03304  0.01373 -0.06036 -0.07474  0.01716 -0.03821  0.02719]
 [ 0.02056 -0.04235  0.02223 -0.03208  0.02435 -0.03469  0.0082  -0.01539 -0.02368  0.00596]
 [-0.01846  0.01328 -0.01612 -0.0376  -0.0026  -0.01918 -0.03758 -0.0154  -0.01747  0.02746]
 [-0.01283 -0.01919  0.0135   0.00131  0.01254 -0.04955 -0.02221 -0.01261 -0.02833  0.03626]
 [ 0.01384  0.017    0.02705 -0.01914  0.03964 -0.05375 -0.05693 -0.01484 -0.01374  0.03448]
 [-0.00959 -0.03199  0.0208   0.01219  0.01838 -0.04273  0.0039  -0.00971 -0.02165  0.02525]
 [-0.00754 -0.02387  0.03144  0.01683  0.01158 -0.04538 -0.02697  0.00096 -0.02717  0.0279 ]
 [-0.00149 -0.01607  0.03226  0.00932  0.03503 -0.04532 -0.02371 -0.04158 -0.02733  0.01295]
 [-0.00841 -0.02569  0.02426 -0.01079  0.04106 -0.04494 -0.02315 -0.01234 -0.03156  0.03288]
 [ 0.03858 -0.00225  0.03607  0.01419  0.05866 -0.04707 -0.0479  -0.0585  -0.03581  0.0068 ]
 [ 0.00396  0.00229  0.03064  0.00296  0.04075 -0.05858 -0.02379  0.01113 -0.01696  0.02085]
 [-0.0306   0.01159  0.01966 -0.0238   0.0085  -0.04504 -0.07059 -0.02362 -0.02067  0.04078]
 [-0.03142  0.03721  0.02696  0.00428  0.00507 -0.06188 -0.07031 -0.00547 -0.01252  0.0217 ]
 [-0.01006 -0.00553  0.009   -0.02372  0.00812 -0.04891 -0.06205 -0.01728 -0.02563  0.02564]
 [ 0.00656  0.03307  0.01764  0.00273 -0.00087 -0.07939 -0.07847 -0.05009 -0.03572  0.00226]
 [-0.01152  0.00209  0.01268 -0.02766  0.02629 -0.08758 -0.02191 -0.01157 -0.03755  0.07643]
 [ 0.00849 -0.00357  0.04292  0.00966  0.01071 -0.06685 -0.01274 -0.03098 -0.04872  0.04148]
 [-0.02366  0.01482  0.0214  -0.0017   0.03714 -0.06655 -0.00035 -0.03575 -0.00123  0.06236]
 [-0.00574 -0.01814  0.01388 -0.01289  0.03698 -0.05857 -0.04548 -0.01367 -0.03049  0.05411]
 [-0.00067  0.03152  0.01473 -0.00112  0.04264 -0.07165 -0.08533 -0.07986 -0.03663  0.02476]
 [ 0.00238 -0.01577  0.03842 -0.01902  0.03321 -0.05336 -0.03538  0.00514 -0.03226  0.02748]
 [-0.05187 -0.02097  0.04671 -0.02635  0.03248 -0.0399  -0.06063 -0.01354 -0.03793  0.04167]
 [ 0.02789 -0.00258  0.01699 -0.01496  0.02993 -0.06447 -0.03116 -0.00742 -0.03525  0.01357]
 [ 0.03633  0.0008   0.01882 -0.00777  0.00483 -0.0582  -0.05376 -0.08621 -0.0305   0.0121 ]
 [ 0.00676  0.01654  0.03401  0.00623  0.035   -0.06879 -0.0183  -0.00864 -0.03068  0.02204]
 [-0.04109  0.01182  0.00634 -0.02956 -0.00965 -0.064   -0.05708 -0.01416 -0.01804  0.03834]
 [-0.00258  0.03007  0.04005 -0.00407 -0.04672 -0.05536 -0.07473 -0.01553 -0.01838 -0.00027]]

To ensure each row sums to 1, we can apply softmax to

ypreds
ypreds:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
from jax.nn import softmax
yprobs = softmax(ypreds, axis=1)
print(f"yprobs.shape: {yprobs.shape}")
print(f"yprobs:\n{yprobs}")
from jax.nn import softmax yprobs = softmax(ypreds, axis=1) print(f"yprobs.shape: {yprobs.shape}") print(f"yprobs:\n{yprobs}")
from jax.nn import softmax

yprobs = softmax(ypreds, axis=1)

print(f"yprobs.shape: {yprobs.shape}")
print(f"yprobs:\n{yprobs}")
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
yprobs.shape: (64, 10)
yprobs:
[[0.0985 0.10498 0.10199 0.10171 0.10054 0.09359 0.09674 0.1007 0.10169 0.09954]
[0.10153 0.10251 0.10201 0.10051 0.09947 0.09582 0.09661 0.0992 0.09918 0.10317]
[0.09949 0.10228 0.10416 0.09818 0.10186 0.0958 0.09698 0.09601 0.09701 0.10823]
[0.10349 0.1011 0.10288 0.09988 0.10389 0.09426 0.09743 0.09923 0.09627 0.10155]
[0.0982 0.09881 0.10485 0.1002 0.10269 0.09607 0.09701 0.10003 0.09746 0.10468]
[0.09813 0.10323 0.10219 0.10038 0.1032 0.09376 0.09667 0.09929 0.09973 0.10341]
[0.102 0.09924 0.10215 0.09776 0.10277 0.09683 0.09698 0.09787 0.09942 0.10497]
[0.09886 0.10543 0.1056 0.10198 0.09981 0.09586 0.09316 0.09697 0.10163 0.1007 ]
[0.10125 0.09994 0.10175 0.09904 0.10153 0.09641 0.09911 0.09926 0.09827 0.10345]
[0.10024 0.09971 0.10547 0.10055 0.10371 0.09305 0.09684 0.09859 0.09765 0.10417]
[0.09929 0.10294 0.10367 0.0995 0.10208 0.09659 0.09369 0.10039 0.09753 0.10432]
[0.09923 0.09956 0.10181 0.10151 0.10209 0.09664 0.09973 0.09987 0.09638 0.10317]
[0.09897 0.1057 0.1018 0.09911 0.10481 0.09478 0.09314 0.09854 0.09929 0.10387]
[0.09734 0.10127 0.10345 0.09852 0.10367 0.09841 0.09629 0.1012 0.09755 0.10231]
[0.10039 0.09973 0.10104 0.10081 0.10208 0.09492 0.09898 0.09882 0.09831 0.10491]
[0.10013 0.09742 0.1051 0.10052 0.10369 0.09448 0.0991 0.09866 0.09563 0.10527]
[0.10216 0.10169 0.10236 0.09969 0.104 0.0962 0.09698 0.09609 0.09944 0.1014 ]
[0.10221 0.09962 0.10299 0.10232 0.1022 0.0942 0.09728 0.09745 0.09894 0.10281]
[0.09965 0.10211 0.10196 0.09858 0.10239 0.09584 0.09681 0.10028 0.09811 0.10426]
[0.10063 0.09919 0.10471 0.10123 0.10281 0.09323 0.09724 0.10094 0.09713 0.10289]
[0.09956 0.10104 0.10313 0.1005 0.10467 0.09494 0.09697 0.09931 0.09572 0.10416]
[0.10442 0.10142 0.10364 0.10052 0.09847 0.0957 0.09635 0.10021 0.09663 0.10264]
[0.10003 0.09907 0.10271 0.10112 0.10051 0.09444 0.09805 0.10324 0.09801 0.10283]
[0.10324 0.10058 0.10259 0.09975 0.10434 0.09407 0.09769 0.09949 0.09644 0.1018 ]
[0.10052 0.10087 0.10219 0.1027 0.10217 0.09688 0.0962 0.1013 0.09765 0.09953]
[0.10137 0.10317 0.10155 0.10033 0.10257 0.09684 0.09401 0.09633 0.09972 0.10413]
[0.09848 0.09975 0.1035 0.10012 0.10324 0.09834 0.09708 0.09894 0.09704 0.10351]
[0.09849 0.10601 0.10241 0.1002 0.10049 0.09388 0.09301 0.10049 0.10049 0.10454]
[0.10039 0.10165 0.10144 0.09441 0.10519 0.09385 0.09911 0.09823 0.09847 0.10727]
[0.10206 0.09837 0.10446 0.1 0.10405 0.09333 0.09992 0.09876 0.09567 0.10338]
[0.10343 0.10502 0.10131 0.09804 0.10321 0.09429 0.09475 0.09473 0.09887 0.10635]
[0.10078 0.10184 0.10196 0.10176 0.09872 0.09597 0.09834 0.09709 0.09954 0.104 ]
[0.10258 0.10261 0.10058 0.0985 0.10127 0.09866 0.09746 0.09821 0.09824 0.10187]
[0.10156 0.10081 0.10405 0.10098 0.10312 0.09209 0.09546 0.10084 0.0979 0.1032 ]
[0.09728 0.10466 0.10131 0.09963 0.10146 0.09583 0.09403 0.10076 0.09846 0.1066 ]
[0.10101 0.10073 0.10265 0.10083 0.10468 0.0943 0.09819 0.09963 0.09734 0.10064]
[0.09864 0.10516 0.10099 0.09963 0.10077 0.09469 0.09586 0.10013 0.09997 0.10414]
[0.0978 0.10555 0.10241 0.09794 0.10263 0.0953 0.09394 0.10298 0.09743 0.10402]
[0.10273 0.09647 0.1029 0.09746 0.10312 0.09721 0.10147 0.0991 0.09829 0.10124]
[0.09937 0.10258 0.09961 0.09749 0.10096 0.0993 0.09749 0.09968 0.09947 0.10404]
[0.0995 0.09887 0.10216 0.10092 0.10206 0.09591 0.09857 0.09952 0.09797 0.10451]
[0.10161 0.10193 0.10296 0.09831 0.10426 0.09497 0.09466 0.09873 0.09884 0.10373]
[0.09937 0.09717 0.10244 0.10156 0.10219 0.09613 0.10072 0.09936 0.09818 0.10289]
[0.09964 0.09802 0.1036 0.1021 0.10156 0.09594 0.09772 0.10049 0.0977 0.10323]
[0.10047 0.09902 0.10392 0.10157 0.10421 0.09617 0.09827 0.09653 0.09791 0.10194]
[0.09971 0.098 0.10302 0.09947 0.10476 0.09613 0.09825 0.09932 0.09743 0.10391]
[0.10424 0.10007 0.10398 0.10173 0.10635 0.09568 0.0956 0.0946 0.09677 0.10098]
[0.10023 0.10006 0.10294 0.10013 0.10398 0.09415 0.09748 0.10095 0.09815 0.10193]
[0.09824 0.10248 0.10331 0.09891 0.10216 0.09684 0.09439 0.09893 0.09922 0.10551]
[0.09769 0.10463 0.10356 0.10124 0.10132 0.09476 0.09396 0.10026 0.09955 0.10302]
[0.10047 0.10092 0.1024 0.0991 0.10231 0.09664 0.09538 0.09974 0.09892 0.10412]
[0.10244 0.10519 0.10358 0.10205 0.10168 0.094 0.09409 0.09679 0.0982 0.102 ]
[0.09957 0.10093 0.10201 0.09797 0.10341 0.09228 0.09854 0.09956 0.09701 0.10872]
[0.1013 0.10008 0.10484 0.10141 0.10152 0.09394 0.09917 0.09738 0.09566 0.10469]
[0.09754 0.10137 0.10203 0.0997 0.10365 0.09344 0.09984 0.09637 0.09975 0.1063 ]
[0.10017 0.09894 0.10216 0.09946 0.10454 0.09502 0.09627 0.09938 0.09772 0.10635]
[0.10145 0.10477 0.10303 0.10141 0.10594 0.0945 0.09322 0.09373 0.09787 0.10407]
[0.10069 0.09888 0.10438 0.09856 0.10384 0.09523 0.09696 0.10097 0.09726 0.10325]
[0.09612 0.09914 0.10608 0.09861 0.10458 0.09728 0.09528 0.09988 0.09747 0.10555]
[0.10348 0.10038 0.10236 0.09914 0.10369 0.09435 0.09755 0.09989 0.09715 0.10201]
[0.10534 0.10166 0.10351 0.10079 0.10207 0.09584 0.09626 0.09319 0.09853 0.10282]
[0.10069 0.10168 0.10347 0.10064 0.10358 0.09336 0.0982 0.09915 0.09699 0.10224]
[0.09764 0.10295 0.10239 0.09878 0.10076 0.09543 0.0961 0.10031 0.09992 0.10572]
[0.10117 0.10452 0.10557 0.10102 0.0968 0.09596 0.09412 0.09986 0.09958 0.1014 ]]
yprobs.shape: (64, 10) yprobs: [[0.0985 0.10498 0.10199 0.10171 0.10054 0.09359 0.09674 0.1007 0.10169 0.09954] [0.10153 0.10251 0.10201 0.10051 0.09947 0.09582 0.09661 0.0992 0.09918 0.10317] [0.09949 0.10228 0.10416 0.09818 0.10186 0.0958 0.09698 0.09601 0.09701 0.10823] [0.10349 0.1011 0.10288 0.09988 0.10389 0.09426 0.09743 0.09923 0.09627 0.10155] [0.0982 0.09881 0.10485 0.1002 0.10269 0.09607 0.09701 0.10003 0.09746 0.10468] [0.09813 0.10323 0.10219 0.10038 0.1032 0.09376 0.09667 0.09929 0.09973 0.10341] [0.102 0.09924 0.10215 0.09776 0.10277 0.09683 0.09698 0.09787 0.09942 0.10497] [0.09886 0.10543 0.1056 0.10198 0.09981 0.09586 0.09316 0.09697 0.10163 0.1007 ] [0.10125 0.09994 0.10175 0.09904 0.10153 0.09641 0.09911 0.09926 0.09827 0.10345] [0.10024 0.09971 0.10547 0.10055 0.10371 0.09305 0.09684 0.09859 0.09765 0.10417] [0.09929 0.10294 0.10367 0.0995 0.10208 0.09659 0.09369 0.10039 0.09753 0.10432] [0.09923 0.09956 0.10181 0.10151 0.10209 0.09664 0.09973 0.09987 0.09638 0.10317] [0.09897 0.1057 0.1018 0.09911 0.10481 0.09478 0.09314 0.09854 0.09929 0.10387] [0.09734 0.10127 0.10345 0.09852 0.10367 0.09841 0.09629 0.1012 0.09755 0.10231] [0.10039 0.09973 0.10104 0.10081 0.10208 0.09492 0.09898 0.09882 0.09831 0.10491] [0.10013 0.09742 0.1051 0.10052 0.10369 0.09448 0.0991 0.09866 0.09563 0.10527] [0.10216 0.10169 0.10236 0.09969 0.104 0.0962 0.09698 0.09609 0.09944 0.1014 ] [0.10221 0.09962 0.10299 0.10232 0.1022 0.0942 0.09728 0.09745 0.09894 0.10281] [0.09965 0.10211 0.10196 0.09858 0.10239 0.09584 0.09681 0.10028 0.09811 0.10426] [0.10063 0.09919 0.10471 0.10123 0.10281 0.09323 0.09724 0.10094 0.09713 0.10289] [0.09956 0.10104 0.10313 0.1005 0.10467 0.09494 0.09697 0.09931 0.09572 0.10416] [0.10442 0.10142 0.10364 0.10052 0.09847 0.0957 0.09635 0.10021 0.09663 0.10264] [0.10003 0.09907 0.10271 0.10112 0.10051 0.09444 0.09805 0.10324 0.09801 0.10283] [0.10324 0.10058 0.10259 0.09975 0.10434 0.09407 0.09769 0.09949 0.09644 0.1018 ] [0.10052 0.10087 0.10219 0.1027 0.10217 0.09688 0.0962 0.1013 0.09765 0.09953] [0.10137 0.10317 0.10155 0.10033 0.10257 0.09684 0.09401 0.09633 0.09972 0.10413] [0.09848 0.09975 0.1035 0.10012 0.10324 0.09834 0.09708 0.09894 0.09704 0.10351] [0.09849 0.10601 0.10241 0.1002 0.10049 0.09388 0.09301 0.10049 0.10049 0.10454] [0.10039 0.10165 0.10144 0.09441 0.10519 0.09385 0.09911 0.09823 0.09847 0.10727] [0.10206 0.09837 0.10446 0.1 0.10405 0.09333 0.09992 0.09876 0.09567 0.10338] [0.10343 0.10502 0.10131 0.09804 0.10321 0.09429 0.09475 0.09473 0.09887 0.10635] [0.10078 0.10184 0.10196 0.10176 0.09872 0.09597 0.09834 0.09709 0.09954 0.104 ] [0.10258 0.10261 0.10058 0.0985 0.10127 0.09866 0.09746 0.09821 0.09824 0.10187] [0.10156 0.10081 0.10405 0.10098 0.10312 0.09209 0.09546 0.10084 0.0979 0.1032 ] [0.09728 0.10466 0.10131 0.09963 0.10146 0.09583 0.09403 0.10076 0.09846 0.1066 ] [0.10101 0.10073 0.10265 0.10083 0.10468 0.0943 0.09819 0.09963 0.09734 0.10064] [0.09864 0.10516 0.10099 0.09963 0.10077 0.09469 0.09586 0.10013 0.09997 0.10414] [0.0978 0.10555 0.10241 0.09794 0.10263 0.0953 0.09394 0.10298 0.09743 0.10402] [0.10273 0.09647 0.1029 0.09746 0.10312 0.09721 0.10147 0.0991 0.09829 0.10124] [0.09937 0.10258 0.09961 0.09749 0.10096 0.0993 0.09749 0.09968 0.09947 0.10404] [0.0995 0.09887 0.10216 0.10092 0.10206 0.09591 0.09857 0.09952 0.09797 0.10451] [0.10161 0.10193 0.10296 0.09831 0.10426 0.09497 0.09466 0.09873 0.09884 0.10373] [0.09937 0.09717 0.10244 0.10156 0.10219 0.09613 0.10072 0.09936 0.09818 0.10289] [0.09964 0.09802 0.1036 0.1021 0.10156 0.09594 0.09772 0.10049 0.0977 0.10323] [0.10047 0.09902 0.10392 0.10157 0.10421 0.09617 0.09827 0.09653 0.09791 0.10194] [0.09971 0.098 0.10302 0.09947 0.10476 0.09613 0.09825 0.09932 0.09743 0.10391] [0.10424 0.10007 0.10398 0.10173 0.10635 0.09568 0.0956 0.0946 0.09677 0.10098] [0.10023 0.10006 0.10294 0.10013 0.10398 0.09415 0.09748 0.10095 0.09815 0.10193] [0.09824 0.10248 0.10331 0.09891 0.10216 0.09684 0.09439 0.09893 0.09922 0.10551] [0.09769 0.10463 0.10356 0.10124 0.10132 0.09476 0.09396 0.10026 0.09955 0.10302] [0.10047 0.10092 0.1024 0.0991 0.10231 0.09664 0.09538 0.09974 0.09892 0.10412] [0.10244 0.10519 0.10358 0.10205 0.10168 0.094 0.09409 0.09679 0.0982 0.102 ] [0.09957 0.10093 0.10201 0.09797 0.10341 0.09228 0.09854 0.09956 0.09701 0.10872] [0.1013 0.10008 0.10484 0.10141 0.10152 0.09394 0.09917 0.09738 0.09566 0.10469] [0.09754 0.10137 0.10203 0.0997 0.10365 0.09344 0.09984 0.09637 0.09975 0.1063 ] [0.10017 0.09894 0.10216 0.09946 0.10454 0.09502 0.09627 0.09938 0.09772 0.10635] [0.10145 0.10477 0.10303 0.10141 0.10594 0.0945 0.09322 0.09373 0.09787 0.10407] [0.10069 0.09888 0.10438 0.09856 0.10384 0.09523 0.09696 0.10097 0.09726 0.10325] [0.09612 0.09914 0.10608 0.09861 0.10458 0.09728 0.09528 0.09988 0.09747 0.10555] [0.10348 0.10038 0.10236 0.09914 0.10369 0.09435 0.09755 0.09989 0.09715 0.10201] [0.10534 0.10166 0.10351 0.10079 0.10207 0.09584 0.09626 0.09319 0.09853 0.10282] [0.10069 0.10168 0.10347 0.10064 0.10358 0.09336 0.0982 0.09915 0.09699 0.10224] [0.09764 0.10295 0.10239 0.09878 0.10076 0.09543 0.0961 0.10031 0.09992 0.10572] [0.10117 0.10452 0.10557 0.10102 0.0968 0.09596 0.09412 0.09986 0.09958 0.1014 ]]
yprobs.shape: (64, 10)
yprobs:
[[0.0985  0.10498 0.10199 0.10171 0.10054 0.09359 0.09674 0.1007  0.10169 0.09954]
 [0.10153 0.10251 0.10201 0.10051 0.09947 0.09582 0.09661 0.0992  0.09918 0.10317]
 [0.09949 0.10228 0.10416 0.09818 0.10186 0.0958  0.09698 0.09601 0.09701 0.10823]
 [0.10349 0.1011  0.10288 0.09988 0.10389 0.09426 0.09743 0.09923 0.09627 0.10155]
 [0.0982  0.09881 0.10485 0.1002  0.10269 0.09607 0.09701 0.10003 0.09746 0.10468]
 [0.09813 0.10323 0.10219 0.10038 0.1032  0.09376 0.09667 0.09929 0.09973 0.10341]
 [0.102   0.09924 0.10215 0.09776 0.10277 0.09683 0.09698 0.09787 0.09942 0.10497]
 [0.09886 0.10543 0.1056  0.10198 0.09981 0.09586 0.09316 0.09697 0.10163 0.1007 ]
 [0.10125 0.09994 0.10175 0.09904 0.10153 0.09641 0.09911 0.09926 0.09827 0.10345]
 [0.10024 0.09971 0.10547 0.10055 0.10371 0.09305 0.09684 0.09859 0.09765 0.10417]
 [0.09929 0.10294 0.10367 0.0995  0.10208 0.09659 0.09369 0.10039 0.09753 0.10432]
 [0.09923 0.09956 0.10181 0.10151 0.10209 0.09664 0.09973 0.09987 0.09638 0.10317]
 [0.09897 0.1057  0.1018  0.09911 0.10481 0.09478 0.09314 0.09854 0.09929 0.10387]
 [0.09734 0.10127 0.10345 0.09852 0.10367 0.09841 0.09629 0.1012  0.09755 0.10231]
 [0.10039 0.09973 0.10104 0.10081 0.10208 0.09492 0.09898 0.09882 0.09831 0.10491]
 [0.10013 0.09742 0.1051  0.10052 0.10369 0.09448 0.0991  0.09866 0.09563 0.10527]
 [0.10216 0.10169 0.10236 0.09969 0.104   0.0962  0.09698 0.09609 0.09944 0.1014 ]
 [0.10221 0.09962 0.10299 0.10232 0.1022  0.0942  0.09728 0.09745 0.09894 0.10281]
 [0.09965 0.10211 0.10196 0.09858 0.10239 0.09584 0.09681 0.10028 0.09811 0.10426]
 [0.10063 0.09919 0.10471 0.10123 0.10281 0.09323 0.09724 0.10094 0.09713 0.10289]
 [0.09956 0.10104 0.10313 0.1005  0.10467 0.09494 0.09697 0.09931 0.09572 0.10416]
 [0.10442 0.10142 0.10364 0.10052 0.09847 0.0957  0.09635 0.10021 0.09663 0.10264]
 [0.10003 0.09907 0.10271 0.10112 0.10051 0.09444 0.09805 0.10324 0.09801 0.10283]
 [0.10324 0.10058 0.10259 0.09975 0.10434 0.09407 0.09769 0.09949 0.09644 0.1018 ]
 [0.10052 0.10087 0.10219 0.1027  0.10217 0.09688 0.0962  0.1013  0.09765 0.09953]
 [0.10137 0.10317 0.10155 0.10033 0.10257 0.09684 0.09401 0.09633 0.09972 0.10413]
 [0.09848 0.09975 0.1035  0.10012 0.10324 0.09834 0.09708 0.09894 0.09704 0.10351]
 [0.09849 0.10601 0.10241 0.1002  0.10049 0.09388 0.09301 0.10049 0.10049 0.10454]
 [0.10039 0.10165 0.10144 0.09441 0.10519 0.09385 0.09911 0.09823 0.09847 0.10727]
 [0.10206 0.09837 0.10446 0.1     0.10405 0.09333 0.09992 0.09876 0.09567 0.10338]
 [0.10343 0.10502 0.10131 0.09804 0.10321 0.09429 0.09475 0.09473 0.09887 0.10635]
 [0.10078 0.10184 0.10196 0.10176 0.09872 0.09597 0.09834 0.09709 0.09954 0.104  ]
 [0.10258 0.10261 0.10058 0.0985  0.10127 0.09866 0.09746 0.09821 0.09824 0.10187]
 [0.10156 0.10081 0.10405 0.10098 0.10312 0.09209 0.09546 0.10084 0.0979  0.1032 ]
 [0.09728 0.10466 0.10131 0.09963 0.10146 0.09583 0.09403 0.10076 0.09846 0.1066 ]
 [0.10101 0.10073 0.10265 0.10083 0.10468 0.0943  0.09819 0.09963 0.09734 0.10064]
 [0.09864 0.10516 0.10099 0.09963 0.10077 0.09469 0.09586 0.10013 0.09997 0.10414]
 [0.0978  0.10555 0.10241 0.09794 0.10263 0.0953  0.09394 0.10298 0.09743 0.10402]
 [0.10273 0.09647 0.1029  0.09746 0.10312 0.09721 0.10147 0.0991  0.09829 0.10124]
 [0.09937 0.10258 0.09961 0.09749 0.10096 0.0993  0.09749 0.09968 0.09947 0.10404]
 [0.0995  0.09887 0.10216 0.10092 0.10206 0.09591 0.09857 0.09952 0.09797 0.10451]
 [0.10161 0.10193 0.10296 0.09831 0.10426 0.09497 0.09466 0.09873 0.09884 0.10373]
 [0.09937 0.09717 0.10244 0.10156 0.10219 0.09613 0.10072 0.09936 0.09818 0.10289]
 [0.09964 0.09802 0.1036  0.1021  0.10156 0.09594 0.09772 0.10049 0.0977  0.10323]
 [0.10047 0.09902 0.10392 0.10157 0.10421 0.09617 0.09827 0.09653 0.09791 0.10194]
 [0.09971 0.098   0.10302 0.09947 0.10476 0.09613 0.09825 0.09932 0.09743 0.10391]
 [0.10424 0.10007 0.10398 0.10173 0.10635 0.09568 0.0956  0.0946  0.09677 0.10098]
 [0.10023 0.10006 0.10294 0.10013 0.10398 0.09415 0.09748 0.10095 0.09815 0.10193]
 [0.09824 0.10248 0.10331 0.09891 0.10216 0.09684 0.09439 0.09893 0.09922 0.10551]
 [0.09769 0.10463 0.10356 0.10124 0.10132 0.09476 0.09396 0.10026 0.09955 0.10302]
 [0.10047 0.10092 0.1024  0.0991  0.10231 0.09664 0.09538 0.09974 0.09892 0.10412]
 [0.10244 0.10519 0.10358 0.10205 0.10168 0.094   0.09409 0.09679 0.0982  0.102  ]
 [0.09957 0.10093 0.10201 0.09797 0.10341 0.09228 0.09854 0.09956 0.09701 0.10872]
 [0.1013  0.10008 0.10484 0.10141 0.10152 0.09394 0.09917 0.09738 0.09566 0.10469]
 [0.09754 0.10137 0.10203 0.0997  0.10365 0.09344 0.09984 0.09637 0.09975 0.1063 ]
 [0.10017 0.09894 0.10216 0.09946 0.10454 0.09502 0.09627 0.09938 0.09772 0.10635]
 [0.10145 0.10477 0.10303 0.10141 0.10594 0.0945  0.09322 0.09373 0.09787 0.10407]
 [0.10069 0.09888 0.10438 0.09856 0.10384 0.09523 0.09696 0.10097 0.09726 0.10325]
 [0.09612 0.09914 0.10608 0.09861 0.10458 0.09728 0.09528 0.09988 0.09747 0.10555]
 [0.10348 0.10038 0.10236 0.09914 0.10369 0.09435 0.09755 0.09989 0.09715 0.10201]
 [0.10534 0.10166 0.10351 0.10079 0.10207 0.09584 0.09626 0.09319 0.09853 0.10282]
 [0.10069 0.10168 0.10347 0.10064 0.10358 0.09336 0.0982  0.09915 0.09699 0.10224]
 [0.09764 0.10295 0.10239 0.09878 0.10076 0.09543 0.0961  0.10031 0.09992 0.10572]
 [0.10117 0.10452 0.10557 0.10102 0.0968  0.09596 0.09412 0.09986 0.09958 0.1014 ]]

At this point, the outputs are meaningless and are pretty close to uniformly distributed over classes. This is because we haven’t yet calculated the gradient of the loss function with respect to each weight, which allows the network to adjust its weights and biases to minimize prediction errors. In the next post, we’ll implement backpropagation entirely in JAX and walkthrough how to construct the training and validation loops.

To leave a comment for the author, please follow the link and comment on their blog: The Pleasure of Finding Things Out: A blog by James Triveri .

Want to share your content on python-bloggers? click here.