-
Notifications
You must be signed in to change notification settings - Fork 547
Closed
Description
📚 Documentation
This jupyter-notebook-example-code from the docs:
https://pytorch.org/tutorials/recipes/recipes/Captum_Recipe.html
throws the following error in the 'Visualizing the Results'-Cell:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-4-43fe1d15438b> in <module>
15 vis_signs,
16 ["attribution for dog", "image"],
---> 17 show_colorbar = True
18 )
19
~\Anaconda3\lib\site-packages\captum\attr\_utils\visualization.py in visualize_image_attr_multiple(attr, original_image, methods, signs, titles, fig_size, use_pyplot, **kwargs)
387 >>> ["all", "positive"], attribution, orig_image)
388 """
--> 389 assert len(methods) == len(signs), "Methods and signs array lengths must match."
390 if titles is not None:
391 assert len(methods) == len(titles), (
AssertionError: Methods and signs array lengths must match.
This is due to the trailing comma in vis_sings:
import numpy as np
from captum.attr import visualization as viz
# Convert the compute attribution tensor into an image-like numpy array
attribution_dog = np.transpose(attribution_dog.squeeze().cpu().detach().numpy(), (1,2,0))
vis_types = ["heat_map", "original_image"]
vis_signs = ["all", "all"], # "positive", "negative", or "all" to show both
# positive attribution indicates that the presence of the area increases the prediction score
# negative attribution indicates distractor areas whose absence increases the score
_ = viz.visualize_image_attr_multiple(attribution_dog,
center_crop(img),
vis_types,
vis_signs,
["attribution for dog", "image"],
show_colorbar = True
)
attribution_cat = np.transpose(attribution_cat.squeeze().cpu().detach().numpy(), (1,2,0))
_ = viz.visualize_image_attr_multiple(attribution_cat,
center_crop(img),
["heat_map", "original_image"],
["all", "all"], # positive/negative attribution or all
["attribution for cat", "image"],
show_colorbar = True
)
Relevant packages in my env:
torch==1.6.0
torchvision==0.7.0
captum==0.2.0
Python 3.7.7
bilalsal
Metadata
Metadata
Assignees
Labels
No labels