Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b7333ec
Add main Activation Atlas classes & functions
ProGamerGov Sep 30, 2021
dae9576
Add main Activation Atlas tutorial notebook
ProGamerGov Oct 3, 2021
f35b844
Remove unsued import
ProGamerGov Oct 10, 2021
340d3b8
Changes based on feedback
ProGamerGov Oct 21, 2021
86eb581
Revert ufmt change as it causes isort to fail
ProGamerGov Oct 21, 2021
f52ae2e
Improve documentation of `AngledNeuronDirection` & atlas functions
ProGamerGov Oct 27, 2021
d51ef0d
Improve atlas related documentation
ProGamerGov Oct 31, 2021
6ac980c
Move atlas.py to _utils/image & improve atlas docs
ProGamerGov Oct 31, 2021
18b017f
Add sum_loss_list() function & correct target type hints
ProGamerGov Nov 19, 2021
0233680
RandomRotation JIT support & other improvements
ProGamerGov Dec 12, 2021
f8aa611
Better way to handle torch version check with JIT
ProGamerGov Dec 12, 2021
5b52643
Use better scale type hint in RandomRotation init function
ProGamerGov Dec 13, 2021
e3c3457
Add torch.distributions support to RandomRotation
ProGamerGov Dec 20, 2021
dcf3b99
Add assert & more tests for RandomRotation
ProGamerGov Dec 24, 2021
2e4f4b0
Adding SkipTest to RandomRotation reflection
ProGamerGov Dec 24, 2021
d1e22ae
Fix formatting error
ProGamerGov Dec 24, 2021
271d845
Changes to main atlas tutorial notebook based on feedback
ProGamerGov Jan 24, 2022
ad03e3b
Remove unused type hint
ProGamerGov Jan 24, 2022
9b87f23
Improve whitening description in main activation atlas tutorial
ProGamerGov Jan 25, 2022
4a16d1a
Spelling & grammar fixes
ProGamerGov Jan 26, 2022
29d0283
Improve `calc_grid_indices` documentation
ProGamerGov Jan 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions captum/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from captum.optim._param.image import images, transforms # noqa: F401
from captum.optim._param.image.images import ImageTensor # noqa: F401
from captum.optim._utils import circuits, reducer # noqa: F401
from captum.optim._utils.image import atlas # noqa: F401
from captum.optim._utils.image.common import ( # noqa: F401
nchannels_to_rgb,
save_tensor_as_image,
Expand All @@ -23,6 +24,7 @@
"circuits",
"models",
"reducer",
"atlas",
"nchannels_to_rgb",
"save_tensor_as_image",
"show",
Expand Down
146 changes: 141 additions & 5 deletions captum/optim/_core/loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import operator
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -27,7 +27,7 @@ def __init__(self) -> None:
super(Loss, self).__init__()

@abstractproperty
def target(self) -> nn.Module:
def target(self) -> Union[nn.Module, List[nn.Module]]:
pass

@abstractmethod
Expand Down Expand Up @@ -140,7 +140,9 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:

class BaseLoss(Loss):
def __init__(
self, target: nn.Module = [], batch_index: Optional[int] = None
self,
target: Union[nn.Module, List[nn.Module]] = [],
batch_index: Optional[int] = None,
) -> None:
super(BaseLoss, self).__init__()
self._target = target
Expand All @@ -150,7 +152,7 @@ def __init__(
self._batch_index = (batch_index, batch_index + 1)

@property
def target(self) -> nn.Module:
def target(self) -> Union[nn.Module, List[nn.Module]]:
return self._target

@property
Expand All @@ -160,7 +162,10 @@ def batch_index(self) -> Tuple:

class CompositeLoss(BaseLoss):
def __init__(
self, loss_fn: Callable, name: str = "", target: nn.Module = []
self,
loss_fn: Callable,
name: str = "",
target: Union[nn.Module, List[nn.Module]] = [],
) -> None:
super(CompositeLoss, self).__init__(target)
self.__name__ = name
Expand Down Expand Up @@ -499,6 +504,94 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow)


@loss_wrapper
class AngledNeuronDirection(BaseLoss):
"""
Visualize a direction vector with an optional whitened activation vector to
unstretch the activation space. Compared to the traditional Direction objectives,
this objective places more emphasis on angle by optionally multiplying the dot
product by the cosine similarity.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since cosine similarity is also seen as normalized dot product, it sounds a bit unclear what cosine similarity is multiplied to what dot product.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Oct 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is what AngledNeuronDirection does:

        if self.cossim_pow == 0:
            return activations * vec
        dot = torch.mean(activations * vec)
        cossims = dot / (self.eps + torch.sqrt(torch.sum(activations ** 2)))
        return dot * torch.clamp(cossims, min=0.1) ** self.cossim_pow

And this is what _dot_cossim does:

    dot = torch.sum(x * y, dim)
    if cossim_pow == 0:
        return dot
    return dot * torch.clamp(torch.cosine_similarity(x, y, eps=eps), 0.1) ** cossim_pow

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @ProGamerGov! I saw that. I think the description of multiplying dot product by itself and normalizing it is a bit confusing. Is there any theoretical explanations for this.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Oct 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK From the Activation Atlas paper, I think this might answer your question:

We find it helpful to use an objective that emphasizes angle more heavily by multiplying the dot product by cosine similarity.... A reference implementation of this can be seen in the attached notebooks, and more general discussion can be found in this github issue.

The Lucid Github issue lists a bunch of different feature visualization objective algorithms, including the one we are using (I copied the relevant part from the issue below):

Dot x Cosine Similarity

  • Multiplying dot product by cosine similarity (possibly raised to a power) can be a useful way to get a dot-product like objective that cares more about angle, but still maximizes how far it can get in a certain direction. We've had quite a bit of success with this.
  • One important implementation details: you want to use something like dot(x,y) * ceil(0.1, cossim(x,y))^n to avoid multiplying dot product by 0 or negative cosine similarity. Otherwise, you could end up in a situation where you maximize the opposite direction (because both dot and cossim are negative, and multiply to be positive) or get stuck because both are zero.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I've added a link to the Lucid Github issue detailing the reasoning behind the loss algorithm!


When cossim_pow is equal to 0, this objective works as a euclidean
neuron objective. When cossim_pow is greater than 0, this objective works as a
cosine similarity objective. An additional whitened neuron direction vector
can optionally be supplied to improve visualization quality for some models.

More information on the algorithm this objective uses can be found here:
https://github.com/tensorflow/lucid/issues/116

This Lucid equivalents of this loss function can be found here:
https://github.com/tensorflow/lucid/blob/master/notebooks/
activation-atlas/activation-atlas-simple.ipynb
https://github.com/tensorflow/lucid/blob/master/notebooks/
activation-atlas/class-activation-atlas.ipynb

Like the Lucid equivalents, our implementation differs slightly from the
associated research paper.

Carter, et al., "Activation Atlas", Distill, 2019.
https://distill.pub/2019/activation-atlas/
"""

def __init__(
self,
target: torch.nn.Module,
vec: torch.Tensor,
vec_whitened: Optional[torch.Tensor] = None,
cossim_pow: float = 4.0,
x: Optional[int] = None,
y: Optional[int] = None,
eps: float = 1.0e-4,
batch_index: Optional[int] = None,
) -> None:
"""
Args:
target (nn.Module): A target layer instance.
vec (torch.Tensor): A neuron direction vector to use.
vec_whitened (torch.Tensor, optional): A whitened neuron direction vector.
cossim_pow (float, optional): The desired cosine similarity power to use.
x (int, optional): Optionally provide a specific x position for the target
neuron.
y (int, optional): Optionally provide a specific y position for the target
neuron.
eps (float, optional): If cossim_pow is greater than zero, the desired
epsilon value to use for cosine similarity calculations.
"""
BaseLoss.__init__(self, target, batch_index)
self.vec = vec.unsqueeze(0) if vec.dim() == 1 else vec
self.vec_whitened = vec_whitened
self.cossim_pow = cossim_pow
self.eps = eps
self.x = x
self.y = y
if self.vec_whitened is not None:
assert self.vec_whitened.dim() == 2
assert self.vec.dim() == 2

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
activations = activations[self.batch_index[0] : self.batch_index[1]]
assert activations.dim() == 4 or activations.dim() == 2
assert activations.shape[1] == self.vec.shape[1]
if activations.dim() == 4:
_x, _y = get_neuron_pos(
activations.size(2), activations.size(3), self.x, self.y
)
activations = activations[..., _x, _y]

vec = (
torch.matmul(self.vec, self.vec_whitened)[0]
if self.vec_whitened is not None
else self.vec
)
if self.cossim_pow == 0:
return activations * vec

dot = torch.mean(activations * vec)
cossims = dot / (self.eps + torch.sqrt(torch.sum(activations ** 2)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screen Shot 2021-10-18 at 7 50 25 PM

Shouldn't we divide by L2 norm for vec as well ?

Copy link
Contributor Author

@ProGamerGov ProGamerGov Oct 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code I've written follows after the implementation in the notebooks associated with the activation atlas papers:

They setup the objective calculations like this:

@objectives.wrap_objective
def direction_neuron_S(layer_name, vec, batch=None, x=None, y=None, S=None):
  def inner(T):
    layer = T(layer_name)
    shape = tf.shape(layer)
    x_ = shape[1] // 2 if x is None else x
    y_ = shape[2] // 2 if y is None else y
    if batch is None:
      raise RuntimeError("requires batch")

    acts = layer[batch, x_, y_]
    vec_ = vec
    if S is not None: vec_ = tf.matmul(vec_[None], S)[0]
    # mag = tf.sqrt(tf.reduce_sum(acts**2))
    dot = tf.reduce_mean(acts * vec_)
    # cossim = dot/(1e-4 + mag)
    return dot
  return inner


@objectives.wrap_objective
def direction_neuron_cossim_S(layer_name, vec, batch=None, x=None, y=None, cossim_pow=1, S=None):
  def inner(T):
    layer = T(layer_name)
    shape = tf.shape(layer)
    x_ = shape[1] // 2 if x is None else x
    y_ = shape[2] // 2 if y is None else y
    if batch is None:
      raise RuntimeError("requires batch")

    acts = layer[batch, x_, y_]
    vec_ = vec
    if S is not None: vec_ = tf.matmul(vec_[None], S)[0]
    mag = tf.sqrt(tf.reduce_sum(acts**2))
    dot = tf.reduce_mean(acts * vec_)
    cossim = dot/(1e-4 + mag)
    cossim = tf.maximum(0.1, cossim)
    return dot * cossim ** cossim_pow
  return inner

The PyTorch objective merges these two objectives and follows what they do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @ProGamerGov for the reference! I think that it would be good to cite those notebooks in the code because it is in general confusing if the formal definitions vary from the implementation. To be honest. I don't quite understand why they explicitly defined and implemented that way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK I'll add the references, and a note about how the reference code differs slightly from the implementation described in the paper!

Copy link
Contributor Author

@ProGamerGov ProGamerGov Oct 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have yet to figure out why their implementation is slightly different than the paper, but I figured that I should follow the working reference implementation rather than the paper's equations for now!

return dot * torch.clamp(cossims, min=0.1) ** self.cossim_pow


@loss_wrapper
class TensorDirection(BaseLoss):
"""
Expand Down Expand Up @@ -590,6 +683,47 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
return activations


def sum_loss_list(
loss_list: List,
to_scalar_fn: Callable[[torch.Tensor], torch.Tensor] = torch.mean,
) -> CompositeLoss:
"""
Summarize a large number of losses without recursion errors. By default using 300+
loss functions for a single optimization task will result in exceeding Python's
default maximum recursion depth limit. This function can be used to avoid the
recursion depth limit for tasks such as summarizing a large list of loss functions
with the built-in sum() function.

This function works similar to Lucid's optvis.objectives.Objective.sum() function.

Args:

loss_list (list): A list of loss function objectives.
to_scalar_fn (Callable): A function for converting loss function outputs to
scalar values, in order to prevent size mismatches.
Default: torch.mean

Returns:
loss_fn (CompositeLoss): A composite loss function containing all the loss
functions from `loss_list`.
"""

def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
return sum([to_scalar_fn(loss(module)) for loss in loss_list])

name = "Sum(" + ", ".join([loss.__name__ for loss in loss_list]) + ")"
# Collect targets from losses
target = [
target
for targets in [
[loss.target] if not hasattr(loss.target, "__iter__") else loss.target
for loss in loss_list
]
for target in targets
]
return CompositeLoss(loss_fn, name=name, target=target)


def default_loss_summarize(loss_value: torch.Tensor) -> torch.Tensor:
"""
Helper function to summarize tensor outputs from loss functions.
Expand Down Expand Up @@ -617,7 +751,9 @@ def default_loss_summarize(loss_value: torch.Tensor) -> torch.Tensor:
"Alignment",
"Direction",
"NeuronDirection",
"AngledNeuronDirection",
"TensorDirection",
"ActivationWeights",
"sum_loss_list",
"default_loss_summarize",
]
147 changes: 147 additions & 0 deletions captum/optim/_param/image/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,152 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.translate_tensor(input, insets)


class RandomRotation(nn.Module):
"""
Apply random rotation transforms on a NCHW tensor, using a sequence of degrees or
torch.distributions instance.
"""

__constants__ = [
"degrees",
"mode",
"padding_mode",
"align_corners",
"_has_align_corners",
"_is_distribution",
]

def __init__(
self,
degrees: NumSeqOrTensorType,
mode: str = "bilinear",
padding_mode: str = "zeros",
align_corners: bool = False,
) -> None:
"""
Args:

degrees (float, sequence, or torch.distribution): Tuple of degrees values
to randomly select from, or a torch.distributions instance.
mode (str, optional): Interpolation mode to use. See documentation of
F.grid_sample for more details. One of; "bilinear", "nearest", or
"bicubic".
Default: "bilinear"
padding_mode (str, optional): Padding mode for values that fall outside of
the grid. See documentation of F.grid_sample for more details. One of;
"zeros", "border", or "reflection".
Default: "zeros"
align_corners (bool, optional): Whether or not to align corners. See
documentation of F.affine_grid & F.grid_sample for more details.
Default: False
"""
super().__init__()
if isinstance(degrees, torch.distributions.distribution.Distribution):
# Distributions are not supported by TorchScript / JIT yet
assert degrees.batch_shape == torch.Size([])
self.degrees_distribution = degrees
self._is_distribution = True
self.degrees = []
else:
assert hasattr(degrees, "__iter__")
if torch.is_tensor(degrees):
assert cast(torch.Tensor, degrees).dim() == 1
degrees = degrees.tolist()
assert len(degrees) > 0
self.degrees = [float(d) for d in degrees]
self._is_distribution = False

self.mode = mode
self.padding_mode = padding_mode
self.align_corners = align_corners
self._has_align_corners = torch.__version__ >= "1.3.0"

def _get_rot_mat(
self,
theta: float,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Create a rotation matrix tensor.

Args:

theta (float): The rotation value in degrees.

Returns:
**rot_mat** (torch.Tensor): A rotation matrix.
"""
theta = theta * math.pi / 180.0
rot_mat = torch.tensor(
[
[math.cos(theta), -math.sin(theta), 0.0],
[math.sin(theta), math.cos(theta), 0.0],
],
device=device,
dtype=dtype,
)
return rot_mat

def _rotate_tensor(self, x: torch.Tensor, theta: float) -> torch.Tensor:
"""
Rotate an NCHW image tensor based on a specified degree value.

Args:

x (torch.Tensor): The NCHW image tensor to rotate.
theta (float): The amount to rotate the NCHW image, in degrees.

Returns:
**x** (torch.Tensor): A rotated NCHW image tensor.
"""
rot_matrix = self._get_rot_mat(theta, x.device, x.dtype)[None, ...].repeat(
x.shape[0], 1, 1
)
if self._has_align_corners:
# Pass align_corners explicitly for torch >= 1.3.0
grid = F.affine_grid(rot_matrix, x.size(), align_corners=self.align_corners)
x = F.grid_sample(
x,
grid,
mode=self.mode,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
)
else:
grid = F.affine_grid(rot_matrix, x.size())
x = F.grid_sample(x, grid, mode=self.mode, padding_mode=self.padding_mode)
return x

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Randomly rotate an NCHW image tensor.

Args:

x (torch.Tensor): NCHW image tensor to randomly rotate.

Returns:
**x** (torch.Tensor): A randomly rotated NCHW image *tensor*.
"""
assert x.dim() == 4
if self._is_distribution:
rotate_angle = float(self.degrees_distribution.sample().item())
else:
n = int(
torch.randint(
low=0,
high=len(self.degrees),
size=[1],
dtype=torch.int64,
layout=torch.strided,
device=x.device,
).item()
)
rotate_angle = self.degrees[n]
return self._rotate_tensor(x, rotate_angle)


class ScaleInputRange(nn.Module):
"""
Multiplies the input by a specified multiplier for models with input ranges other
Expand Down Expand Up @@ -673,6 +819,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"center_crop",
"RandomScale",
"RandomSpatialJitter",
"RandomRotation",
"ScaleInputRange",
"RGBToBGR",
"GaussianSmoothing",
Expand Down
Loading