Evolution of Grad-CAM heat-maps along a ResNet-34

This exercise is a continuation of my last post, which was an exploration in generating class discriminative localization maps for a convnet. In particular, I used feature map activations of the last convolutional layer (after BatchNorm), along with gradients of a specific class score wrt these activations to create heat-maps that help visualize parts of input image that contribute most coming up with a prediction.

I wanted to extend that approach to see how these heat-maps shape up as we move deeper into the network, starting with the very first convolutional layer. Similar to the last post, inspiration for this comes from a fastai Deep Learning MOOC lecture which is itself inspired by Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization by Ramprasaath R. Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, Dhruv Batra.

Table of Contents

Setup

I'll breeze through the setup (dataset, model training) since it's the same as last time. Network architecture is based on ResNet-34, with 37 output classes.

%reload_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML
from fastai import *
from fastai.vision import *
from fastai import version as fastai_version
print(f'fastai version -> {fastai_version.__version__}')
fastai version -> 1.0.32
bs = 32  #batch size
path = untar_data(URLs.PETS)/'images'
tfms = get_transforms(max_rotate=20, max_zoom=1.3, max_lighting=0.4, max_warp=0.4,
                      p_affine=.2, p_lighting=.2)
src = ImageItemList.from_folder(path).random_split_by_pct(0.2, seed=2)
def get_data(size, bs, padding_mode='reflection'):
    return (src.label_from_re(r'([^/]+)_\d+.jpg$')
           .transform(tfms, size=size, padding_mode=padding_mode)
           .databunch(bs=bs).normalize(imagenet_stats))
data = get_data(224, bs)
data.show_batch(rows=3, figsize=(6,6))

img 16

data = get_data(352,16)
learn = create_cnn(data, models.resnet34, metrics=error_rate, bn_final=True).load('352')

Using pre-trained weights from the last post to keep things simple. I'll be indexing into learn.model to get the building blocks of the model.

learn.model[0][0]
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
learn.model[0][4][0]
BasicBlock(
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

Forward pass heat-maps along the network

Let's work with this image of a miniature_pinscher.

idx=4
x,y = data.valid_ds[idx]
x.show()
print(f'class name: {y}\nclass_index: {y.data}')
class name: miniature_pinscher
class_index: 26

img 15

from fastai.callbacks.hooks import *
m = learn.model.eval();
xb,_ = data.one_item(x)    #get tensor from Image x
xb_im = Image(data.denorm(xb)[0])
xb = xb.cuda()

non_class_discriminative_activations_multi returns a list of forward pass activations at different points in the network along with layer_names.

def non_class_discriminative_activations_multi(xb):
    hooks = []
    layer_names = []
    
    hooks.append(hook_output(m[0][3]))
    layer_names.append('first conv')

    for ind in [4,5,6,7]:
        for i,el in enumerate(learn.model[0][ind]):
            hooks.append(hook_output(el))
            layer_names.append(f'layer-{ind-3} - conv-{i+1}')
            # layer 1 (layer is combination of resnet blocks) is model[0][4]
    
    preds = m(xb)
    
    for hook in hooks:
        hook.remove()
    
    return hooks,layer_names
hooks,layer_names = non_class_discriminative_activations_multi(xb)

Printing out shapes of these activation tensors. I'm using the terminology of calling a group of ResNet blocks a "layer", similar to PyTorch's implementation. There are 4 such "layers" in ResNet-34 having [3,4,6,3] ResNet blocks.

for layer_name,hook in zip(layer_names,hooks):
    print(f'{layer_name}{" "*(18-len(layer_name))} -->   {hook.stored[0].shape}')
first conv         -->   torch.Size([64, 88, 88])
layer-1 - conv-1   -->   torch.Size([64, 88, 88])
layer-1 - conv-2   -->   torch.Size([64, 88, 88])
layer-1 - conv-3   -->   torch.Size([64, 88, 88])
layer-2 - conv-1   -->   torch.Size([128, 44, 44])
layer-2 - conv-2   -->   torch.Size([128, 44, 44])
layer-2 - conv-3   -->   torch.Size([128, 44, 44])
layer-2 - conv-4   -->   torch.Size([128, 44, 44])
layer-3 - conv-1   -->   torch.Size([256, 22, 22])
layer-3 - conv-2   -->   torch.Size([256, 22, 22])
layer-3 - conv-3   -->   torch.Size([256, 22, 22])
layer-3 - conv-4   -->   torch.Size([256, 22, 22])
layer-3 - conv-5   -->   torch.Size([256, 22, 22])
layer-3 - conv-6   -->   torch.Size([256, 22, 22])
layer-4 - conv-1   -->   torch.Size([512, 11, 11])
layer-4 - conv-2   -->   torch.Size([512, 11, 11])
layer-4 - conv-3   -->   torch.Size([512, 11, 11])
acts  = hook_1.stored[0].cpu()
acts.shape
torch.Size([64, 88, 88])

Averaging the values of these activations over the channel axis to get a 2 dimensional tensor.

avg_acts = acts.mean(0)
avg_acts.shape
torch.Size([88, 88])

Plotting these averaged activations.

plt.imshow(avg_acts, cmap='magma');

img 14

from math import ceil

Plotting all of the stored activations.

def plot_forward_activations_multi(hooks):
    
    num_cols = 4
    num_rows = ceil(len(hooks)/num_cols)

    fig,ax = plt.subplots(num_rows,num_cols)
    fig.set_size_inches(num_cols*3,num_rows*3)

    ind = 0
    
    for i in range(num_rows):
        for j in range(num_cols):
            
            if ind>=len(hooks):
                break
            
            acts  = hooks[ind].stored[0].cpu()
            avg_acts = acts.mean(0)

            ax[i,j].imshow(avg_acts, cmap='magma')
            ind+=1
            
    plt.show()
plot_forward_activations_multi(hooks)

Plotting heat-maps based on these activations by extrapolating them to the size of the input image.

def plot_non_class_discriminative_heatmaps_multi(x):
    
    xb,_ = data.one_item(x)
    xb_im = Image(data.denorm(xb)[0])
    xb = xb.cuda()
    
    hooks,_ = non_class_discriminative_activations_multi(xb)
    
    num_cols = 4
    num_rows = ceil(len(hooks)/num_cols)

    fig,ax = plt.subplots(num_rows,num_cols)
    fig.set_size_inches(num_cols*3,num_rows*3)

    ind = 0
    
    for i in range(num_rows):
        for j in range(num_cols):
            
            if ind>=len(hooks):
                break
            
            acts  = hooks[ind].stored[0].cpu()
            avg_acts = acts.mean(0)

            xb_im.show(ax[i,j])
            ax[i,j].imshow(avg_acts, alpha=0.6, extent=(0,352,352,0),
                      interpolation='bilinear', cmap='magma');
            ind+=1
            
    plt.show()
plot_non_class_discriminative_heatmaps_multi(x)

Testing out on the pug_maine image like last time.

!wget https://i.pinimg.com/originals/ae/e4/a7/aee4a7df36c2e17f2490036d84f05d1f.jpg -O pug_maine.jpg
fn = 'pug_maine.jpg'
x_test = open_image(fn)
x_test.show(figsize=(x_test.size[0]/120,x_test.size[1]/120))

img 11

learn.predict(x_test)[0]
Category pug
plot_non_class_discriminative_heatmaps_multi(x_test)

Class-discriminative heat-maps along the network

Let's see how the Grad-CAM heat-maps look at different points in the network. Again, on more info on how Grad-CAM works, refer back to previous post.

def class_discriminative_gradients_multi(xb, cat):
    hooks_a = []
    hooks_g = []
    layer_names = []
    
    hooks_a.append(hook_output(m[0][3]))
    hooks_g.append(hook_output(m[0][3], grad=True))
    
    layer_names.append('first conv')

    for ind in [4,5,6,7]:
        for i,el in enumerate(learn.model[0][ind]):
            hooks_a.append(hook_output(el))
            hooks_g.append(hook_output(el, grad=True))
            layer_names.append(f'layer-{ind-3} - conv-{i+1}')
            # layer 1 (layer is combination of resnet blocks) is model[0][4]
    
    preds = m(xb)
    preds[0,int(cat)].backward()
    
    for hook in hooks_a:
        hook.remove()
        
    for hook in hooks_g:
        hook.remove()
    
    return hooks_a,hooks_g,layer_names
def plot_class_discriminative_heatmaps_multi(x,cat=None,relu=True):
    
    xb,_ = data.one_item(x)
    xb_im = Image(data.denorm(xb)[0])
    xb = xb.cuda()
    
    if cat is None:
        y_to_get_gradients_for = y.data
    else:
        y_to_get_gradients_for = cat
    
    hooks_a,hooks_g,_ = class_discriminative_gradients_multi(xb,y_to_get_gradients_for)
    
    num_cols = 4
    num_rows = ceil(len(hooks)/num_cols)

    fig,ax = plt.subplots(num_rows,num_cols)
    fig.set_size_inches(num_cols*3,num_rows*3)

    ind = 0
    
    for i in range(num_rows):
        for j in range(num_cols):
            
            if ind>=len(hooks):
                break
            
            acts = hooks_a[ind].stored[0].cpu()
            grad = hooks_g[ind].stored[0][0].cpu()

            grad_chan = grad.mean(1).mean(1)
            mult = (acts*grad_chan[...,None,None]).mean(0)

            if relu:
                mult = F.relu(mult)

            xb_im.show(ax[i,j])
            ax[i,j].imshow(mult, alpha=0.6, extent=(0,352,352,0),
                      interpolation='bilinear', cmap='magma');
            ind+=1
            
    plt.show()
plot_class_discriminative_heatmaps_multi(x)

As expected, the class discriminative heat-maps are more concentrated especially in the last few layers of the network.

class_dict = {}
for i,el in enumerate(data.classes):
    class_dict[el.lower()] = i
plot_class_discriminative_heatmaps_multi(x_test,class_dict['pug'])

plot_class_discriminative_heatmaps_multi(x_test,class_dict['maine_coon'])

Testing it on one more image containing objects belonging to 2 classes.

!wget https://i.ytimg.com/vi/SEAayLjjVOg/maxresdefault.jpg -O beagle_shiba.jpg
fn = 'beagle_shiba.jpg'
x_test_2 = open_image(fn)
x_test_2.show(figsize=(x_test_2.size[0]/100,x_test_2.size[1]/100))

img 6

learn.predict(x_test_2)[0]
Category shiba_inu

The model classifies this image as shiba_inu.

plot_class_discriminative_heatmaps_multi(x_test_2,class_dict['shiba_inu'])

plot_class_discriminative_heatmaps_multi(x_test_2,class_dict['beagle'])

!wget https://previews.123rf.com/images/cynoclub/cynoclub1712/cynoclub171200076/91122699-puppy-saint-bernard-and-chihuahua-in-front-of-white-background.jpg -O saint_bernard_chihuahua.jpg
fn = 'saint_bernard_chihuahua.jpg'
x_test_3 = open_image(fn)
x_test_3.show(figsize=(x_test_3.size[0]/120,x_test_3.size[1]/120))

img 3

learn.predict(x_test_3)[0]
Category saint_bernard
plot_class_discriminative_heatmaps_multi(x_test_3,cat=class_dict['saint_bernard'])

plot_class_discriminative_heatmaps_multi(x_test_3,class_dict['chihuahua'])

One point to note here is that the heat-maps get properly concentrated on the correct pixels only in the last one or 2 layers of the network. I wanted to find some more generalizations in the way these heat-maps shape up along the layers of the network, but couldn't find a common trait in these tests. I'll do some more testing on a larger bunch of images and link it here.

Let's generate animations from these heat-maps.

def get_class_discriminative_heatmap_animation(x,cat=None,relu=True):
    
    fig, ax = plt.subplots(1,2)
#     fig.set_size_inches(5,5)
    fig.tight_layout()
    plt.close()

    xb,_ = data.one_item(x)
    xb_im = Image(data.denorm(xb)[0])
    xb = xb.cuda()

    if cat is None:
        y_to_get_gradients_for = y.data
    else:
        y_to_get_gradients_for = cat

    hooks_a,hooks_g,layer_names = class_discriminative_gradients_multi(xb,y_to_get_gradients_for)

    ims = []
    
    xb_im.show(ax[0])

    for ind in range(len(hooks_a)):

        acts = hooks_a[ind].stored[0].cpu()
        grad = hooks_g[ind].stored[0][0].cpu()

        grad_chan = grad.mean(1).mean(1)
        mult = (acts*grad_chan[...,None,None]).mean(0)
        
        if relu:
            mult = F.relu(mult)
        
        im = ax[1].imshow(mult, alpha=1, extent=(0,352,352,0), interpolation='bilinear', cmap='magma', animated=True)

        ims.append([im])

    ani = animation.ArtistAnimation(fig, ims, interval=400, blit=True,
                                    repeat_delay=1000, repeat=False)
    return ani
from IPython.display import HTML
ani = get_class_discriminative_heatmap_animation(x)
# rc('animation', html='jshtml')
# ani
HTML(ani.to_html5_video())
ani = get_class_discriminative_heatmap_animation(x_test_2,class_dict['shiba_inu'])
HTML(ani.to_html5_video())
ani = get_class_discriminative_heatmap_animation(x_test_2,class_dict['beagle'])
HTML(ani.to_html5_video())
ani = get_class_discriminative_heatmap_animation(x_test,class_dict['pug'])
HTML(ani.to_html5_video())
ani = get_class_discriminative_heatmap_animation(x_test,class_dict['maine_coon'])
HTML(ani.to_html5_video())
ani = get_class_discriminative_heatmap_animation(x_test_3,class_dict['saint_bernard'])
HTML(ani.to_html5_video())
ani = get_class_discriminative_heatmap_animation(x_test_3,class_dict['chihuahua'])
HTML(ani.to_html5_video())

References