Skip to content

Commit bee7cec

Browse files
authored
Merge pull request #952 from pytorch/jlin27_tutorials_refresh
Add Captum Recipe into source and recipes_index.rst
2 parents 5d1149b + ce36aaf commit bee7cec

File tree

2 files changed

+262
-0
lines changed

2 files changed

+262
-0
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""
2+
Model Interpretability using Captum
3+
===================================
4+
5+
"""
6+
7+
8+
######################################################################
9+
# Captum helps you understand how the data features impact your model
10+
# predictions or neuron activations, shedding light on how your model
11+
# operates.
12+
#
13+
# Using Captum, you can apply a wide range of state-of-the-art feature
14+
# attribution algorithms such as \ ``Guided GradCam``\ and
15+
# \ ``Integrated Gradients``\ in a unified way.
16+
#
17+
# In this recipe you will learn how to use Captum to: \* attribute the
18+
# predictions of an image classifier to their corresponding image
19+
# features. \* visualize the attribution results.
20+
#
21+
22+
23+
######################################################################
24+
# Before you begin
25+
# ----------------
26+
#
27+
28+
29+
######################################################################
30+
# Make sure Captum is installed in your active Python environment. Captum
31+
# is available both on GitHub, as a ``pip`` package, or as a ``conda``
32+
# package. For detailed instructions, consult the installation guide at
33+
# https://captum.ai/
34+
#
35+
36+
37+
######################################################################
38+
# For a model, we use a built-in image classifier in PyTorch. Captum can
39+
# reveal which parts of a sample image support certain predictions made by
40+
# the model.
41+
#
42+
43+
import torchvision
44+
from torchvision import transforms
45+
from PIL import Image
46+
import requests
47+
from io import BytesIO
48+
49+
model = torchvision.models.resnet18(pretrained=True).eval()
50+
51+
response = requests.get("https://image.freepik.com/free-photo/two-beautiful-puppies-cat-dog_58409-6024.jpg")
52+
img = Image.open(BytesIO(response.content))
53+
54+
center_crop = transforms.Compose([
55+
transforms.Resize(256),
56+
transforms.CenterCrop(224),
57+
])
58+
59+
normalize = transforms.Compose([
60+
transforms.ToTensor(), # converts the image to a tensor with values between 0 and 1
61+
transforms.Normalize( # normalize to follow 0-centered imagenet pixel rgb distribution
62+
mean=[0.485, 0.456, 0.406],
63+
std=[0.229, 0.224, 0.225]
64+
)
65+
])
66+
input_img = normalize(center_crop(img)).unsqueeze(0)
67+
68+
69+
######################################################################
70+
# Computing Attribution
71+
# ---------------------
72+
#
73+
74+
75+
######################################################################
76+
# Among the top-3 predictions of the models are classes 208 and 283 which
77+
# correspond to dog and cat.
78+
#
79+
# Let us attribute each of these predictions to the corresponding part of
80+
# the input, using Captum’s \ ``Occlusion``\ algorithm.
81+
#
82+
83+
from captum.attr import Occlusion
84+
85+
occlusion = Occlusion(model)
86+
87+
strides = (3, 9, 9) # smaller = more fine-grained attribution but slower
88+
target=208, # Labrador index in ImageNet
89+
sliding_window_shapes=(3,45, 45) # choose size enough to change object appearance
90+
baselines = 0 # values to occlude the image with. 0 corresponds to gray
91+
92+
attribution_dog = occlusion.attribute(input_img,
93+
strides = strides,
94+
target=target,
95+
sliding_window_shapes=sliding_window_shapes,
96+
baselines=baselines)
97+
98+
99+
target=283, # Persian cat index in ImageNet
100+
attribution_cat = occlusion.attribute(input_img,
101+
strides = strides,
102+
target=target,
103+
sliding_window_shapes=sliding_window_shapes,
104+
baselines=0)
105+
106+
107+
######################################################################
108+
# Besides ``Occlusion``, Captum features many algorithms such as
109+
# \ ``Integrated Gradients``\ , \ ``Deconvolution``\ ,
110+
# \ ``GuidedBackprop``\ , \ ``Guided GradCam``\ , \ ``DeepLift``\ , and
111+
# \ ``GradientShap``\ . All of these algorithms are subclasses of
112+
# ``Attribution`` which expects your model as a callable ``forward_func``
113+
# upon initialization and has an ``attribute(...)`` method which returns
114+
# the attribution result in a unified format.
115+
#
116+
# Let us visualize the computed attribution results in case of images.
117+
#
118+
119+
120+
######################################################################
121+
# Visualizing the Results
122+
# -----------------------
123+
#
124+
125+
126+
######################################################################
127+
# Captum’s \ ``visualization``\ utility provides out-of-the-box methods
128+
# to visualize attribution results both for pictorial and for textual
129+
# inputs.
130+
#
131+
132+
import numpy as np
133+
from captum.attr import visualization as viz
134+
135+
# Convert the compute attribution tensor into an image-like numpy array
136+
attribution_dog = np.transpose(attribution_dog.squeeze().cpu().detach().numpy(), (1,2,0))
137+
138+
vis_types = ["heat_map", "original_image"]
139+
vis_signs = ["all", "all"], # "positive", "negative", or "all" to show both
140+
# positive attribution indicates that the presence of the area increases the prediction score
141+
# negative attribution indicates distractor areas whose absence increases the score
142+
143+
_ = viz.visualize_image_attr_multiple(attribution_dog,
144+
center_crop(img),
145+
vis_types,
146+
vis_signs,
147+
["attribution for dog", "image"],
148+
show_colorbar = True
149+
)
150+
151+
152+
attribution_cat = np.transpose(attribution_cat.squeeze().cpu().detach().numpy(), (1,2,0))
153+
154+
_ = viz.visualize_image_attr_multiple(attribution_cat,
155+
center_crop(img),
156+
["heat_map", "original_image"],
157+
["all", "all"], # positive/negative attribution or all
158+
["attribution for cat", "image"],
159+
show_colorbar = True
160+
)
161+
162+
163+
######################################################################
164+
# If your data is textual, ``visualization.visualize_text()`` offers a
165+
# dedicated view to explore attribution on top of the input text. Find out
166+
# more at http://captum.ai/tutorials/IMDB_TorchText_Interpret
167+
#
168+
169+
170+
######################################################################
171+
# Final Notes
172+
# -----------
173+
#
174+
175+
176+
######################################################################
177+
# Captum can handle most model types in PyTorch across modalities
178+
# including vision, text, and more. With Captum you can: \* Attribute a
179+
# specific output to the model input as illustrated above. \* Attribute a
180+
# specific output to a hidden-layer neuron (see Captum API reference). \*
181+
# Attribute a hidden-layer neuron response to the model input (see Captum
182+
# API reference).
183+
#
184+
# For complete API of the supported methods and a list of tutorials,
185+
# consult our website http://captum.ai
186+
#
187+
# Another useful post by Gilbert Tanner:
188+
# https://gilberttanner.com/blog/interpreting-pytorch-models-with-captum
189+
#

recipes_source/recipes_index.rst

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
PyTorch Recipes
2+
---------------------------------------------
3+
Recipes are bite-sized bite-sized, actionable examples of how to use specific PyTorch features, different from our full-length tutorials.
4+
5+
.. raw:: html
6+
7+
</div>
8+
</div>
9+
10+
<div id="tutorial-cards-container">
11+
12+
<nav class="navbar navbar-expand-lg navbar-light tutorials-nav col-12">
13+
<div class="tutorial-tags-container">
14+
<div id="dropdown-filter-tags">
15+
<div class="tutorial-filter-menu">
16+
<div class="tutorial-filter filter-btn all-tag-selected" data-tag="all">All</div>
17+
</div>
18+
</div>
19+
</div>
20+
</nav>
21+
22+
<hr class="tutorials-hr">
23+
24+
<div class="row">
25+
26+
<div id="tutorial-cards">
27+
<div class="list">
28+
29+
.. Add recipe cards below this line
30+
31+
.. Getting Started
32+
33+
.. customcarditem::
34+
:header: Writing Custom Datasets, DataLoaders and Transforms
35+
:card_description: Learn how to load and preprocess/augment data from a non trivial dataset.
36+
:image: _static/img/thumbnails/pytorch-logo-flat.png
37+
:link: ../recipes/recipes/data_loading_tutorial.html
38+
:tags: Getting-Started
39+
40+
.. Interpretability
41+
42+
.. customcarditem::
43+
:header: Model Interpretability using Captum
44+
:card_description: Learn how to use Captum attribute the predictions of an image classifier to their corresponding image features and visualize the attribution results.
45+
:image: _static/img/thumbnails/pytorch-logo-flat.png
46+
:link: ../recipes/recipes/Captum_Recipe.html
47+
:tags: TorchScript
48+
49+
50+
.. Production Development
51+
52+
.. customcarditem::
53+
:header: TorchScript for Deployment
54+
:card_description: Learn how to export your trained model in TorchScript format and how to load your TorchScript model in C++ and do inference.
55+
:image: _static/img/thumbnails/pytorch-logo-flat.png
56+
:link: ../recipes/recipes/torchscript_inference.html
57+
:tags: TorchScript
58+
59+
60+
61+
.. End of recipe card section
62+
63+
.. raw:: html
64+
65+
</div>
66+
67+
</div>
68+
69+
</div>
70+
71+
</div>
72+
73+
.. .. galleryitem:: beginner/saving_loading_models.py

0 commit comments

Comments
 (0)