From 575101955b0d1af4704d15579342503b6307bff1 Mon Sep 17 00:00:00 2001 From: vedrenne Date: Thu, 23 May 2024 17:14:57 +0200 Subject: [PATCH 1/4] fix __getitem__ for classes inheriting Transform3d --- pytorch3d/transforms/transform3d.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index b2ee25937..444c86ce1 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -198,7 +198,9 @@ def __getitem__( """ if isinstance(index, int): index = [index] - return self.__class__(matrix=self.get_matrix()[index]) + instance = self.__class__.__new__(self.__class__) + instance._matrix = self.get_matrix()[index] + return instance def compose(self, *others: "Transform3d") -> "Transform3d": """ From f1ba05e4f1af7a980554a9db75686feaf641dd43 Mon Sep 17 00:00:00 2001 From: vedrenne Date: Tue, 28 May 2024 15:08:15 +0200 Subject: [PATCH 2/4] fix __getitem__ for classes inheriting Transform3d --- pytorch3d/transforms/transform3d.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index 444c86ce1..aa522538f 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -200,6 +200,8 @@ def __getitem__( index = [index] instance = self.__class__.__new__(self.__class__) instance._matrix = self.get_matrix()[index] + for attr in ('_transforms', '_lu', 'device', 'dtype'): + setattr(instance, attr, getattr(self, attr)) return instance def compose(self, *others: "Transform3d") -> "Transform3d": From ec0c7826e2a8d5140909b18d1ba0f8378e7be28e Mon Sep 17 00:00:00 2001 From: vedrenne Date: Mon, 10 Jun 2024 14:51:49 +0200 Subject: [PATCH 3/4] fix transforms indexing by implementing __getitem__ for each subclass --- pytorch3d/transforms/transform3d.py | 58 ++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index aa522538f..fef298455 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -198,11 +198,7 @@ def __getitem__( """ if isinstance(index, int): index = [index] - instance = self.__class__.__new__(self.__class__) - instance._matrix = self.get_matrix()[index] - for attr in ('_transforms', '_lu', 'device', 'dtype'): - setattr(instance, attr, getattr(self, attr)) - return instance + return self.__class__(matrix=self.get_matrix()[index]) def compose(self, *others: "Transform3d") -> "Transform3d": """ @@ -568,6 +564,22 @@ def _get_matrix_inverse(self) -> torch.Tensor: i_matrix = self._matrix * inv_mask return i_matrix + def __getitem__( + self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor] + ) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + return self.__class__(self.get_matrix()[index, 3, :3]) + class Scale(Transform3d): def __init__( @@ -617,6 +629,26 @@ def _get_matrix_inverse(self) -> torch.Tensor: imat = torch.diag_embed(ixyz, dim1=1, dim2=2) return imat + def __getitem__( + self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor] + ) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + mat = self.get_matrix()[index] + x = mat[:, 0, 0] + y = mat[:, 1, 1] + z = mat[:, 2, 2] + return self.__class__(x, y, z) + class Rotate(Transform3d): def __init__( @@ -659,6 +691,22 @@ def _get_matrix_inverse(self) -> torch.Tensor: """ return self._matrix.permute(0, 2, 1).contiguous() + def __getitem__( + self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor] + ) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + return self.__class__(self.get_matrix()[index, :3, :3]) + class RotateAxisAngle(Rotate): def __init__( From 23d3cb50bac9ad37475ba38096d680f8346abd5b Mon Sep 17 00:00:00 2001 From: vedrenne Date: Tue, 11 Jun 2024 09:11:32 +0200 Subject: [PATCH 4/4] add getitem tests for Transform3d subclasses --- tests/test_transforms.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 5a2d729f7..6851afbf0 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -685,6 +685,15 @@ def test_inverse(self): self.assertTrue(torch.allclose(im, im_comp)) self.assertTrue(torch.allclose(im, im_2)) + def test_get_item(self, batch_size=5): + device = torch.device("cuda:0") + xyz = torch.randn(size=[batch_size, 3], device=device, dtype=torch.float32) + t3d = Translate(xyz) + index = 1 + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 1) + self.assertIsInstance(t3d_selected, Translate) + class TestScale(unittest.TestCase): def test_single_python_scalar(self): @@ -871,6 +880,15 @@ def test_inverse(self): self.assertTrue(torch.allclose(im, im_comp)) self.assertTrue(torch.allclose(im, im_2)) + def test_get_item(self, batch_size=5): + device = torch.device("cuda:0") + s = torch.randn(size=[batch_size, 3], device=device, dtype=torch.float32) + t3d = Scale(s) + index = 1 + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 1) + self.assertIsInstance(t3d_selected, Scale) + class TestTransformBroadcast(unittest.TestCase): def test_broadcast_transform_points(self): @@ -986,6 +1004,15 @@ def test_inverse(self, batch_size=5): self.assertTrue(torch.allclose(im, im_comp, atol=1e-4)) self.assertTrue(torch.allclose(im, im_2, atol=1e-4)) + def test_get_item(self, batch_size=5): + device = torch.device("cuda:0") + r = random_rotations(batch_size, dtype=torch.float32, device=device) + t3d = Rotate(r) + index = 1 + t3d_selected = t3d[index] + self.assertEqual(len(t3d_selected), 1) + self.assertIsInstance(t3d_selected, Rotate) + class TestRotateAxisAngle(unittest.TestCase): def test_rotate_x_python_scalar(self):