-
Notifications
You must be signed in to change notification settings - Fork 542
Optim wip - Move & restructure loss objectives #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2b644af
2fe480c
fd19f3d
1471942
ed40559
7c28ba1
f9af2bf
df979a1
dbeb667
4f54d5a
abdf2ba
e25b61e
de92824
5edd341
6a36378
4e8f079
8012d7b
1041392
540d3a1
c99bdd6
864ca58
cbe6545
804609c
2de31a7
21fa5bf
4428fd3
d172187
aeefa27
d6cbe49
9ba374a
3c0697a
8d8a9a0
f778b2e
8c0f376
1cf9bdb
5043673
eeb3939
ff101ef
b216517
f38dc86
c869dab
818c57c
3035a20
4b0df62
1af18d9
c246313
78319dc
66e5742
759a27b
67a948d
184f21a
de0baeb
db5a04e
159cf9f
807ea81
116bce1
d2e35f8
74b4b01
b3cffe1
d49757b
958c8b3
e8eaf0e
36a67e7
95dbca3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,371 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from captum.optim._utils.images import get_neuron_pos | ||
from captum.optim._utils.typing import ModuleOutputMapping | ||
|
||
|
||
class Loss(ABC): | ||
""" | ||
Abstract Class to describe loss. | ||
""" | ||
|
||
def __init__(self, target: nn.Module) -> None: | ||
super(Loss, self).__init__() | ||
self.target = target | ||
|
||
@abstractmethod | ||
def __call__(self, targets_to_values: ModuleOutputMapping): | ||
pass | ||
|
||
|
||
class LayerActivation(Loss): | ||
""" | ||
Maximize activations at the target layer. | ||
""" | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
return targets_to_values[self.target] | ||
|
||
|
||
class ChannelActivation(Loss): | ||
""" | ||
Maximize activations at the target layer and target channel. | ||
""" | ||
|
||
def __init__(self, target: nn.Module, channel_index: int) -> None: | ||
super(Loss, self).__init__() | ||
self.target = target | ||
self.channel_index = channel_index | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations = targets_to_values[self.target] | ||
assert activations is not None | ||
# ensure channel_index is valid | ||
assert self.channel_index < activations.shape[1] | ||
# assume NCHW | ||
# NOTE: not necessarily true e.g. for Linear layers | ||
# assert len(activations.shape) == 4 | ||
return activations[:, self.channel_index, ...] | ||
|
||
|
||
class NeuronActivation(Loss): | ||
def __init__( | ||
self, | ||
target: nn.Module, | ||
channel_index: int, | ||
x: Optional[int] = None, | ||
y: Optional[int] = None, | ||
) -> None: | ||
super(Loss, self).__init__() | ||
self.target = target | ||
self.channel_index = channel_index | ||
self.x = x | ||
self.y = y | ||
|
||
# ensure channel_index will be valid | ||
assert self.channel_index < self.target.out_channels | ||
|
||
def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations = targets_to_values[self.target] | ||
assert activations is not None | ||
assert len(activations.shape) == 4 # assume NCHW | ||
_x, _y = get_neuron_pos( | ||
activations.size(2), activations.size(3), self.x, self.y | ||
) | ||
|
||
return activations[:, self.channel_index, _x, _y] | ||
|
||
|
||
class DeepDream(Loss): | ||
""" | ||
Maximize 'interestingness' at the target layer. | ||
Mordvintsev et al., 2015. | ||
""" | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations = targets_to_values[self.target] | ||
return activations ** 2 | ||
|
||
|
||
class TotalVariation(Loss): | ||
""" | ||
Total variation denoising penalty for activations. | ||
See Mahendran, V. 2014. Understanding Deep Image Representations by Inverting Them. | ||
https://arxiv.org/abs/1412.0035 | ||
""" | ||
|
||
def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations = targets_to_values[self.target] | ||
x_diff = activations[..., 1:, :] - activations[..., :-1, :] | ||
y_diff = activations[..., :, 1:] - activations[..., :, :-1] | ||
return torch.sum(torch.abs(x_diff)) + torch.sum(torch.abs(y_diff)) | ||
|
||
|
||
class L1(Loss): | ||
""" | ||
L1 norm of the target layer, generally used as a penalty. | ||
""" | ||
|
||
def __init__(self, target: nn.Module, constant: float = 0.0) -> None: | ||
super(Loss, self).__init__() | ||
self.target = target | ||
self.constant = constant | ||
|
||
def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations = targets_to_values[self.target] | ||
return torch.abs(activations - self.constant).sum() | ||
|
||
|
||
class L2(Loss): | ||
""" | ||
L2 norm of the target layer, generally used as a penalty. | ||
""" | ||
|
||
def __init__( | ||
self, target: nn.Module, constant: float = 0.0, epsilon: float = 1e-6 | ||
) -> None: | ||
self.target = target | ||
self.constant = constant | ||
self.epsilon = epsilon | ||
|
||
def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations = targets_to_values[self.target] | ||
activations = (activations - self.constant).sum() | ||
return torch.sqrt(self.epsilon + activations) | ||
|
||
|
||
class Diversity(Loss): | ||
""" | ||
Use a cosine similarity penalty to extract features from a polysemantic neuron. | ||
Olah, Mordvintsev & Schubert, 2017. | ||
https://distill.pub/2017/feature-visualization/#diversity | ||
""" | ||
|
||
def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations = targets_to_values[self.target] | ||
return -sum( | ||
[ | ||
sum( | ||
[ | ||
( | ||
torch.cosine_similarity( | ||
activations[j].view(1, -1), activations[i].view(1, -1) | ||
) | ||
).sum() | ||
for i in range(activations.size(0)) | ||
if i != j | ||
] | ||
) | ||
for j in range(activations.size(0)) | ||
] | ||
) / activations.size(0) | ||
|
||
|
||
class ActivationInterpolation(Loss): | ||
""" | ||
Interpolate between two different layers & channels. | ||
Olah, Mordvintsev & Schubert, 2017. | ||
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons | ||
""" | ||
|
||
def __init__( | ||
self, | ||
target1: nn.Module = None, | ||
channel_index1: int = -1, | ||
target2: nn.Module = None, | ||
channel_index2: int = -1, | ||
) -> None: | ||
super(Loss, self).__init__() | ||
self.target_one = target1 | ||
self.channel_index_one = channel_index1 | ||
self.target_two = target2 | ||
self.channel_index_two = channel_index2 | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations_one = targets_to_values[self.target_one] | ||
activations_two = targets_to_values[self.target_two] | ||
|
||
assert activations_one is not None and activations_two is not None | ||
# ensure channel indices are valid | ||
assert ( | ||
self.channel_index_one < activations_one.shape[1] | ||
and self.channel_index_two < activations_two.shape[1] | ||
) | ||
assert activations_one.size(0) == activations_two.size(0) | ||
|
||
if self.channel_index_one > -1: | ||
activations_one = activations_one[:, self.channel_index_one] | ||
if self.channel_index_two > -1: | ||
activations_two = activations_two[:, self.channel_index_two] | ||
B = activations_one.size(0) | ||
|
||
batch_weights = torch.arange(B, device=activations_one.device) / (B - 1) | ||
sum_tensor = torch.zeros(1, device=activations_one.device) | ||
for n in range(B): | ||
sum_tensor = ( | ||
sum_tensor + ((1 - batch_weights[n]) * activations_one[n]).mean() | ||
) | ||
sum_tensor = sum_tensor + (batch_weights[n] * activations_two[n]).mean() | ||
return sum_tensor | ||
|
||
|
||
class Alignment(Loss): | ||
""" | ||
Penalize the L2 distance between tensors in the batch to encourage visual | ||
similarity between them. | ||
Olah, Mordvintsev & Schubert, 2017. | ||
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons | ||
""" | ||
|
||
def __init__(self, target: nn.Module, decay_ratio: float = 2.0) -> None: | ||
super(Loss, self).__init__() | ||
self.target = target | ||
self.decay_ratio = decay_ratio | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations = targets_to_values[self.target] | ||
B = activations.size(0) | ||
|
||
sum_tensor = torch.zeros(1, device=activations.device) | ||
for d in [1, 2, 3, 4]: | ||
for i in range(B - d): | ||
a, b = i, i + d | ||
activ_a, activ_b = activations[a], activations[b] | ||
sum_tensor = sum_tensor + ( | ||
(activ_a - activ_b) ** 2 | ||
).mean() / self.decay_ratio ** float(d) | ||
|
||
return sum_tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: because we are optimizing negative loss do we want to return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently the InputOptimization class doesn't support losses like Alignment. So, I've left them as they are for @greentfrapp to change, as he's currently working on the optimization system. |
||
|
||
|
||
class Direction(Loss): | ||
""" | ||
Visualize a general direction vector. | ||
Carter, et al., "Activation Atlas", Distill, 2019. | ||
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images | ||
""" | ||
|
||
def __init__(self, target: nn.Module, vec: torch.Tensor) -> None: | ||
super(Loss, self).__init__() | ||
self.target = target | ||
self.direction = vec.reshape((1, -1, 1, 1)) | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations = targets_to_values[self.target] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think that it would be good to assert the sizes of self.direction and activations before computing cosine similarity on them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added an assert for the channel size of self.direction & activations. @greentfrapp may be able to add more assertion tests in his upcoming PR. |
||
assert activations.size(1) == self.direction.size(1) | ||
return torch.cosine_similarity(self.direction, activations) | ||
|
||
|
||
class DirectionNeuron(Loss): | ||
""" | ||
Visualize a single (x, y) position for a direction vector. | ||
Carter, et al., "Activation Atlas", Distill, 2019. | ||
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images | ||
""" | ||
|
||
def __init__( | ||
self, | ||
target: nn.Module, | ||
vec: torch.Tensor, | ||
channel_index: int, | ||
x: Optional[int] = None, | ||
y: Optional[int] = None, | ||
) -> None: | ||
super(Loss, self).__init__() | ||
self.target = target | ||
self.direction = vec.reshape((1, -1, 1, 1)) | ||
self.channel_index = channel_index | ||
self.x = x | ||
self.y = y | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations = targets_to_values[self.target] | ||
|
||
assert activations.dim() == 4 | ||
|
||
_x, _y = get_neuron_pos( | ||
activations.size(2), activations.size(3), self.x, self.y | ||
) | ||
activations = activations[:, self.channel_index, _x, _y] | ||
return torch.cosine_similarity(self.direction, activations[None, None, None]) | ||
|
||
|
||
class TensorDirection(Loss): | ||
""" | ||
Visualize a tensor direction vector. | ||
Carter, et al., "Activation Atlas", Distill, 2019. | ||
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images | ||
""" | ||
|
||
def __init__(self, target: nn.Module, vec: torch.Tensor) -> None: | ||
super(Loss, self).__init__() | ||
self.target = target | ||
self.direction = vec | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations = targets_to_values[self.target] | ||
|
||
assert activations.dim() == 4 | ||
|
||
H_direction, W_direction = self.direction.size(2), self.direction.size(3) | ||
H_activ, W_activ = activations.size(2), activations.size(3) | ||
|
||
H = (H_activ - H_direction) // 2 | ||
W = (W_activ - W_direction) // 2 | ||
|
||
activations = activations[:, :, H : H + H_direction, W : W + W_direction] | ||
return torch.cosine_similarity(self.direction, activations) | ||
|
||
|
||
class ActivationWeights(Loss): | ||
""" | ||
Apply weights to channels, neurons, or spots in the target. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
target: nn.Module, | ||
weights: torch.Tensor = None, | ||
neuron: bool = False, | ||
x: Optional[int] = None, | ||
y: Optional[int] = None, | ||
wx: Optional[int] = None, | ||
wy: Optional[int] = None, | ||
) -> None: | ||
super(Loss, self).__init__() | ||
self.target = target | ||
self.x = x | ||
self.y = y | ||
self.wx = wx | ||
self.wy = wy | ||
self.weights = weights | ||
self.neuron = x is not None or y is not None or neuron | ||
assert ( | ||
wx is None | ||
and wy is None | ||
or wx is not None | ||
and wy is not None | ||
and x is not None | ||
and y is not None | ||
) | ||
|
||
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor: | ||
activations = targets_to_values[self.target] | ||
if self.neuron: | ||
assert activations.dim() == 4 | ||
if self.wx is None and self.wy is None: | ||
_x, _y = get_neuron_pos( | ||
activations.size(2), activations.size(3), self.x, self.y | ||
) | ||
activations = activations[..., _x, _y].squeeze() * self.weights | ||
else: | ||
activations = activations[ | ||
..., self.y : self.y + self.wy, self.x : self.x + self.wx | ||
] * self.weights.view(1, -1, 1, 1) | ||
else: | ||
activations = activations * self.weights.view(1, -1, 1, 1) | ||
return activations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this L1-norm w.r.t. total variation of height and width ?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the sum of the absolute differences for neighboring values in the activations or image. TensorFlow's version (that Lucid uses) links to this Wikipedia page: https://en.wikipedia.org/wiki/Total_variation_denoising
My neural-style-pt project uses basically the same algorithm, but it's origins trace back to this research article: https://arxiv.org/abs/1412.0035
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chris also says that TensorFlow's total variation algorithm comes from the Understanding Deep Image Representations by Inverting Them paper: https://arxiv.org/abs/1412.0035