How to Visualize PyTorch Neural Networks – 3 Examples in Python

This article was first published on Tag: python - Appsilon | Enterprise R Shiny Dashboards , and kindly contributed to python-bloggers. (You can report issue about the content on this page here)
Want to share your content on python-bloggers? click here.
Visualize PyTorch Neural Networks Article Thumbnail

If you truly want to wrap your head around a deep learning model, visualizing it might be a good idea. These networks typically have dozens of layers, and figuring out what’s going on from the summary alone won’t get you far. That’s why today we’ll show you 3 ways to visualize Pytorch neural networks.

We’ll first build a simple feed-forward neural network model for the well-known Iris dataset. You’ll see that visualizing models/model architectures isn’t complicated at all, and will take you only a couple of lines of code.

Data for Good – How Appsilon Counted Nests of Shags with YOLO Object Detection Algorithm.

Table of contents:

Getting Started with PyTorch: Let’s Build a Neural Network

Building a neural network model from scratch in PyTorch is easier than it sounds. Previous experience with the library is desirable, but not required – you’ll have no trouble following if you prefer some other deep learning package.

We’ll build a model around the Iris dataset for two reasons:

  1. No data preparation is needed – the dataset is simple to understand, clean, and ready for supervised machine learning classification.
  2. You don’t need a huge network to get accurate results – which makes visualizing the network easier.

The code snippet below imports all Python libraries we’ll need for now and loads in the dataset:

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

iris = pd.read_csv("")
Image 1 - Head of the Iris dataset for the PyTorch neural network example

Image 1 – Head of the Iris dataset

Now, PyTorch can’t understand Pandas DataFrames, so we’ll have to convert the dataset into a tensor format.

The features of the dataset can be passed straight into the torch.tensor() function, while the target variable requires some encoding (from string to integer):

X = torch.tensor(iris.drop("variety", axis=1).values, dtype=torch.float)
y = torch.tensor(
    [0 if vty == "Setosa" else 1 if vty == "Versicolor" else 2 for vty in iris["variety"]], 

Image 2 - Contents of the feature and target tensors

Image 2 – Contents of the feature and target tensors

And that’s it. The dataset is ready to be passed into a PyTorch neural network model. Let’s build one next. It will have an input layer going from 4 features to 16 nodes, one hidden layer, and an output layer going from 16 nodes to 3 class probabilities:

class Net(nn.Module):
    def __init__(self):
        self.input = nn.Linear(in_features=4, out_features=16)
        self.hidden_1 = nn.Linear(in_features=16, out_features=16)
        self.output = nn.Linear(in_features=16, out_features=3)
    def forward(self, x):
        x = F.relu(self.input(x))
        x = F.relu(self.hidden_1(x))
        return self.output(x)
model = Net()
Image 3 - Summary of a neural network model

Image 3 – Summary of a neural network model

It’s easy to look at the summary of this model since there are only a couple of layers, but imagine you had a deep network with dozens of layers – all of the sudden, the summary would be too large to fit the screen.

In the following section, we’ll explore the first way to visualize PyTorch neural networks, and that is with the Torchviz library.

Torchviz: Visualize PyTorch Neural Networks With a Single Function Call

Torchviz is a Python package used to create visualizations of PyTorch execution graphs and traces. It depends on Graphviz, which is a dependency you’ll have to install system-wide (Mac example shown below). Once installed, you can install Torchviz with pip:

brew install graphviz
pip install torchviz

To use Torchviz in Python, you’ll have to import the make_dot() function, make an instance of your neural network class, and calculate prediction probabilities of the entire training set or a batch of samples. Since the Iris dataset is small, we’ll calculate predictions for all flower instances:

from torchviz import make_dot

model = Net()
y = model(X)

That’s all you need to visualize the network. Simply pass the average of the probability tensor alongside the model parameters to the make_dot() function:

make_dot(y.mean(), params=dict(model.named_parameters()))
Image 4 - Visualizing a neural network model with torchviz (1)

Image 4 – Visualizing model with torchviz (1)

You can also see what autograd saves for the backward pass by specifying two additional parameters: show_attrs=True and show_saved=True:

make_dot(y.mean(), params=dict(model.named_parameters()), show_attrs=True, show_saved=True)
Image 5 - Visualizing model with torchviz (2)

Image 5 – Visualizing model with torchviz (2)

It’s a bit more detailed graph, but maybe that’s what you’re aiming for.

Next, we’ll explore a Desktop app used to visualize any ONNX model.

Netron: Desktop App for Visualizing ONNX Models

Netron is a Desktop and Web interface for visualizing neural network models from different libraries, including PyTorch. It works best if you export the model into an ONNX format (Open Neural Network Exchange), which is as simple as a function call in PyTorch.

You can download the Desktop standalone application, or you can use a web interface linked in the documentation. There are also Python server options, but we haven’t explored them.

To get started, specify names for inputs and outputs as a list of string(s). Feel free to name these however you want. Once done, call the torch.onnx.export() function to export the model to a file:

input_names = ["Iris"]
output_names = ["Iris Species Prediction"]

torch.onnx.export(model, X, "model.onnx", input_names=input_names, output_names=output_names)

The model is now saved to model.onnx file, and you can easily load it into Netron. Here’s what it looks like:

Image 6 - Visualizing PyTorch neural network model with Netron

Image 6 – Visualizing model with Netron

Let’s explore another way to visualize PyTorch neural networks which Tensorflow users will find familiar.

Tensorboard: Visualize Machine Learning Workflow and Graphs

TensorBoard is a visualization and tooling framework needed for machine learning experimentations. It has many features useful to deep learning researchers and practitioners, one of them being visualizing the model graph.

That’s exactly the feature we’ll explore today. But first, make sure to install TensorBoard through pip:

pip install tensorboard

So, how can you connect the PyTorch model with TensorBoard? You’ll need to take advantage of the SummaryWriter class from PyTorch, and add a network graph to a log directory. In our example, the logs will be saved to the torchlogs/ folder:

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("torchlogs/")
model = Net()
writer.add_graph(model, X)

Once the network graph is saved, navigate to the log directory from the shell and launch TensorBoard:

cd <path-to-logs-dir>
tensorboard --logdir=./
Image 7 - Starting Tensorboard from the shell

Image 7 – Starting Tensorboard from the shell

You’ll be able to see the model graph on http://localhost:6006. You can click on any graph element and TensorBoard will expand it for you, as shown in the figure below:

Image 8 - Visualizing neural network model with Tensorboard

Image 8 – Visualizing model with Tensorboard

And that’s it for the ways to visualize PyTorch neural networks. Let’s make a short recap next.

Summing up How to Visualize PyTorch Neural Networks

If you want to understand what’s going on in a neural network model, visualizing the network graph is the way to go. Sure, you need to actually understand why the network is constructed the way it is, but that’s a fundamental deep learning knowledge we assume you have.

Maximize the benefits of your ML projects with templates using PyTorch Lightning & Hydra.

We’ve explored three ways to visualize neural network models from PyTorch – with Torchviz, Netron, and TensorBoard. All are excellent, and there’s no way to pick a winner. Let us know which one you prefer.

Do you use some other tool to visualize neural network model graphs? Please let us know in the comment section below. Also, don’t hesitate to move the discussion to Twitter – @appsilon. We’d love to hear from you.

What are benefits of Model Serialization? Find out in our latest blog post by Piotr Storożenko.

The post How to Visualize PyTorch Neural Networks – 3 Examples in Python appeared first on Appsilon | Enterprise R Shiny Dashboards.

To leave a comment for the author, please follow the link and comment on their blog: Tag: python - Appsilon | Enterprise R Shiny Dashboards .

Want to share your content on python-bloggers? click here.