Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
2b644af
Remove loss objectives
ProGamerGov Nov 13, 2020
2fe480c
Loss objectives moved to their own file
ProGamerGov Nov 13, 2020
fd19f3d
Fix some class names
ProGamerGov Nov 13, 2020
1471942
Update tutorials for loss changes
ProGamerGov Nov 13, 2020
ed40559
Add loss path
ProGamerGov Nov 13, 2020
7c28ba1
Oops
ProGamerGov Nov 13, 2020
f9af2bf
Merge pull request #11 from ProGamerGov/loss-class
ProGamerGov Nov 13, 2020
df979a1
Add function type hints
ProGamerGov Nov 14, 2020
dbeb667
Add init param to NaturalImage & other changes
ProGamerGov Nov 14, 2020
4f54d5a
Remove cuda call from LaplacianImage
ProGamerGov Nov 14, 2020
abdf2ba
Non abstract init
ProGamerGov Nov 14, 2020
e25b61e
Non abstract __init__ func
ProGamerGov Nov 14, 2020
de92824
Size var fixes
ProGamerGov Nov 14, 2020
5edd341
Fix error
ProGamerGov Nov 14, 2020
6a36378
Hopefully fix LaplacianImage
ProGamerGov Nov 14, 2020
4e8f079
Merge pull request #12 from ProGamerGov/init-param
ProGamerGov Nov 14, 2020
8012d7b
Linting
ProGamerGov Nov 14, 2020
1041392
Get user int images working
ProGamerGov Nov 15, 2020
540d3a1
Update images.py
ProGamerGov Nov 15, 2020
c99bdd6
Add init torch.inverse() func
ProGamerGov Nov 15, 2020
864ca58
Remove old decorrelate init func
ProGamerGov Nov 15, 2020
cbe6545
Fix func call
ProGamerGov Nov 15, 2020
804609c
Re-add H & W names
ProGamerGov Nov 15, 2020
2de31a7
Fix decorrelate init
ProGamerGov Nov 15, 2020
21fa5bf
Add type hints
ProGamerGov Nov 15, 2020
4428fd3
Merge pull request #13 from ProGamerGov/init-image
ProGamerGov Nov 15, 2020
d172187
Fix LaplacianImage init size issue
ProGamerGov Nov 15, 2020
aeefa27
Lint fix
ProGamerGov Nov 15, 2020
d6cbe49
Add functional import
ProGamerGov Nov 15, 2020
9ba374a
Add Direction objectives (#14)
ProGamerGov Nov 16, 2020
3c0697a
Comment clarification
ProGamerGov Nov 16, 2020
8d8a9a0
Comment fix
ProGamerGov Nov 16, 2020
f778b2e
Add Activation Interpolation
ProGamerGov Nov 16, 2020
8c0f376
Pip in test failed & lint fix
ProGamerGov Nov 16, 2020
1cf9bdb
Add layer objective
ProGamerGov Nov 16, 2020
5043673
Comment change
ProGamerGov Nov 17, 2020
eeb3939
Better func params
ProGamerGov Nov 17, 2020
ff101ef
Fix major bug
ProGamerGov Nov 17, 2020
b216517
Better error handling
ProGamerGov Nov 17, 2020
f38dc86
Add Alignment Objective
ProGamerGov Nov 17, 2020
c869dab
Add weight objective and dim checks
ProGamerGov Nov 18, 2020
818c57c
Lint fix
ProGamerGov Nov 18, 2020
3035a20
Add Batch Sizes To Image Parameterization
ProGamerGov Nov 19, 2020
4b0df62
Add missing type hint to batch setup func
ProGamerGov Nov 20, 2020
1af18d9
Add self to size values
ProGamerGov Nov 20, 2020
c246313
Remove redundant float()
ProGamerGov Nov 20, 2020
78319dc
Bug fix & ToRGB improvements
ProGamerGov Nov 20, 2020
66e5742
Full batch support for LaplacianImage
ProGamerGov Nov 20, 2020
759a27b
Properly handle LaplacianImage init batch
ProGamerGov Nov 20, 2020
67a948d
Type hints & new NaturalImage params
ProGamerGov Nov 20, 2020
184f21a
Add type hints to init & forward functions
ProGamerGov Nov 20, 2020
de0baeb
Lint fix
ProGamerGov Nov 20, 2020
db5a04e
More type hints
ProGamerGov Nov 21, 2020
159cf9f
Add more type hints
ProGamerGov Nov 21, 2020
807ea81
Merge conflict fixes
ProGamerGov Nov 21, 2020
116bce1
Fix mistake
ProGamerGov Nov 21, 2020
d2e35f8
Add missing type hint
ProGamerGov Nov 22, 2020
74b4b01
Add missing type hints
ProGamerGov Nov 22, 2020
b3cffe1
Add size func to ImageTensor
ProGamerGov Nov 22, 2020
d49757b
Add type hints to __torch_function__ functions
ProGamerGov Nov 23, 2020
958c8b3
Change things based on feedback
ProGamerGov Nov 24, 2020
e8eaf0e
Lint fix
ProGamerGov Nov 24, 2020
36a67e7
Ensure that direction and activation channels are the same
ProGamerGov Nov 24, 2020
95dbca3
Fix TotalVariation citation
ProGamerGov Nov 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions captum/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""optim submodule."""

from captum.optim._core import loss # noqa: F401
from captum.optim._core import objectives # noqa: F401
from captum.optim._core.objectives import InputOptimization # noqa: F401
from captum.optim._param.image import images # noqa: F401
Expand Down
371 changes: 371 additions & 0 deletions captum/optim/_core/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
from abc import ABC, abstractmethod
from typing import Optional

import torch
import torch.nn as nn

from captum.optim._utils.images import get_neuron_pos
from captum.optim._utils.typing import ModuleOutputMapping


class Loss(ABC):
"""
Abstract Class to describe loss.
"""

def __init__(self, target: nn.Module) -> None:
super(Loss, self).__init__()
self.target = target

@abstractmethod
def __call__(self, targets_to_values: ModuleOutputMapping):
pass


class LayerActivation(Loss):
"""
Maximize activations at the target layer.
"""

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
return targets_to_values[self.target]


class ChannelActivation(Loss):
"""
Maximize activations at the target layer and target channel.
"""

def __init__(self, target: nn.Module, channel_index: int) -> None:
super(Loss, self).__init__()
self.target = target
self.channel_index = channel_index

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
assert activations is not None
# ensure channel_index is valid
assert self.channel_index < activations.shape[1]
# assume NCHW
# NOTE: not necessarily true e.g. for Linear layers
# assert len(activations.shape) == 4
return activations[:, self.channel_index, ...]


class NeuronActivation(Loss):
def __init__(
self,
target: nn.Module,
channel_index: int,
x: Optional[int] = None,
y: Optional[int] = None,
) -> None:
super(Loss, self).__init__()
self.target = target
self.channel_index = channel_index
self.x = x
self.y = y

# ensure channel_index will be valid
assert self.channel_index < self.target.out_channels

def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
assert activations is not None
assert len(activations.shape) == 4 # assume NCHW
_x, _y = get_neuron_pos(
activations.size(2), activations.size(3), self.x, self.y
)

return activations[:, self.channel_index, _x, _y]


class DeepDream(Loss):
"""
Maximize 'interestingness' at the target layer.
Mordvintsev et al., 2015.
"""

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
return activations ** 2


class TotalVariation(Loss):
"""
Total variation denoising penalty for activations.
See Mahendran, V. 2014. Understanding Deep Image Representations by Inverting Them.
https://arxiv.org/abs/1412.0035
"""

def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
x_diff = activations[..., 1:, :] - activations[..., :-1, :]
y_diff = activations[..., :, 1:] - activations[..., :, :-1]
return torch.sum(torch.abs(x_diff)) + torch.sum(torch.abs(y_diff))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this L1-norm w.r.t. total variation of height and width ?

Copy link
Contributor Author

@ProGamerGov ProGamerGov Nov 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the sum of the absolute differences for neighboring values in the activations or image. TensorFlow's version (that Lucid uses) links to this Wikipedia page: https://en.wikipedia.org/wiki/Total_variation_denoising

My neural-style-pt project uses basically the same algorithm, but it's origins trace back to this research article: https://arxiv.org/abs/1412.0035

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chris also says that TensorFlow's total variation algorithm comes from the Understanding Deep Image Representations by Inverting Them paper: https://arxiv.org/abs/1412.0035



class L1(Loss):
"""
L1 norm of the target layer, generally used as a penalty.
"""

def __init__(self, target: nn.Module, constant: float = 0.0) -> None:
super(Loss, self).__init__()
self.target = target
self.constant = constant

def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
return torch.abs(activations - self.constant).sum()


class L2(Loss):
"""
L2 norm of the target layer, generally used as a penalty.
"""

def __init__(
self, target: nn.Module, constant: float = 0.0, epsilon: float = 1e-6
) -> None:
self.target = target
self.constant = constant
self.epsilon = epsilon

def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
activations = (activations - self.constant).sum()
return torch.sqrt(self.epsilon + activations)


class Diversity(Loss):
"""
Use a cosine similarity penalty to extract features from a polysemantic neuron.
Olah, Mordvintsev & Schubert, 2017.
https://distill.pub/2017/feature-visualization/#diversity
"""

def _call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
return -sum(
[
sum(
[
(
torch.cosine_similarity(
activations[j].view(1, -1), activations[i].view(1, -1)
)
).sum()
for i in range(activations.size(0))
if i != j
]
)
for j in range(activations.size(0))
]
) / activations.size(0)


class ActivationInterpolation(Loss):
"""
Interpolate between two different layers & channels.
Olah, Mordvintsev & Schubert, 2017.
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons
"""

def __init__(
self,
target1: nn.Module = None,
channel_index1: int = -1,
target2: nn.Module = None,
channel_index2: int = -1,
) -> None:
super(Loss, self).__init__()
self.target_one = target1
self.channel_index_one = channel_index1
self.target_two = target2
self.channel_index_two = channel_index2

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations_one = targets_to_values[self.target_one]
activations_two = targets_to_values[self.target_two]

assert activations_one is not None and activations_two is not None
# ensure channel indices are valid
assert (
self.channel_index_one < activations_one.shape[1]
and self.channel_index_two < activations_two.shape[1]
)
assert activations_one.size(0) == activations_two.size(0)

if self.channel_index_one > -1:
activations_one = activations_one[:, self.channel_index_one]
if self.channel_index_two > -1:
activations_two = activations_two[:, self.channel_index_two]
B = activations_one.size(0)

batch_weights = torch.arange(B, device=activations_one.device) / (B - 1)
sum_tensor = torch.zeros(1, device=activations_one.device)
for n in range(B):
sum_tensor = (
sum_tensor + ((1 - batch_weights[n]) * activations_one[n]).mean()
)
sum_tensor = sum_tensor + (batch_weights[n] * activations_two[n]).mean()
return sum_tensor


class Alignment(Loss):
"""
Penalize the L2 distance between tensors in the batch to encourage visual
similarity between them.
Olah, Mordvintsev & Schubert, 2017.
https://distill.pub/2017/feature-visualization/#Interaction-between-Neurons
"""

def __init__(self, target: nn.Module, decay_ratio: float = 2.0) -> None:
super(Loss, self).__init__()
self.target = target
self.decay_ratio = decay_ratio

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
B = activations.size(0)

sum_tensor = torch.zeros(1, device=activations.device)
for d in [1, 2, 3, 4]:
for i in range(B - d):
a, b = i, i + d
activ_a, activ_b = activations[a], activations[b]
sum_tensor = sum_tensor + (
(activ_a - activ_b) ** 2
).mean() / self.decay_ratio ** float(d)

return sum_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: because we are optimizing negative loss do we want to return sum_tensor ? I'm asking that because the losses are returning the opposite sign here for other losses as well:
https://github.com/greentfrapp/lucent/blob/master/lucent/optvis/objectives.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the InputOptimization class doesn't support losses like Alignment. So, I've left them as they are for @greentfrapp to change, as he's currently working on the optimization system.



class Direction(Loss):
"""
Visualize a general direction vector.
Carter, et al., "Activation Atlas", Distill, 2019.
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
"""

def __init__(self, target: nn.Module, vec: torch.Tensor) -> None:
super(Loss, self).__init__()
self.target = target
self.direction = vec.reshape((1, -1, 1, 1))

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think that it would be good to assert the sizes of self.direction and activations before computing cosine similarity on them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an assert for the channel size of self.direction & activations. @greentfrapp may be able to add more assertion tests in his upcoming PR.

assert activations.size(1) == self.direction.size(1)
return torch.cosine_similarity(self.direction, activations)


class DirectionNeuron(Loss):
"""
Visualize a single (x, y) position for a direction vector.
Carter, et al., "Activation Atlas", Distill, 2019.
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
"""

def __init__(
self,
target: nn.Module,
vec: torch.Tensor,
channel_index: int,
x: Optional[int] = None,
y: Optional[int] = None,
) -> None:
super(Loss, self).__init__()
self.target = target
self.direction = vec.reshape((1, -1, 1, 1))
self.channel_index = channel_index
self.x = x
self.y = y

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]

assert activations.dim() == 4

_x, _y = get_neuron_pos(
activations.size(2), activations.size(3), self.x, self.y
)
activations = activations[:, self.channel_index, _x, _y]
return torch.cosine_similarity(self.direction, activations[None, None, None])


class TensorDirection(Loss):
"""
Visualize a tensor direction vector.
Carter, et al., "Activation Atlas", Distill, 2019.
https://distill.pub/2019/activation-atlas/#Aggregating-Multiple-Images
"""

def __init__(self, target: nn.Module, vec: torch.Tensor) -> None:
super(Loss, self).__init__()
self.target = target
self.direction = vec

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]

assert activations.dim() == 4

H_direction, W_direction = self.direction.size(2), self.direction.size(3)
H_activ, W_activ = activations.size(2), activations.size(3)

H = (H_activ - H_direction) // 2
W = (W_activ - W_direction) // 2

activations = activations[:, :, H : H + H_direction, W : W + W_direction]
return torch.cosine_similarity(self.direction, activations)


class ActivationWeights(Loss):
"""
Apply weights to channels, neurons, or spots in the target.
"""

def __init__(
self,
target: nn.Module,
weights: torch.Tensor = None,
neuron: bool = False,
x: Optional[int] = None,
y: Optional[int] = None,
wx: Optional[int] = None,
wy: Optional[int] = None,
) -> None:
super(Loss, self).__init__()
self.target = target
self.x = x
self.y = y
self.wx = wx
self.wy = wy
self.weights = weights
self.neuron = x is not None or y is not None or neuron
assert (
wx is None
and wy is None
or wx is not None
and wy is not None
and x is not None
and y is not None
)

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
if self.neuron:
assert activations.dim() == 4
if self.wx is None and self.wy is None:
_x, _y = get_neuron_pos(
activations.size(2), activations.size(3), self.x, self.y
)
activations = activations[..., _x, _y].squeeze() * self.weights
else:
activations = activations[
..., self.y : self.y + self.wy, self.x : self.x + self.wx
] * self.weights.view(1, -1, 1, 1)
else:
activations = activations * self.weights.view(1, -1, 1, 1)
return activations
Loading