Skip to content

Commit e40da13

Browse files
authored
Add tests for atlas functions & random rotation transform
1 parent 86cbfed commit e40da13

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

tests/optim/param/test_transforms.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,48 @@ def test_random_scale_matrix(self) -> None:
7272
)
7373

7474

75+
class TestRandomRotation(BaseTest):
76+
def test_random_rotation_degrees(self) -> None:
77+
test_degrees = [0.0, 1.0, 2.0, 3.0, 4.0]
78+
rot_mod = RandomRotation(test_degrees)
79+
degrees = rot_mod.degrees
80+
self.assertTrue(hasattr(degrees, "__iter__"))
81+
self.assertEqual(degrees, test_degrees)
82+
83+
def test_random_rotation_matrix(self) -> None:
84+
theta = 25.1
85+
theta = theta * 3.141592653589793 / 180
86+
rot_mod = RandomRotation([25.1])
87+
rot_matrix = rot_mod.get_rot_mat(
88+
theta, device=torch.device("cpu"), dtype=torch.float32
89+
)
90+
expected_matrix = torch.tensor(
91+
[[0.9056, -0.4242, 0.0000], [0.4242, 0.9056, 0.0000]]
92+
)
93+
94+
assertTensorAlmostEqual(self, rot_matrix, expected_matrix)
95+
96+
def test_random_rotation_rotate_tensor(self) -> None:
97+
rot_mod = RandomRotation([25.0])
98+
99+
test_input = torch.eye(4, 4).repeat(3, 1, 1).unsqueeze(0)
100+
test_output = rot_mod.rotate_tensor(test_input, 25.0)
101+
102+
expected_output = (
103+
torch.tensor(
104+
[
105+
[0.1143, 0.0000, 0.0000, 0.0000],
106+
[0.5258, 0.6198, 0.2157, 0.0000],
107+
[0.0000, 0.2157, 0.6198, 0.5258],
108+
[0.0000, 0.0000, 0.0000, 0.1143],
109+
]
110+
)
111+
.repeat(3, 1, 1)
112+
.unsqueeze(0)
113+
)
114+
assertTensorAlmostEqual(self, test_output, expected_output)
115+
116+
75117
class TestRandomSpatialJitter(BaseTest):
76118
def test_random_spatial_jitter_hw(self) -> None:
77119
translate_vals = [4, 4]

tests/optim/utils/test_atlas.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#!/usr/bin/env python3
2+
import unittest
3+
4+
import torch
5+
6+
import captum.optim._utils.atlas as atlas
7+
from tests.helpers.basic import BaseTest, assertTensorAlmostEqual
8+
9+
10+
class TestNormalizeGrid(BaseTest):
11+
def test_normalize_grid(self) -> None:
12+
x = torch.arange(0, 2 * 3 * 3).view(3 * 3, 2).float()
13+
14+
x_out = atlas.normalize_grid(x)
15+
16+
x_expected = torch.tensor(
17+
[
18+
[0.0000, 0.0000],
19+
[0.1250, 0.1250],
20+
[0.2500, 0.2500],
21+
[0.3750, 0.3750],
22+
[0.5000, 0.5000],
23+
[0.6250, 0.6250],
24+
[0.7500, 0.7500],
25+
[0.8750, 0.8750],
26+
[1.0000, 1.0000],
27+
]
28+
)
29+
30+
assertTensorAlmostEqual(self, x_out, x_expected)
31+
32+
33+
class TestGridIndices(BaseTest):
34+
def test_grid_indices(self) -> None:
35+
x = torch.arange(0, 2 * 3 * 3).view(3 * 3, 2).float()
36+
x = atlas.normalize_grid(x)
37+
x_indices = atlas.grid_indices(x, size=(2, 2))
38+
39+
expected_indices = [
40+
[torch.tensor([0, 1, 2, 3, 4]), torch.tensor([4])],
41+
[torch.tensor([4]), torch.tensor([4, 5, 6, 7, 8])],
42+
]
43+
44+
for list1, list2 in zip(x_indices, expected_indices):
45+
for t1, t2 in zip(list1, list2):
46+
assertTensorAlmostEqual(self, t1, t2)
47+
48+
49+
class TestExtractGridVectors(BaseTest):
50+
def test_extract_grid_vectors(self) -> None:
51+
x_raw = torch.arange(0, 4 * 3 * 3).view(3 * 3, 4).float()
52+
x = torch.arange(0, 2 * 3 * 3).view(3 * 3, 2).float()
53+
x = atlas.normalize_grid(x)
54+
x_indices = atlas.grid_indices(x, size=(2, 2))
55+
56+
x_vecs, vec_coords = atlas.extract_grid_vectors(
57+
x_indices, x_raw, size=(2, 2), min_density=2
58+
)
59+
60+
expected_vecs = torch.tensor([[8.0, 9.0, 10.0, 11.0], [24.0, 25.0, 26.0, 27.0]])
61+
expected_coords = [(0, 0), (1, 1)]
62+
63+
assertTensorAlmostEqual(self, x_vecs, expected_vecs)
64+
self.assertEqual(vec_coords, expected_coords)
65+
66+
67+
class TestCreateAtlasVectors(BaseTest):
68+
def test_create_atlas_vectors(self) -> None:
69+
x_raw = torch.arange(0, 4 * 3 * 3).view(3 * 3, 4).float()
70+
x = torch.arange(0, 2 * 3 * 3).view(3 * 3, 2).float()
71+
x_vecs, vec_coords = atlas.create_atlas_vectors(
72+
x, x_raw, size=(2, 2), min_density=2, normalize=True
73+
)
74+
75+
expected_vecs = torch.tensor([[8.0, 9.0, 10.0, 11.0], [24.0, 25.0, 26.0, 27.0]])
76+
expected_coords = [(0, 0), (1, 1)]
77+
78+
assertTensorAlmostEqual(self, x_vecs, expected_vecs)
79+
self.assertEqual(vec_coords, expected_coords)
80+
81+
82+
class TestCreateAtlas(BaseTest):
83+
def test_create_atlas(self) -> None:
84+
img_list = [torch.ones(1, 3, 4, 4)] * 2
85+
expected_coords = [(0, 0), (1, 1)]
86+
canvas = atlas.create_atlas(img_list, expected_coords, grid_size=(2, 2))
87+
assertTensorAlmostEqual(self, canvas, torch.ones_like(canvas))
88+
89+
90+
if __name__ == "__main__":
91+
unittest.main()

0 commit comments

Comments
 (0)