-
Notifications
You must be signed in to change notification settings - Fork 541
Optim-wip: Add descriptions and argument documentation to losses #831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -189,6 +189,14 @@ def wrapper(*args, **kwargs) -> object: | |
class LayerActivation(BaseLoss): | ||
""" | ||
Maximize activations at the target layer. | ||
This is the most basic loss available and it simply returns the activations in | ||
their original form. | ||
|
||
Args: | ||
target (nn.Module): The layer to optimize for. | ||
batch_index (int, optional): The index of the image to optimize if we | ||
optimizing a batch of images. If unspecified, defaults to all images | ||
in the batch. | ||
""" | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
|
@@ -201,6 +209,15 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | |
class ChannelActivation(BaseLoss): | ||
""" | ||
Maximize activations at the target layer and target channel. | ||
This loss maximizes the activations of a target channel in a specified target | ||
layer, and can be useful to determine what features the channel is excited by. | ||
|
||
Args: | ||
target (nn.Module): The layer to containing the channel to optimize for. | ||
channel_index (int): The index of the channel to optimize for. | ||
batch_index (int, optional): The index of the image to optimize if we | ||
optimizing a batch of images. If unspecified, defaults to all images | ||
in the batch. | ||
Comment on lines
+218
to
+220
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if we should be calling the activations There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's also a grammatical error in the batch_index docs: "optimize if we optimizing a batch of", should be "optimize if we are optimizing a batch of". |
||
""" | ||
|
||
def __init__( | ||
|
@@ -224,6 +241,26 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | |
|
||
@loss_wrapper | ||
class NeuronActivation(BaseLoss): | ||
""" | ||
This loss maximizes the activations of a target neuron in the specified channel | ||
from the specified layer. This loss is useful for determining the type of features | ||
that excite a neuron, and thus is often used for circuits and neuron related | ||
research. | ||
|
||
Args: | ||
target (nn.Module): The layer to containing the channel to optimize for. | ||
channel_index (int): The index of the channel to optimize for. | ||
x (int, optional): The x coordinate of the neuron to optimize for. If | ||
unspecified, defaults to center, or one unit left of center for even | ||
lengths. | ||
y (int, optional): The y coordinate of the neuron to optimize for. If | ||
unspecified, defaults to center, or one unit up of center for even | ||
heights. | ||
Comment on lines
+253
to
+258
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These docs are also missing |
||
batch_index (int, optional): The index of the image to optimize if we | ||
optimizing a batch of images. If unspecified, defaults to all images | ||
in the batch. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
target: nn.Module, | ||
|
@@ -258,6 +295,16 @@ class DeepDream(BaseLoss): | |
""" | ||
Maximize 'interestingness' at the target layer. | ||
Mordvintsev et al., 2015. | ||
https://github.com/google/deepdream | ||
This loss returns the squared layer activations. When combined with a negative | ||
mean loss summarization, this loss will create hallucinogenic visuals commonly | ||
referred to as 'Deep Dream'. | ||
|
||
Args: | ||
target (nn.Module): The layer to optimize for. | ||
batch_index (int, optional): The index of the image to optimize if we | ||
optimizing a batch of images. If unspecified, defaults to all images | ||
in the batch. | ||
""" | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
|
@@ -272,6 +319,15 @@ class TotalVariation(BaseLoss): | |
Total variation denoising penalty for activations. | ||
See Mahendran, V. 2014. Understanding Deep Image Representations by Inverting Them. | ||
https://arxiv.org/abs/1412.0035 | ||
This loss attempts to smooth / denoise the target by performing total variance | ||
denoising. The target is most often the image that’s being optimized. This loss is | ||
often used to remove unwanted visual artifacts. | ||
|
||
Args: | ||
target (nn.Module): The layer to optimize for. | ||
batch_index (int, optional): The index of the image to optimize if we | ||
optimizing a batch of images. If unspecified, defaults to all images | ||
in the batch. | ||
""" | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
|
@@ -286,6 +342,14 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | |
class L1(BaseLoss): | ||
""" | ||
L1 norm of the target layer, generally used as a penalty. | ||
|
||
Args: | ||
target (nn.Module): The layer to optimize for. | ||
constant (float): Constant threshold to deduct from the activations. | ||
Defaults to 0. | ||
batch_index (int, optional): The index of the image to optimize if we | ||
optimizing a batch of images. If unspecified, defaults to all images | ||
in the batch. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -307,6 +371,15 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | |
class L2(BaseLoss): | ||
""" | ||
L2 norm of the target layer, generally used as a penalty. | ||
|
||
Args: | ||
target (nn.Module): The layer to optimize for. | ||
constant (float): Constant threshold to deduct from the activations. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The doc should be listed as optional, as there is a default value: |
||
Defaults to 0. | ||
epsilon (float): Small value to add to L2 prior to sqrt. Defaults to 1e-6. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the documentation format requires There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The epsilon doc is also missing the |
||
batch_index (int, optional): The index of the image to optimize if we | ||
optimizing a batch of images. If unspecified, defaults to all images | ||
in the batch. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -334,6 +407,14 @@ class Diversity(BaseLoss): | |
Use a cosine similarity penalty to extract features from a polysemantic neuron. | ||
Olah, Mordvintsev & Schubert, 2017. | ||
https://distill.pub/2017/feature-visualization/#diversity | ||
This loss helps break up polysemantic layers, channels, and neurons by encouraging | ||
diversity across the different batches. This loss is to be used along with a main | ||
loss. | ||
|
||
Args: | ||
target (nn.Module): The layer to optimize for. | ||
batch_index (int, optional): Unused here since we are optimizing for diversity | ||
across the batch. | ||
""" | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
|
@@ -359,6 +440,16 @@ class ActivationInterpolation(BaseLoss): | |
Interpolate between two different layers & channels. | ||
Olah, Mordvintsev & Schubert, 2017. | ||
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons | ||
This loss helps to interpolate or mix visualizations from two activations (layer or | ||
channel) by interpolating a linear sum between the two activations. | ||
|
||
Args: | ||
target1 (nn.Module): The first layer to optimize for. | ||
channel_index1 (int): Index of channel in first layer to optimize. Defaults to | ||
all channels. | ||
target2 (nn.Module): The first layer to optimize for. | ||
channel_index2 (int): Index of channel in first layer to optimize. Defaults to | ||
all channels. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -410,6 +501,14 @@ class Alignment(BaseLoss): | |
similarity between them. | ||
Olah, Mordvintsev & Schubert, 2017. | ||
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons | ||
When interpolating between activations, it may be desirable to keep image landmarks | ||
in the same position for visual comparison. This loss helps to minimize L2 distance | ||
between neighbouring images. | ||
|
||
Args: | ||
target (nn.Module): The layer to optimize for. | ||
decay_ratio (float): How much to decay penalty as images move apart in batch. | ||
Defaults to 2. | ||
""" | ||
|
||
def __init__(self, target: nn.Module, decay_ratio: float = 2.0) -> None: | ||
|
@@ -438,6 +537,18 @@ class Direction(BaseLoss): | |
Visualize a general direction vector. | ||
Carter, et al., "Activation Atlas", Distill, 2019. | ||
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images | ||
This loss helps to visualize a specific vector direction in a layer, by maximizing | ||
the alignment between the input vector and the layer’s activation vector. The | ||
dimensionality of the vector should correspond to the number of channels in the | ||
layer. | ||
|
||
Args: | ||
target (nn.Module): The layer to optimize for. | ||
vec (torch.Tensor): Vector representing direction to align to. | ||
cossim_pow (float, optional): The desired cosine similarity power to use. | ||
batch_index (int, optional): The index of the image to optimize if we | ||
optimizing a batch of images. If unspecified, defaults to all images | ||
in the batch. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -464,6 +575,23 @@ class NeuronDirection(BaseLoss): | |
Visualize a single (x, y) position for a direction vector. | ||
Carter, et al., "Activation Atlas", Distill, 2019. | ||
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images | ||
Extends Direction loss by focusing on visualizing a single neuron within the | ||
kernel. | ||
|
||
Args: | ||
target (nn.Module): The layer to optimize for. | ||
vec (torch.Tensor): Vector representing direction to align to. | ||
x (int, optional): The x coordinate of the neuron to optimize for. If | ||
unspecified, defaults to center, or one unit left of center for even | ||
lengths. | ||
y (int, optional): The y coordinate of the neuron to optimize for. If | ||
unspecified, defaults to center, or one unit up of center for even | ||
heights. | ||
Comment on lines
+584
to
+589
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These docs are missing |
||
channel_index (int): The index of the channel to optimize for. | ||
cossim_pow (float, optional): The desired cosine similarity power to use. | ||
Comment on lines
+590
to
+591
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These docs are missing their |
||
batch_index (int, optional): The index of the image to optimize if we | ||
optimizing a batch of images. If unspecified, defaults to all images | ||
in the batch. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -505,6 +633,15 @@ class TensorDirection(BaseLoss): | |
Visualize a tensor direction vector. | ||
Carter, et al., "Activation Atlas", Distill, 2019. | ||
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images | ||
Extends Direction loss by allowing batch-wise direction visualization. | ||
|
||
Args: | ||
target (nn.Module): The layer to optimize for. | ||
vec (torch.Tensor): Vector representing direction to align to. | ||
cossim_pow (float, optional): The desired cosine similarity power to use. | ||
batch_index (int, optional): The index of the image to optimize if we | ||
optimizing a batch of images. If unspecified, defaults to all images | ||
in the batch. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -542,6 +679,23 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | |
class ActivationWeights(BaseLoss): | ||
""" | ||
Apply weights to channels, neurons, or spots in the target. | ||
This loss weighs specific channels or neurons in a given layer, via a weight | ||
vector. | ||
|
||
Args: | ||
target (nn.Module): The layer to optimize for. | ||
weights (torch.Tensor): Weights to apply to targets. | ||
neuron (bool): Whether target is a neuron. Defaults to False. | ||
x (int, optional): The x coordinate of the neuron to optimize for. If | ||
unspecified, defaults to center, or one unit left of center for even | ||
lengths. | ||
y (int, optional): The y coordinate of the neuron to optimize for. If | ||
unspecified, defaults to center, or one unit up of center for even | ||
heights. | ||
wx (int, optional): Length of neurons to apply the weights to, along the | ||
x-axis. | ||
wy (int, optional): Length of neurons to apply the weights to, along the | ||
y-axis. | ||
Comment on lines
+689
to
+698
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These docs are all missing |
||
""" | ||
|
||
def __init__( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
batch_index
documentation is missingDefault: None
below the description for all of the batch_index docs I think.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that the batch index docs should look like this: