Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions captum/optim/_core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ def wrapper(*args, **kwargs) -> object:
class LayerActivation(BaseLoss):
"""
Maximize activations at the target layer.
This is the most basic loss available and it simply returns the activations in
their original form.

Args:
target (nn.Module): The layer to optimize for.
batch_index (int, optional): The index of the image to optimize if we
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch_index documentation is missing Default: None below the description for all of the batch_index docs I think.

Copy link
Contributor

@ProGamerGov ProGamerGov Mar 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the batch index docs should look like this:

            batch_index (int, optional): The index of activations to optimize if
                optimizing a batch of activations. If set to None, defaults to all
                activations in the batch.
                Default: None

optimizing a batch of images. If unspecified, defaults to all images
in the batch.
"""

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
Expand All @@ -201,6 +209,15 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
class ChannelActivation(BaseLoss):
"""
Maximize activations at the target layer and target channel.
This loss maximizes the activations of a target channel in a specified target
layer, and can be useful to determine what features the channel is excited by.

Args:
target (nn.Module): The layer to containing the channel to optimize for.
channel_index (int): The index of the channel to optimize for.
batch_index (int, optional): The index of the image to optimize if we
optimizing a batch of images. If unspecified, defaults to all images
in the batch.
Comment on lines +218 to +220
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we should be calling the activations images in the batch index docs, as the activations can be images, or 2D / 4D activations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's also a grammatical error in the batch_index docs: "optimize if we optimizing a batch of", should be "optimize if we are optimizing a batch of".

"""

def __init__(
Expand All @@ -224,6 +241,26 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:

@loss_wrapper
class NeuronActivation(BaseLoss):
"""
This loss maximizes the activations of a target neuron in the specified channel
from the specified layer. This loss is useful for determining the type of features
that excite a neuron, and thus is often used for circuits and neuron related
research.

Args:
target (nn.Module): The layer to containing the channel to optimize for.
channel_index (int): The index of the channel to optimize for.
x (int, optional): The x coordinate of the neuron to optimize for. If
unspecified, defaults to center, or one unit left of center for even
lengths.
y (int, optional): The y coordinate of the neuron to optimize for. If
unspecified, defaults to center, or one unit up of center for even
heights.
Comment on lines +253 to +258
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These docs are also missing Default: None.

batch_index (int, optional): The index of the image to optimize if we
optimizing a batch of images. If unspecified, defaults to all images
in the batch.
"""

def __init__(
self,
target: nn.Module,
Expand Down Expand Up @@ -258,6 +295,16 @@ class DeepDream(BaseLoss):
"""
Maximize 'interestingness' at the target layer.
Mordvintsev et al., 2015.
https://github.com/google/deepdream
This loss returns the squared layer activations. When combined with a negative
mean loss summarization, this loss will create hallucinogenic visuals commonly
referred to as 'Deep Dream'.

Args:
target (nn.Module): The layer to optimize for.
batch_index (int, optional): The index of the image to optimize if we
optimizing a batch of images. If unspecified, defaults to all images
in the batch.
"""

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
Expand All @@ -272,6 +319,15 @@ class TotalVariation(BaseLoss):
Total variation denoising penalty for activations.
See Mahendran, V. 2014. Understanding Deep Image Representations by Inverting Them.
https://arxiv.org/abs/1412.0035
This loss attempts to smooth / denoise the target by performing total variance
denoising. The target is most often the image that’s being optimized. This loss is
often used to remove unwanted visual artifacts.

Args:
target (nn.Module): The layer to optimize for.
batch_index (int, optional): The index of the image to optimize if we
optimizing a batch of images. If unspecified, defaults to all images
in the batch.
"""

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
Expand All @@ -286,6 +342,14 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
class L1(BaseLoss):
"""
L1 norm of the target layer, generally used as a penalty.

Args:
target (nn.Module): The layer to optimize for.
constant (float): Constant threshold to deduct from the activations.
Defaults to 0.
batch_index (int, optional): The index of the image to optimize if we
optimizing a batch of images. If unspecified, defaults to all images
in the batch.
"""

def __init__(
Expand All @@ -307,6 +371,15 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
class L2(BaseLoss):
"""
L2 norm of the target layer, generally used as a penalty.

Args:
target (nn.Module): The layer to optimize for.
constant (float): Constant threshold to deduct from the activations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc should be listed as optional, as there is a default value: constant (float, optional)

Defaults to 0.
epsilon (float): Small value to add to L2 prior to sqrt. Defaults to 1e-6.
Copy link
Contributor

@ProGamerGov ProGamerGov Mar 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the documentation format requires Default: <default_value> on a separate line. Defaults can still however be mentioned in the description.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The epsilon doc is also missing the optional word

batch_index (int, optional): The index of the image to optimize if we
optimizing a batch of images. If unspecified, defaults to all images
in the batch.
"""

def __init__(
Expand Down Expand Up @@ -334,6 +407,14 @@ class Diversity(BaseLoss):
Use a cosine similarity penalty to extract features from a polysemantic neuron.
Olah, Mordvintsev & Schubert, 2017.
https://distill.pub/2017/feature-visualization/#diversity
This loss helps break up polysemantic layers, channels, and neurons by encouraging
diversity across the different batches. This loss is to be used along with a main
loss.

Args:
target (nn.Module): The layer to optimize for.
batch_index (int, optional): Unused here since we are optimizing for diversity
across the batch.
"""

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
Expand All @@ -359,6 +440,16 @@ class ActivationInterpolation(BaseLoss):
Interpolate between two different layers & channels.
Olah, Mordvintsev & Schubert, 2017.
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons
This loss helps to interpolate or mix visualizations from two activations (layer or
channel) by interpolating a linear sum between the two activations.

Args:
target1 (nn.Module): The first layer to optimize for.
channel_index1 (int): Index of channel in first layer to optimize. Defaults to
all channels.
target2 (nn.Module): The first layer to optimize for.
channel_index2 (int): Index of channel in first layer to optimize. Defaults to
all channels.
"""

def __init__(
Expand Down Expand Up @@ -410,6 +501,14 @@ class Alignment(BaseLoss):
similarity between them.
Olah, Mordvintsev & Schubert, 2017.
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons
When interpolating between activations, it may be desirable to keep image landmarks
in the same position for visual comparison. This loss helps to minimize L2 distance
between neighbouring images.

Args:
target (nn.Module): The layer to optimize for.
decay_ratio (float): How much to decay penalty as images move apart in batch.
Defaults to 2.
"""

def __init__(self, target: nn.Module, decay_ratio: float = 2.0) -> None:
Expand Down Expand Up @@ -438,6 +537,18 @@ class Direction(BaseLoss):
Visualize a general direction vector.
Carter, et al., "Activation Atlas", Distill, 2019.
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
This loss helps to visualize a specific vector direction in a layer, by maximizing
the alignment between the input vector and the layer’s activation vector. The
dimensionality of the vector should correspond to the number of channels in the
layer.

Args:
target (nn.Module): The layer to optimize for.
vec (torch.Tensor): Vector representing direction to align to.
cossim_pow (float, optional): The desired cosine similarity power to use.
batch_index (int, optional): The index of the image to optimize if we
optimizing a batch of images. If unspecified, defaults to all images
in the batch.
"""

def __init__(
Expand All @@ -464,6 +575,23 @@ class NeuronDirection(BaseLoss):
Visualize a single (x, y) position for a direction vector.
Carter, et al., "Activation Atlas", Distill, 2019.
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
Extends Direction loss by focusing on visualizing a single neuron within the
kernel.

Args:
target (nn.Module): The layer to optimize for.
vec (torch.Tensor): Vector representing direction to align to.
x (int, optional): The x coordinate of the neuron to optimize for. If
unspecified, defaults to center, or one unit left of center for even
lengths.
y (int, optional): The y coordinate of the neuron to optimize for. If
unspecified, defaults to center, or one unit up of center for even
heights.
Comment on lines +584 to +589
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These docs are missing Default: None.

channel_index (int): The index of the channel to optimize for.
cossim_pow (float, optional): The desired cosine similarity power to use.
Comment on lines +590 to +591
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These docs are missing their Default settings.

batch_index (int, optional): The index of the image to optimize if we
optimizing a batch of images. If unspecified, defaults to all images
in the batch.
"""

def __init__(
Expand Down Expand Up @@ -505,6 +633,15 @@ class TensorDirection(BaseLoss):
Visualize a tensor direction vector.
Carter, et al., "Activation Atlas", Distill, 2019.
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
Extends Direction loss by allowing batch-wise direction visualization.

Args:
target (nn.Module): The layer to optimize for.
vec (torch.Tensor): Vector representing direction to align to.
cossim_pow (float, optional): The desired cosine similarity power to use.
batch_index (int, optional): The index of the image to optimize if we
optimizing a batch of images. If unspecified, defaults to all images
in the batch.
"""

def __init__(
Expand Down Expand Up @@ -542,6 +679,23 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
class ActivationWeights(BaseLoss):
"""
Apply weights to channels, neurons, or spots in the target.
This loss weighs specific channels or neurons in a given layer, via a weight
vector.

Args:
target (nn.Module): The layer to optimize for.
weights (torch.Tensor): Weights to apply to targets.
neuron (bool): Whether target is a neuron. Defaults to False.
x (int, optional): The x coordinate of the neuron to optimize for. If
unspecified, defaults to center, or one unit left of center for even
lengths.
y (int, optional): The y coordinate of the neuron to optimize for. If
unspecified, defaults to center, or one unit up of center for even
heights.
wx (int, optional): Length of neurons to apply the weights to, along the
x-axis.
wy (int, optional): Length of neurons to apply the weights to, along the
y-axis.
Comment on lines +689 to +698
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These docs are all missing Default: None

"""

def __init__(
Expand Down