Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
740fcde
Add Activation Atlas tutorial & functions
ProGamerGov Jan 6, 2021
1127d31
Add tests for atlas functions & random rotation transform
ProGamerGov Jan 6, 2021
e19fe65
Only test atlas functions with >= torch 1.7.0
ProGamerGov Jan 6, 2021
dcc7743
Added citations, better atlas docs, & asserts
ProGamerGov Jan 6, 2021
afde029
Improve atlas docs and variables
ProGamerGov Jan 6, 2021
50d2ddf
Update tutorial variable & fix tutorial viz cell
ProGamerGov Jan 6, 2021
beeb2af
Improve capture_activation_samples
ProGamerGov Jan 7, 2021
a5a03e8
Improve description
ProGamerGov Jan 7, 2021
fd45b60
Don't collect samples from edges
ProGamerGov Jan 9, 2021
c3f0cd4
Add missing whitespace to arithmetic operators
ProGamerGov Jan 9, 2021
997d8d2
Merge branch 'optim-wip' of https://github.com/pytorch/captum into op…
Jan 11, 2021
6ead702
Update atlas tutorial with corrected colorspace
ProGamerGov Jan 11, 2021
1618c4e
Fix tutorial model transform
ProGamerGov Jan 12, 2021
c04621f
Code improvements
ProGamerGov Jan 12, 2021
ff50ed6
New samples per image parameter, speed improvements & more
ProGamerGov Jan 14, 2021
81a2c5f
Delete labels for now
ProGamerGov Jan 14, 2021
5f51cec
Fix variable
ProGamerGov Jan 14, 2021
bd34b72
Move WhitenedNeuronDirection to core/loss
ProGamerGov Jan 15, 2021
11fdd19
Activation atlas tutorial improvements
ProGamerGov Jan 17, 2021
84b9f69
Implement improved sample collection
ProGamerGov Jan 24, 2021
34853d7
Fix lint & tests
ProGamerGov Jan 24, 2021
d557791
Fix lint errors
ProGamerGov Jan 24, 2021
4819259
Add List type hint
ProGamerGov Jan 24, 2021
ffac297
Improve activation atlas tutorial
ProGamerGov Jan 24, 2021
b5fbd7d
Improve activation atlas tutorial
ProGamerGov Jan 28, 2021
e7a0741
Merge branch 'optim-wip' of https://github.com/pytorch/captum into op…
Jan 28, 2021
765c352
Improvements to atlas docs, tests, tutorial, & code
ProGamerGov Jan 31, 2021
d5c2cda
Move to [x,y] graph format & improvements
ProGamerGov Feb 1, 2021
37490c6
Fix broken atlas tests
ProGamerGov Feb 1, 2021
a39167f
Minor improvements
ProGamerGov Feb 3, 2021
ce0a87c
General improvements
ProGamerGov Feb 10, 2021
5c22fd9
Add missing 'cast' import
ProGamerGov Feb 10, 2021
ade0b88
Support for class activation atlases & improvements
ProGamerGov Mar 14, 2021
3d94549
Fix Distributed Data Parallel tests failing in torch 1.8
ProGamerGov Mar 14, 2021
bd3e26e
Fix EOF error
ProGamerGov Mar 14, 2021
93fa856
Add the second part of the class activation atlas tutorial
ProGamerGov Mar 15, 2021
5a29c94
Small improvements to class activation atlas tutorial
ProGamerGov Mar 15, 2021
b93d1da
Modify atlas tutorials for multiple attempts & class vis
ProGamerGov Mar 19, 2021
05bce14
Merge branch 'optim-wip' of https://github.com/pytorch/captum into op…
Mar 21, 2021
26d646c
Fix part 2 of class activation atlas tutorial notebook
ProGamerGov Mar 21, 2021
c4483db
Add tests for MaxPool2dRelaxed
ProGamerGov Mar 21, 2021
0927f6c
Add MaxPool2dRelaxed
ProGamerGov Mar 21, 2021
d8e1d0c
Add MaxPool2dRelaxed with tests
ProGamerGov Mar 21, 2021
22062ae
Improve atlas sample collection tutorial part
ProGamerGov Mar 21, 2021
dee2b17
Show progress when concatenating samples
ProGamerGov Mar 22, 2021
9de3924
Improve WhitenedNeuronDirection documentation
ProGamerGov Mar 28, 2021
c37c830
WhitenedNeuronDirection -> AngledNeuronDirection & better description…
ProGamerGov Mar 28, 2021
a479fa9
Revamp the Activation Atlas tutorials
ProGamerGov Apr 4, 2021
fc6584d
Fix variable names, grammar, and spelling in atlas tutorials
ProGamerGov Apr 4, 2021
ffd7277
Improve the first steps of the activation atlas tutorials
ProGamerGov Apr 5, 2021
ad9c97e
Fix sentence & add device to weights_to_heatmap_2d
ProGamerGov Apr 7, 2021
1c89756
Add __all__ to atlas.py & fix UserWarning in rotation transform
ProGamerGov Apr 7, 2021
5c43669
Atlas tutorial improvements & changes for #552
ProGamerGov Apr 19, 2021
4801481
Minor improvements to atlas tutorials
ProGamerGov Apr 20, 2021
1339c65
Various minor improvements
ProGamerGov Apr 21, 2021
2fac2b9
Minor fixes & Improvements to atlas code & tutorials
ProGamerGov Apr 22, 2021
a6a3492
Add CUDA tests for atlas functions
ProGamerGov Apr 22, 2021
78b14d9
Refine some atlas function type hints
ProGamerGov Apr 22, 2021
5f0fe16
Fix atlas function docs
ProGamerGov Apr 22, 2021
afa0229
Improve atlas related docs & add more tests
ProGamerGov Apr 23, 2021
f40822b
Remove CenterCrop transform from Class Atlas tutorial
ProGamerGov Apr 23, 2021
36904fb
Improve AngledNeuronDirection
ProGamerGov Apr 25, 2021
c748358
Merge remote-tracking branch 'origin/patch-6' into optim-wip-activati…
ProGamerGov Apr 26, 2021
51e8775
Fix flake8 errors & update atlas tutorials
ProGamerGov Apr 26, 2021
5f5d6a6
Add AngledNeuronDirection to __all__
ProGamerGov Apr 26, 2021
5d75cfa
Fix crashing bug in atlas notebooks
ProGamerGov Apr 27, 2021
03e6d86
Activation atlas tutorial improvements
ProGamerGov Apr 27, 2021
dccb245
Add AngledNeuronDirection tests written by @greentfrapp
ProGamerGov Apr 29, 2021
3bf7554
Fix AngledNeuronDirection tests
ProGamerGov Apr 29, 2021
a29ec65
Use batch targeting in atlas tutorials
ProGamerGov Apr 30, 2021
10dd42d
Improve class atlas rendering parameters
ProGamerGov Apr 30, 2021
2b785b1
Minor improvements to the main activation atlas tutorial
ProGamerGov May 1, 2021
fe60b4b
Improve main activation atlas tutorial visualizations
ProGamerGov May 1, 2021
00ace09
Add asserts to direction objectives
ProGamerGov May 2, 2021
cc8ba5e
General improvements to both atlas tutorials
ProGamerGov May 2, 2021
a2370fa
Clarify where to use t-SNE vs UMAP in atlas tutorials
ProGamerGov May 3, 2021
022f2d3
Improve weights_to_heatmap_2d & nchannels_to_rgb tests
ProGamerGov May 3, 2021
be4ca32
Improvements to both atlas tutorials
ProGamerGov May 4, 2021
1de9e09
Move atlas sample collection to it's own notebook
ProGamerGov May 4, 2021
c6b6c22
Add adversarial example to class atlas tutorial
ProGamerGov May 6, 2021
1f94574
Minor improvements
ProGamerGov May 8, 2021
7e9e41f
Minor improvements & fixes for both atlas tutorials
ProGamerGov May 9, 2021
492561a
Add comments to the batch targeting section of both atlas tutorials
ProGamerGov May 10, 2021
569e6af
Minor correction to atlas docs
ProGamerGov May 10, 2021
48f5b63
More minor improvements to atlas tutorials
ProGamerGov May 12, 2021
294c1c7
More tutorial improvements
ProGamerGov May 15, 2021
965a44a
Fix adversarial example image urls
ProGamerGov May 18, 2021
ee6992f
Speed up cov matrix calculations & support any number of channels
ProGamerGov Jun 6, 2021
0261aab
Merge branch 'optim-wip' into optim-wip-activation-atlas
ProGamerGov Jun 7, 2021
c4b19b0
Reorganize part of the Activation Atlas PR
ProGamerGov Jul 31, 2021
1fc4a43
Remove unused import
ProGamerGov Jul 31, 2021
5cdf9d7
Merge remote-tracking branch 'upstream/optim-wip' into optim-wip-acti…
ProGamerGov Aug 4, 2021
d5fbdaf
Only commit main atlas changes
ProGamerGov Aug 4, 2021
95bb3ef
Improve doc formatting
ProGamerGov Aug 5, 2021
6093063
Improve tutorial function docs
ProGamerGov Aug 8, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion captum/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from captum.optim._core.optimization import InputOptimization # noqa: F401
from captum.optim._param.image import images, transforms # noqa: F401
from captum.optim._param.image.images import ImageTensor # noqa: F401
from captum.optim._utils import circuits, reducer # noqa: F401
from captum.optim._utils import atlas, circuits, reducer # noqa: F401
from captum.optim._utils.image.common import ( # noqa: F401
nchannels_to_rgb,
save_tensor_as_image,
Expand Down
96 changes: 87 additions & 9 deletions captum/optim/_core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,14 +448,15 @@ def __init__(
batch_index: Optional[int] = None,
) -> None:
BaseLoss.__init__(self, target, batch_index)
self.direction = vec.reshape((1, -1, 1, 1))
assert vec.dim() == 2 or vec.dim() == 4
self.vec = vec.reshape((vec.size(0), -1, 1, 1)) if vec.dim() == 2 else vec
self.cossim_pow = cossim_pow

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
assert activations.size(1) == self.direction.size(1)
assert activations.size(1) == self.vec.size(1)
activations = activations[self.batch_index[0] : self.batch_index[1]]
return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow)
return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow)


@loss_wrapper
Expand All @@ -477,7 +478,8 @@ def __init__(
batch_index: Optional[int] = None,
) -> None:
BaseLoss.__init__(self, target, batch_index)
self.direction = vec.reshape((1, -1, 1, 1))
assert vec.dim() == 2 or vec.dim() == 4
self.vec = vec.reshape((vec.size(0), -1, 1, 1)) if vec.dim() == 2 else vec
self.x = x
self.y = y
self.channel_index = channel_index
Expand All @@ -496,7 +498,81 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
]
if self.channel_index is not None:
activations = activations[:, self.channel_index, ...][:, None, ...]
return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow)
return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow)


@loss_wrapper
class AngledNeuronDirection(BaseLoss):
"""
Visualize a direction vector with an optional whitened activation vector to
unstretch the activation space. Compared to the traditional Direction objectives,
this objective places more emphasis on angle by optionally multiplying the dot
product by the cosine similarity.

When cossim_pow is equal to 0, this objective works as a euclidean
neuron objective. When cossim_pow is greater than 0, this objective works as a
cosine similarity objective. An additional whitened neuron direction vector
can optionally be supplied to improve visualization quality for some models.

Carter, et al., "Activation Atlas", Distill, 2019.
https://distill.pub/2019/activation-atlas/
Args:
target (nn.Module): A target layer instance.
vec (torch.Tensor): A neuron direction vector to use.
vec_whitened (torch.Tensor, optional): A whitened neuron direction vector.
cossim_pow (float, optional): The desired cosine similarity power to use.
x (int, optional): Optionally provide a specific x position for the target
neuron.
y (int, optional): Optionally provide a specific y position for the target
neuron.
eps (float, optional): If cossim_pow is greater than zero, the desired
epsilon value to use for cosine similarity calculations.
"""

def __init__(
self,
target: torch.nn.Module,
vec: torch.Tensor,
vec_whitened: Optional[torch.Tensor] = None,
cossim_pow: float = 4.0,
x: Optional[int] = None,
y: Optional[int] = None,
eps: float = 1.0e-4,
batch_index: Optional[int] = None,
) -> None:
BaseLoss.__init__(self, target, batch_index)
self.vec = vec.unsqueeze(0) if vec.dim() == 1 else vec
self.vec_whitened = vec_whitened
self.cossim_pow = cossim_pow
self.eps = eps
self.x = x
self.y = y
if self.vec_whitened is not None:
assert self.vec_whitened.dim() == 2
assert self.vec.dim() == 2

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
activations = activations[self.batch_index[0] : self.batch_index[1]]
assert activations.dim() == 4 or activations.dim() == 2
assert activations.shape[1] == self.vec.shape[1]
if activations.dim() == 4:
_x, _y = get_neuron_pos(
activations.size(2), activations.size(3), self.x, self.y
)
activations = activations[..., _x, _y]

vec = (
torch.matmul(self.vec, self.vec_whitened)[0]
if self.vec_whitened is not None
else self.vec
)
if self.cossim_pow == 0:
return activations * vec

dot = torch.mean(activations * vec)
cossims = dot / (self.eps + torch.sqrt(torch.sum(activations ** 2)))
return dot * torch.clamp(cossims, min=0.1) ** self.cossim_pow


@loss_wrapper
Expand All @@ -515,16 +591,17 @@ def __init__(
batch_index: Optional[int] = None,
) -> None:
BaseLoss.__init__(self, target, batch_index)
self.direction = vec
assert vec.dim() == 4
self.vec = vec
self.cossim_pow = cossim_pow

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]

assert activations.dim() == 4

H_direction, W_direction = self.direction.size(2), self.direction.size(3)
H_activ, W_activ = activations.size(2), activations.size(3)
H_direction, W_direction = self.vec.shape[2:]
H_activ, W_activ = activations.shape[2:]

H = (H_activ - H_direction) // 2
W = (W_activ - W_direction) // 2
Expand All @@ -535,7 +612,7 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
H : H + H_direction,
W : W + W_direction,
]
return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow)
return _dot_cossim(self.vec, activations, cossim_pow=self.cossim_pow)


@loss_wrapper
Expand Down Expand Up @@ -617,6 +694,7 @@ def default_loss_summarize(loss_value: torch.Tensor) -> torch.Tensor:
"Alignment",
"Direction",
"NeuronDirection",
"AngledNeuronDirection",
"TensorDirection",
"ActivationWeights",
"default_loss_summarize",
Expand Down
73 changes: 73 additions & 0 deletions captum/optim/_param/image/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ def __init__(self, scale: NumSeqOrTensorType) -> None:
scale (float, sequence): Tuple of rescaling values to randomly select from.
"""
super().__init__()
assert hasattr(scale, "__iter__")
if torch.is_tensor(scale):
assert cast(torch.Tensor, scale).dim() == 1
assert len(scale) > 0
self.scale = scale

def get_scale_mat(
Expand Down Expand Up @@ -384,6 +388,75 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.translate_tensor(input, insets)


class RandomRotation(nn.Module):
"""
Apply random rotation transforms on a NCHW tensor, using a sequence of degrees.
"""

def __init__(
self, degrees: Union[List[float], Tuple[float, ...], torch.Tensor]
) -> None:
"""
Args:

degrees (float, sequence): Tuple, List, or Tensor of degrees to randomly
select from.
"""
super().__init__()
assert hasattr(degrees, "__iter__")
if torch.is_tensor(degrees):
assert cast(torch.Tensor, degrees).dim() == 1
assert len(degrees) > 0
self.degrees = degrees

def get_rot_mat(
self,
theta: Union[int, float, torch.Tensor],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
theta = torch.tensor(theta, device=device, dtype=dtype)
rot_mat = torch.tensor(
[
[torch.cos(theta), -torch.sin(theta), 0],
[torch.sin(theta), torch.cos(theta), 0],
],
device=device,
dtype=dtype,
)
return rot_mat

def rotate_tensor(
self, x: torch.Tensor, theta: Union[int, float, torch.Tensor]
) -> torch.Tensor:
theta = theta * math.pi / 180
rot_matrix = self.get_rot_mat(theta, x.device, x.dtype)[None, ...].repeat(
x.shape[0], 1, 1
)
if torch.__version__ >= "1.3.0":
# Pass align_corners explicitly for torch >= 1.3.0
grid = F.affine_grid(rot_matrix, x.size(), align_corners=False)
x = F.grid_sample(x, grid, align_corners=False)
else:
grid = F.affine_grid(rot_matrix, x.size())
x = F.grid_sample(x, grid)
return x

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Randomly rotate an input tensor.

Args:

input (torch.Tensor): Input to randomly rotate.

Returns:
**tensor** (torch.Tensor): A randomly rotated *tensor*.
"""
rotate_angle = _rand_select(self.degrees)
return self.rotate_tensor(x, rotate_angle)


class ScaleInputRange(nn.Module):
"""
Multiplies the input by a specified multiplier for models with input ranges other
Expand Down
Loading