Skip to content

Commit 740fcde

Browse files
authored
Add Activation Atlas tutorial & functions
1 parent d09a953 commit 740fcde

File tree

6 files changed

+669
-1
lines changed

6 files changed

+669
-1
lines changed

captum/optim/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from captum.optim._param.image import images # noqa: F401
77
from captum.optim._param.image import transform # noqa: F401
88
from captum.optim._param.image.images import ImageTensor # noqa: F401
9-
from captum.optim._utils import circuits, models, reducer # noqa: F401
9+
from captum.optim._utils import atlas, circuits, models, reducer # noqa: F401
10+
from captum.optim._utils.image import dataset # noqa: F401
1011
from captum.optim._utils.image.common import nchannels_to_rgb # noqa: F401
1112
from captum.optim._utils.image.common import weights_to_heatmap_2d # noqa: F401

captum/optim/_param/image/transform.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,53 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
267267
return self.translate_tensor(input, insets)
268268

269269

270+
class RandomRotation(nn.Module):
271+
"""
272+
Apply random rotation transforms on a NCHW tensor.
273+
Arguments:
274+
degrees (float, sequence): Tuple of degrees to randomly select from.
275+
"""
276+
277+
def __init__(
278+
self, degrees: Union[List[float], Tuple[float, ...], torch.Tensor]
279+
) -> None:
280+
super().__init__()
281+
assert hasattr(degrees, "__iter__")
282+
self.degrees = degrees
283+
284+
def get_rot_mat(
285+
self,
286+
theta: Union[int, float, torch.Tensor],
287+
device: torch.device,
288+
dtype: torch.dtype,
289+
) -> torch.Tensor:
290+
theta = torch.tensor(theta, device=device, dtype=dtype)
291+
rot_mat = torch.tensor(
292+
[
293+
[torch.cos(theta), -torch.sin(theta), 0],
294+
[torch.sin(theta), torch.cos(theta), 0],
295+
],
296+
device=device,
297+
dtype=dtype,
298+
)
299+
return rot_mat
300+
301+
def rotate_tensor(
302+
self, x: torch.Tensor, theta: Union[int, float, torch.Tensor]
303+
) -> torch.Tensor:
304+
theta = theta * 3.141592653589793 / 180
305+
rot_matrix = self.get_rot_mat(theta, x.device, x.dtype)[None, ...].repeat(
306+
x.shape[0], 1, 1
307+
)
308+
grid = F.affine_grid(rot_matrix, x.size())
309+
x = F.grid_sample(x, grid)
310+
return x
311+
312+
def forward(self, x: torch.Tensor) -> torch.Tensor:
313+
rotate_angle = rand_select(self.degrees)
314+
return self.rotate_tensor(x, rotate_angle)
315+
316+
270317
class ScaleInputRange(nn.Module):
271318
"""
272319
Multiplies the input by a specified multiplier for models with input ranges other

captum/optim/_utils/atlas.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from typing import List, Tuple
2+
3+
import torch
4+
5+
6+
def grid_indices(
7+
tensor: torch.Tensor,
8+
size: Tuple[int, int] = (8, 8),
9+
x_extent: Tuple[float, float] = (0.0, 1.0),
10+
y_extent: Tuple[float, float] = (0.0, 1.0),
11+
) -> List[List[torch.Tensor]]:
12+
"""
13+
Create grid cells of a specified size for an irregular grid.
14+
"""
15+
16+
assert tensor.dim() == 2 and tensor.size(1) == 2
17+
x_coords = ((tensor[:, 0] - x_extent[0]) / (x_extent[1] - x_extent[0])) * size[1]
18+
y_coords = ((tensor[:, 1] - y_extent[0]) / (y_extent[1] - y_extent[0])) * size[0]
19+
20+
x_list = []
21+
for x in range(size[1]):
22+
y_list = []
23+
for y in range(size[0]):
24+
in_bounds_x = torch.logical_and(x <= x_coords, x_coords <= x + 1)
25+
in_bounds_y = torch.logical_and(y <= y_coords, y_coords <= y + 1)
26+
in_bounds_indices = torch.where(
27+
torch.logical_and(in_bounds_x, in_bounds_y)
28+
)[0]
29+
y_list.append(in_bounds_indices)
30+
x_list.append(y_list)
31+
return x_list
32+
33+
34+
def normalize_grid(
35+
x: torch.Tensor,
36+
min_percentile: float = 0.01,
37+
max_percentile: float = 0.99,
38+
relative_margin: float = 0.1,
39+
) -> torch.Tensor:
40+
"""
41+
Remove outliers and rescale grid to [0,1].
42+
"""
43+
44+
assert x.dim() == 2 and x.size(1) == 2
45+
mins = torch.quantile(x, min_percentile, dim=0)
46+
maxs = torch.quantile(x, max_percentile, dim=0)
47+
48+
# add margins
49+
mins = mins - relative_margin * (maxs - mins)
50+
maxs = maxs + relative_margin * (maxs - mins)
51+
52+
clipped = torch.max(torch.min(x, maxs), mins)
53+
clipped = clipped - clipped.min(0)[0]
54+
return clipped / clipped.max(0)[0]
55+
56+
57+
def extract_grid_vectors(
58+
grid: List[List[torch.Tensor]],
59+
activations: torch.Tensor,
60+
size: Tuple[int, int] = (8, 8),
61+
min_density: int = 8,
62+
) -> Tuple[torch.Tensor, List[Tuple[int, int]]]:
63+
"""
64+
Create direction vectors.
65+
"""
66+
67+
cell_coords = []
68+
average_activations = []
69+
for x in range(size[1]):
70+
for y in range(size[0]):
71+
indices = grid[x][y]
72+
if len(indices) >= min_density:
73+
average_activations.append(torch.mean(activations[indices], 0))
74+
cell_coords.append((x, y))
75+
return torch.stack(average_activations), cell_coords
76+
77+
78+
def create_atlas_vectors(
79+
tensor: torch.Tensor,
80+
activations: torch.Tensor,
81+
size: Tuple[int, int] = (8, 8),
82+
min_density: int = 8,
83+
normalize: bool = True,
84+
) -> Tuple[torch.Tensor, List[Tuple[int, int]]]:
85+
"""
86+
Create direction vectors by splitting an irregular grid into cells.
87+
"""
88+
89+
assert tensor.dim() == 2 and tensor.size(1) == 2
90+
if normalize:
91+
tensor = normalize_grid(tensor)
92+
indices = grid_indices(tensor, size)
93+
grid_vecs, vec_coords = extract_grid_vectors(
94+
indices, activations, size, min_density
95+
)
96+
return grid_vecs, vec_coords
97+
98+
99+
def create_atlas(
100+
cells: List[torch.Tensor],
101+
coords: List[List[torch.Tensor]],
102+
grid_size: Tuple[int, int] = (8, 8),
103+
) -> torch.Tensor:
104+
cell_h, cell_w = cells[0].shape[2:]
105+
canvas = torch.ones(1, 3, cell_h * grid_size[0], cell_w * grid_size[1])
106+
for i, img in enumerate(cells):
107+
y = int(coords[i][0])
108+
x = int(coords[i][1])
109+
canvas[
110+
...,
111+
(grid_size[0] - x - 1) * cell_h : (grid_size[0] - x) * cell_h,
112+
y * cell_w : (y + 1) * cell_w,
113+
] = img
114+
return canvas

captum/optim/_utils/image/dataset.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
from typing import Dict, List, Optional
2+
13
import torch
24

5+
from captum.optim._utils.models import collect_activations
6+
from captum.optim._utils.typing import ModuleOutputMapping
7+
38

49
def image_cov(tensor: torch.Tensor) -> torch.Tensor:
510
"""
@@ -51,3 +56,67 @@ def dataset_klt_matrix(
5156

5257
cov_mtx = dataset_cov_matrix(loader)
5358
return cov_matrix_to_klt(cov_mtx, normalize)
59+
60+
61+
def capture_activation_samples(
62+
loader: torch.utils.data.DataLoader,
63+
model,
64+
targets: List[torch.nn.Module],
65+
target_names: List[str],
66+
num_samples: Optional[int] = None,
67+
input_device: torch.device = torch.device("cpu"),
68+
) -> ModuleOutputMapping:
69+
"""
70+
Create a dict of randomly sampled activations for an image dataset.
71+
72+
Args:
73+
loader (torch.utils.data.DataLoader): A torch.utils.data.DataLoader
74+
instance for an image dataset.
75+
model (nn.Module): A PyTorch model instance.
76+
targets (list of nn.Module): A list of layers to sample activations
77+
from.
78+
target_names (list of str): A list of names to use for the layers
79+
to targets in the output dict.
80+
num_samples (int): How many samples to collect. Default is to collect
81+
all samples.
82+
input_device (torch.device): The device to use for model inputs.
83+
Returns:
84+
activation_dict (dict of tensor): A dictionary containing the sampled
85+
dataset activations, with the target_names as the keys.
86+
"""
87+
88+
def random_sample(activations: torch.Tensor) -> torch.Tensor:
89+
"""
90+
Randomly sample H & W dimensions of activations with 4 dimensions.
91+
"""
92+
93+
rnd_samples = []
94+
for b in range(activations.size(0)):
95+
if activations.dim() == 4:
96+
h, w = activations.shape[2:]
97+
y = torch.randint(low=1, high=h, size=[1])
98+
x = torch.randint(low=1, high=w, size=[1])
99+
activ = activations[b, :, y, x]
100+
elif activations.dim() == 2:
101+
activ = activations[b].unsqueeze(1)
102+
rnd_samples.append(activ)
103+
return torch.cat(rnd_samples, 1).permute(1, 0)
104+
105+
assert len(target_names) == len(targets)
106+
activation_dict: Dict = {k: [] for k in dict.fromkeys(target_names).keys()}
107+
108+
sample_count = 0
109+
with torch.no_grad():
110+
for inputs, _ in loader:
111+
inputs = inputs.to(input_device)
112+
target_activ_dict = collect_activations(model, targets, inputs)
113+
for t in target_activ_dict.keys():
114+
target_activ_dict[t] = [random_sample(target_activ_dict[t])]
115+
activation_dict = {
116+
k: activation_dict[k] + target_activ_dict[k] for k in activation_dict
117+
}
118+
sample_count += inputs.size(0)
119+
if num_samples is not None:
120+
if sample_count > num_samples:
121+
return {k: torch.cat(activation_dict[k]) for k in activation_dict}
122+
return {k: torch.cat(activation_dict[k]) for k in activation_dict}

tests/optim/utils/image/dataset.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
import captum.optim._utils.image.dataset as dataset_utils
7+
from captum.optim._models.inception_v1 import googlenet
78
from tests.helpers.basic import (
89
BaseTest,
910
assertArraysAlmostEqual,
@@ -105,5 +106,28 @@ def create_tensor() -> torch.Tensor:
105106
assertTensorAlmostEqual(self, klt_transform, expected_mtx)
106107

107108

109+
class TestCaptureActivationSamples(BaseTest):
110+
def test_capture_activation_samples(self) -> None:
111+
if torch.__version__ == "1.2.0":
112+
raise unittest.SkipTest(
113+
"Skipping capture_activation_samples test due to"
114+
+ "insufficient Torch version."
115+
)
116+
117+
num_tensors = 10
118+
dataset_tensors = [torch.ones(3, 224, 224) for x in range(num_tensors)]
119+
test_dataset = dataset_helpers.ImageTestDataset(dataset_tensors)
120+
dataset_loader = torch.utils.data.DataLoader(
121+
test_dataset, batch_size=10, num_workers=0, shuffle=False
122+
)
123+
model = googlenet(pretrained=True)
124+
targets = [model.mixed4c]
125+
target_names = ["mixed4c"]
126+
activation_dict = dataset_utils.capture_activation_samples(
127+
dataset_loader, model, targets, target_names
128+
)
129+
self.assertEqual(list(activation_dict["mixed4c"].shape), [num_tensors, 512])
130+
131+
108132
if __name__ == "__main__":
109133
unittest.main()

0 commit comments

Comments
 (0)