Interpreting Tensorflow models with tf-explain

by Gilbert Tanner on Oct 14, 2019 · 4 min read

Interpreting Tensorflow models with tf-explain

Understanding your machine and deep learning models is crucial if you want to use your model in production. This is why I already wrote a series of articles on the topic. In my first article I went over what model interpretation is. In the other two, I went deeper into local and global model interpretation.

In this article, I will show you how to easily interpret your Tensorflow CNN models using tf-explain. For more information on the library check out the official documentation. All the code in this article is taken from the examples available in the library.

tf-explain implements interpretability methods for Tensorflow 1.x and 2. It supports two APIs: the Core API which allows you to interpret a model after it was trained and a Callback API which lets you use callbacks to monitor the model whilst training.

Available Methods

As of now, tf-explain offers 5 different methods for interpreting neural networks:

  1. Activations Visualization
  2. Vanilla Gradients
  3. Occlusion Sensitivity
  4. Grad CAM (Class Activation Maps)
  5. SmoothGrad
  6. Integrated Gradients
From Left to Right: Input Image, Activations Visualizations, Occlusion Sensitivity, Grad CAM, SmoothGrad on VGG16
Figure 1: From Left to Right: Input Image, Activations Visualizations, Occlusion Sensitivity, Grad CAM, SmoothGrad on VGG16

All of these methods are available as both a method or a callback.

CoreAPI Example

The core API is located under tf_explain.core. Every method implemented in tf-explain has the same two methods:

  • a explain method which outputs the explanation (for instance, a heatmap)
  • a save method compatible with its output

All of the methods can be used in the following way:

# Import explainer
from tf_explain.core.grad_cam import GradCAM

# Instantiation of the explainer
explainer = GradCAM()

# Call to explain() method
output = explainer.explain(<data>, <model>, <layer_name>, <class_index>, <colormap(optional)>)

# Save output, output_dir, output_name)

You can use the core API to interpret your trained model. The following code uses Smoothgrad to interpret the gradients of a model when given an image of a cat.

import tensorflow as tf

from tf_explain.core.smoothgrad import SmoothGrad

IMAGE_PATH = './cat.jpg'

if __name__ == '__main__':
    model = tf.keras.applications.vgg16.VGG16(weights='imagenet', include_top=True)

    img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224))
    img = tf.keras.preprocessing.image.img_to_array(img)

    data = ([img], None)

    tabby_cat_class_index = 281
    explainer = SmoothGrad()
    # Compute SmoothGrad on VGG16
    grid = explainer.explain(data, model, tabby_cat_class_index, 20, 1.), '.', 'smoothgrad.png')

Callbacks Example

Using the Callback API you can perform on-training monitoring. Being able to observe the behavior of your model whilst training to decide whether your model is reasonable can save you from hours of trying to train a false behaving model.

TF-explains callbacks can be used like any other Keras callback. The following code trains fashion mnist with all the callbacks available in tf_explain.

import numpy as np
import tensorflow as tf
import tf_explain

INPUT_SHAPE = (28, 28, 1)

    'mnist': tf.keras.datasets.mnist,
    'fashion_mnist': tf.keras.datasets.fashion_mnist,
DATASET_NAME = 'fashion_mnist'  # Choose between "mnist" and "fashion_mnist"

# Load dataset
(train_images, train_labels), (test_images, test_labels) = dataset.load_data()

# Convert from (28, 28) images to (28, 28, 1)
train_images = train_images[..., tf.newaxis].astype('float32')
test_images = test_images[..., tf.newaxis].astype('float32')

# One hot encore labels 0, 1, .., 9 to [0, 0, .., 1, 0, 0]
train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=NUM_CLASSES)
test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=NUM_CLASSES)

# Create model
img_input = tf.keras.Input(INPUT_SHAPE)

x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(img_input)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu', name='target_layer')(x)
x = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(x)

x = tf.keras.layers.Dropout(0.25)(x)
x = tf.keras.layers.Flatten()(x)

x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)

x = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(x)

model = tf.keras.Model(img_input, x)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Select a subset of the validation data to examine
validation_class_zero = (np.array([
    el for el, label in zip(test_images, test_labels)
    if np.all(np.argmax(label) == 0)
][0:5]), None)

validation_class_one = (np.array([
    el for el, label in zip(test_images, test_labels)
    if np.all(np.argmax(label) == 1)
][0:5]), None)

# Instantiate callbacks
# class_index value should match the validation_data selected above
callbacks = [
    tf_explain.callbacks.GradCAMCallback(validation_class_zero, layer_name='target_layer', class_index=0),
    tf_explain.callbacks.GradCAMCallback(validation_class_one, layer_name='target_layer', class_index=4),
    tf_explain.callbacks.ActivationsVisualizationCallback(validation_class_zero, layers_name=['target_layer']),
    tf_explain.callbacks.SmoothGradCallback(validation_class_zero, class_index=0, num_samples=15, noise=1.),
    tf_explain.callbacks.IntegratedGradientsCallback(validation_class_zero, class_index=0, n_steps=10),
    tf_explain.callbacks.VanillaGradientsCallback(validation_class_zero, class_index=0),

# Start training, train_labels, epochs=5, callbacks=callbacks)

Other resources


In this article, I went over the basics of tf-explain, a great simple library for interpreting your Tensorflow models. For more information about the library check out the official documentation.

That’s all from this article. If you have any questions or just want to chat with me feel free to leave a comment below or contact me on social media. If you want to get continuous updates about my blog make sure to follow me on Medium and join my newsletter.

Free Machine Learning Newsletter

Table of Content

Support me

Become a Patron