PyTorch + SHAP = Explainable Convolutional Neural Networks

[This article was first published on python – Better Data Science, and kindly contributed to python-bloggers]. (You can report issue about the content on this page here)
Want to share your content on python-bloggers? click here.

Black-box models are a thing of the past – even with deep learning. You can use SHAP to interpret the predictions of deep learning models, and it requires only a couple of lines of code. Today you’ll learn how on the well-known MNIST dataset.

Convolutional neural networks can be tough to understand. A network learns the optimal feature extractors (kernels) from the image. These features are useful to detect any patterns that help the network to classify images correctly.

Your brain isn’t that much different. It also uses a series of patterns to recognize objects in front of you. For example, what makes a number zero a zero? It’s a round to oval outlined shape with nothing inside. That’s a general pattern the kernels behind convolutional layers try to learn.

If you want to represent your model’s interpretations visually, look no further than SHAP (SHapely Additive exPlanations) – a game theoretic approach to explaining the output of any machine learning model. You can refer to this article for a complete beginners guide.

The article is structured as follows:

You can download the corresponding Notebook here.

Defining the model architecture

You’ll use PyTorch to train a simple handwritten digit classifier. It’s a go-to Python library for deep learning, both in research and in business. If you haven’t used PyTorch before but have some Python experience, it will feel natural.

Before defining the model architecture, you’ll have to import a couple of libraries. Most of these are related to PyTorch, and numpy and shap will be used later:

The model architecture is simple and borrowed from the official documentation. Feel free to declare your own architecture, but this one is good enough for our needs:

The following section shows you how to train the model.

Training the model

Let’s start by declaring a couple of variables:

  • batch_size – how many images are shown to the model at once
  • num_epochs – number of complete passes through the training dataset
  • device – specifies is the training done on CPU or GPU. Replace cuda:0 with cpu if you don’t have a CUDA-compatible GPU

Next, you’ll declare couple of functions – train() and test(). These will be used to train and evaluate the model on separate subsets and print the intermediate results.

The entire code snippet is shown below:

Next, you can download the datasets with the torchvision.datasets module. The datasets are then loaded and transformed (conversion to tensor and normalization) and organized in batches:

And now you have everything ready for model training. Here’s how to instantiate the model and train it for the previously declared number of epochs:

You’ll see the intermediate results printed out during the training phase. Here’s how they look on my machine:

Image 1 - Model training with PyTorch - intermediate results (image by author)

Image 1 – Model training with PyTorch – intermediate results (image by author)

Keep in mind – the actual values may differ slightly on your machine, but you should land north of 95% accuracy on the test set.

Next step – interpretations with SHAP!

Interpreting the model

Prediction interpretation is now as simple as writing a couple of lines of code. The following snippet loads in a batch of random images from the test set and interprets predictions for five of them:

After executing the above code snippet, you’ll see the following image:

Image 2 - SHAP explanations for handwritten digits classifier (image by author)

Image 2 – SHAP explanations for handwritten digits classifier (image by author)

Input images are displayed on the left, and the interpretations for every class on the right. Anything colored in red increases the model output (the model is more confident in the classification), while everything colored blue decreases it. 

That’s how SHAP explanations work with convolutional neural networks. Let’s wrap things up in the next section.


Today you’ve learned how to create a basic convolutional neural network model for classifying handwritten digits with PyTorch. You’ve also learned how to explain the predictions made by the model. 

Brining this interpretation skillset to your domain is now as simple as changing the dataset and model architecture. Explanation code should be identical or require minimal changes to accommodate for different subsets.

Thanks for reading.

Learn more 

Stay connected 


The post PyTorch + SHAP = Explainable Convolutional Neural Networks appeared first on Better Data Science.

To leave a comment for the author, please follow the link and comment on their blog: python – Better Data Science.

Want to share your content on python-bloggers? click here.