1+ """
2+ Model Interpretability using Captum
3+ ===================================
4+
5+ """
6+
7+
8+ ######################################################################
9+ # Captum helps you understand how the data features impact your model
10+ # predictions or neuron activations, shedding light on how your model
11+ # operates.
12+ #
13+ # Using Captum, you can apply a wide range of state-of-the-art feature
14+ # attribution algorithms such as \ ``Guided GradCam``\ and
15+ # \ ``Integrated Gradients``\ in a unified way.
16+ #
17+ # In this recipe you will learn how to use Captum to: \* attribute the
18+ # predictions of an image classifier to their corresponding image
19+ # features. \* visualize the attribution results.
20+ #
21+
22+
23+ ######################################################################
24+ # Before you begin
25+ # ----------------
26+ #
27+
28+
29+ ######################################################################
30+ # Make sure Captum is installed in your active Python environment. Captum
31+ # is available both on GitHub, as a ``pip`` package, or as a ``conda``
32+ # package. For detailed instructions, consult the installation guide at
33+ # https://captum.ai/
34+ #
35+
36+
37+ ######################################################################
38+ # For a model, we use a built-in image classifier in PyTorch. Captum can
39+ # reveal which parts of a sample image support certain predictions made by
40+ # the model.
41+ #
42+
43+ import torchvision
44+ from torchvision import transforms
45+ from PIL import Image
46+ import requests
47+ from io import BytesIO
48+
49+ model = torchvision .models .resnet18 (pretrained = True ).eval ()
50+
51+ response = requests .get ("https://image.freepik.com/free-photo/two-beautiful-puppies-cat-dog_58409-6024.jpg" )
52+ img = Image .open (BytesIO (response .content ))
53+
54+ center_crop = transforms .Compose ([
55+ transforms .Resize (256 ),
56+ transforms .CenterCrop (224 ),
57+ ])
58+
59+ normalize = transforms .Compose ([
60+ transforms .ToTensor (), # converts the image to a tensor with values between 0 and 1
61+ transforms .Normalize ( # normalize to follow 0-centered imagenet pixel rgb distribution
62+ mean = [0.485 , 0.456 , 0.406 ],
63+ std = [0.229 , 0.224 , 0.225 ]
64+ )
65+ ])
66+ input_img = normalize (center_crop (img )).unsqueeze (0 )
67+
68+
69+ ######################################################################
70+ # Computing Attribution
71+ # ---------------------
72+ #
73+
74+
75+ ######################################################################
76+ # Among the top-3 predictions of the models are classes 208 and 283 which
77+ # correspond to dog and cat.
78+ #
79+ # Let us attribute each of these predictions to the corresponding part of
80+ # the input, using Captum’s \ ``Occlusion``\ algorithm.
81+ #
82+
83+ from captum .attr import Occlusion
84+
85+ occlusion = Occlusion (model )
86+
87+ strides = (3 , 9 , 9 ) # smaller = more fine-grained attribution but slower
88+ target = 208 , # Labrador index in ImageNet
89+ sliding_window_shapes = (3 ,45 , 45 ) # choose size enough to change object appearance
90+ baselines = 0 # values to occlude the image with. 0 corresponds to gray
91+
92+ attribution_dog = occlusion .attribute (input_img ,
93+ strides = strides ,
94+ target = target ,
95+ sliding_window_shapes = sliding_window_shapes ,
96+ baselines = baselines )
97+
98+
99+ target = 283 , # Persian cat index in ImageNet
100+ attribution_cat = occlusion .attribute (input_img ,
101+ strides = strides ,
102+ target = target ,
103+ sliding_window_shapes = sliding_window_shapes ,
104+ baselines = 0 )
105+
106+
107+ ######################################################################
108+ # Besides ``Occlusion``, Captum features many algorithms such as
109+ # \ ``Integrated Gradients``\ , \ ``Deconvolution``\ ,
110+ # \ ``GuidedBackprop``\ , \ ``Guided GradCam``\ , \ ``DeepLift``\ , and
111+ # \ ``GradientShap``\ . All of these algorithms are subclasses of
112+ # ``Attribution`` which expects your model as a callable ``forward_func``
113+ # upon initialization and has an ``attribute(...)`` method which returns
114+ # the attribution result in a unified format.
115+ #
116+ # Let us visualize the computed attribution results in case of images.
117+ #
118+
119+
120+ ######################################################################
121+ # Visualizing the Results
122+ # -----------------------
123+ #
124+
125+
126+ ######################################################################
127+ # Captum’s \ ``visualization``\ utility provides out-of-the-box methods
128+ # to visualize attribution results both for pictorial and for textual
129+ # inputs.
130+ #
131+
132+ import numpy as np
133+ from captum .attr import visualization as viz
134+
135+ # Convert the compute attribution tensor into an image-like numpy array
136+ attribution_dog = np .transpose (attribution_dog .squeeze ().cpu ().detach ().numpy (), (1 ,2 ,0 ))
137+
138+ vis_types = ["heat_map" , "original_image" ]
139+ vis_signs = ["all" , "all" ], # "positive", "negative", or "all" to show both
140+ # positive attribution indicates that the presence of the area increases the prediction score
141+ # negative attribution indicates distractor areas whose absence increases the score
142+
143+ _ = viz .visualize_image_attr_multiple (attribution_dog ,
144+ center_crop (img ),
145+ vis_types ,
146+ vis_signs ,
147+ ["attribution for dog" , "image" ],
148+ show_colorbar = True
149+ )
150+
151+
152+ attribution_cat = np .transpose (attribution_cat .squeeze ().cpu ().detach ().numpy (), (1 ,2 ,0 ))
153+
154+ _ = viz .visualize_image_attr_multiple (attribution_cat ,
155+ center_crop (img ),
156+ ["heat_map" , "original_image" ],
157+ ["all" , "all" ], # positive/negative attribution or all
158+ ["attribution for cat" , "image" ],
159+ show_colorbar = True
160+ )
161+
162+
163+ ######################################################################
164+ # If your data is textual, ``visualization.visualize_text()`` offers a
165+ # dedicated view to explore attribution on top of the input text. Find out
166+ # more at http://captum.ai/tutorials/IMDB_TorchText_Interpret
167+ #
168+
169+
170+ ######################################################################
171+ # Final Notes
172+ # -----------
173+ #
174+
175+
176+ ######################################################################
177+ # Captum can handle most model types in PyTorch across modalities
178+ # including vision, text, and more. With Captum you can: \* Attribute a
179+ # specific output to the model input as illustrated above. \* Attribute a
180+ # specific output to a hidden-layer neuron (see Captum API reference). \*
181+ # Attribute a hidden-layer neuron response to the model input (see Captum
182+ # API reference).
183+ #
184+ # For complete API of the supported methods and a list of tutorials,
185+ # consult our website http://captum.ai
186+ #
187+ # Another useful post by Gilbert Tanner:
188+ # https://gilberttanner.com/blog/interpreting-pytorch-models-with-captum
189+ #
0 commit comments