From 618e0e963207b9dd60191f3a65e2ebe0a91dfcb5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 20 Oct 2022 15:09:33 +0100 Subject: [PATCH 1/3] Add PermuteDimensions and TransposeDimensions transforms --- test/test_prototype_transforms.py | 74 ++++++++++++++++++++ torchvision/prototype/transforms/__init__.py | 12 +++- torchvision/prototype/transforms/_misc.py | 52 ++++++++++++-- 3 files changed, 131 insertions(+), 7 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 5928e6718c1..9a2988c796d 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -18,6 +18,7 @@ make_masks, make_one_hot_labels, make_segmentation_mask, + make_video, make_videos, ) from torchvision.ops.boxes import box_iou @@ -1826,3 +1827,76 @@ def test_to_dtype(dtype, expected_dtypes): assert transformed_value.dtype is expected_dtypes[value_type] else: assert transformed_value is value + + +@pytest.mark.parametrize( + ("dims", "inverse_dims"), + [ + ( + {torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: None}, + {torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: None}, + ), + ( + {torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: (1, 2, 3, 0)}, + {torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: (3, 0, 1, 2)}, + ), + ], +) +def test_permute_dimensions(dims, inverse_dims): + sample = dict( + plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), + image=make_image(), + bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY), + video=make_video(), + str="str", + int=0, + ) + + transform = transforms.PermuteDimensions(dims) + transformed_sample = transform(sample) + + for key, value in sample.items(): + value_type = type(value) + transformed_value = transformed_sample[key] + + # make sure the transformation retains the type + assert isinstance(transformed_value, value_type) + + if isinstance(value, torch.Tensor) and transform.dims.get(value_type) is not None: + assert transformed_value.permute(inverse_dims[value_type]).equal(value) + else: + assert transformed_value is value + + +@pytest.mark.parametrize( + "dims", + [ + (-1, -2), + {torch.Tensor: (-1, -2), features.Image: (1, 2), features.Video: None}, + ], +) +def test_transpose_dimensions(dims): + sample = dict( + plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), + image=make_image(), + bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY), + video=make_video(), + str="str", + int=0, + ) + + transform = transforms.TransposeDimensions(dims) + transformed_sample = transform(sample) + + for key, value in sample.items(): + value_type = type(value) + transformed_value = transformed_sample[key] + + # make sure the transformation retains the type + assert isinstance(transformed_value, value_type) + + transposed_dims = transform.dims.get(value_type) + if isinstance(value, torch.Tensor) and transposed_dims is not None: + assert transformed_value.transpose(*transposed_dims).equal(value) + else: + assert transformed_value is value diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 5324db63496..5bf5a12cd78 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -40,7 +40,17 @@ TenCrop, ) from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype -from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype +from ._misc import ( + GaussianBlur, + Identity, + Lambda, + LinearTransformation, + Normalize, + PermuteDimensions, + RemoveSmallBoundingBoxes, + ToDtype, + TransposeDimensions, +) from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index bf7af5c26c7..c0db9c0e276 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,6 +1,6 @@ import functools from collections import defaultdict -from typing import Any, Callable, Dict, List, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import PIL.Image @@ -142,18 +142,19 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.gaussian_blur(inpt, self.kernel_size, **params) +def _default_arg(value: Any) -> Any: + return value + + class ToDtype(Transform): _transformed_types = (torch.Tensor,) - def _default_dtype(self, dtype: torch.dtype) -> torch.dtype: - return dtype - - def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None: + def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None: super().__init__() if not isinstance(dtype, dict): # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. # If it were possible, we could replace this with `defaultdict(lambda: dtype)` - dtype = defaultdict(functools.partial(self._default_dtype, dtype)) + dtype = defaultdict(functools.partial(_default_arg, dtype)) self.dtype = dtype def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: @@ -163,6 +164,45 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt.to(dtype=dtype) +class PermuteDimensions(Transform): + _transformed_types = (features.is_simple_tensor, features.Image, features.Video) + + def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None: + super().__init__() + if not isinstance(dims, dict): + dims = defaultdict(functools.partial(_default_arg, dims)) + self.dims = dims + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + dims = self.dims[type(inpt)] + if dims is None: + return inpt + output = inpt.permute(dims) + if isinstance(inpt, (features.Image, features.Video)): + output = inpt.wrap_like(inpt, output) + return output + + +class TransposeDimensions(Transform): + _transformed_types = (features.is_simple_tensor, features.Image, features.Video) + + def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None: + super().__init__() + if not isinstance(dims, dict): + dims = defaultdict(functools.partial(_default_arg, dims)) + self.dims = dims + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + dims = self.dims[type(inpt)] + if dims is None: + return inpt + dim0, dim1 = dims + output = inpt.transpose(dim0, dim1) + if isinstance(inpt, (features.Image, features.Video)): + output = inpt.wrap_like(inpt, output) + return output + + class RemoveSmallBoundingBoxes(Transform): _transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel) From 912cc44ec3ef3be9e38fe819f3d53df3f072baee Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 20 Oct 2022 15:30:51 +0100 Subject: [PATCH 2/3] Strip Subclass info. --- test/test_prototype_transforms.py | 19 +++++++++---------- torchvision/prototype/transforms/_misc.py | 23 ++++++++++------------- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 9a2988c796d..351430e1c9d 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -23,6 +23,7 @@ ) from torchvision.ops.boxes import box_iou from torchvision.prototype import features, transforms +from torchvision.prototype.transforms._utils import _isinstance from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] @@ -1859,11 +1860,10 @@ def test_permute_dimensions(dims, inverse_dims): value_type = type(value) transformed_value = transformed_sample[key] - # make sure the transformation retains the type - assert isinstance(transformed_value, value_type) - - if isinstance(value, torch.Tensor) and transform.dims.get(value_type) is not None: - assert transformed_value.permute(inverse_dims[value_type]).equal(value) + if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)): + if transform.dims.get(value_type) is not None: + assert transformed_value.permute(inverse_dims[value_type]).equal(value) + assert type(transformed_value) == torch.Tensor else: assert transformed_value is value @@ -1892,11 +1892,10 @@ def test_transpose_dimensions(dims): value_type = type(value) transformed_value = transformed_sample[key] - # make sure the transformation retains the type - assert isinstance(transformed_value, value_type) - transposed_dims = transform.dims.get(value_type) - if isinstance(value, torch.Tensor) and transposed_dims is not None: - assert transformed_value.transpose(*transposed_dims).equal(value) + if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)): + if transposed_dims is not None: + assert transformed_value.transpose(*transposed_dims).equal(value) + assert type(transformed_value) == torch.Tensor else: assert transformed_value is value diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index c0db9c0e276..a5a6a511229 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -173,14 +173,13 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]] dims = defaultdict(functools.partial(_default_arg, dims)) self.dims = dims - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def _transform( + self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] + ) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: - return inpt - output = inpt.permute(dims) - if isinstance(inpt, (features.Image, features.Video)): - output = inpt.wrap_like(inpt, output) - return output + return inpt.as_subclass(torch.Tensor) + return inpt.permute(*dims) class TransposeDimensions(Transform): @@ -192,15 +191,13 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i dims = defaultdict(functools.partial(_default_arg, dims)) self.dims = dims - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def _transform( + self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] + ) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: - return inpt - dim0, dim1 = dims - output = inpt.transpose(dim0, dim1) - if isinstance(inpt, (features.Image, features.Video)): - output = inpt.wrap_like(inpt, output) - return output + return inpt.as_subclass(torch.Tensor) + return inpt.transpose(*dims) class RemoveSmallBoundingBoxes(Transform): From a4b5ddbb0394dd44562c87cc75932d79807fa659 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 21 Oct 2022 08:48:49 +0100 Subject: [PATCH 3/3] Apply changes from code review. --- torchvision/prototype/transforms/_misc.py | 16 ++++------------ torchvision/prototype/transforms/_utils.py | 19 +++++++++++++------ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index a5a6a511229..aad684bf1a8 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,5 +1,3 @@ -import functools -from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import PIL.Image @@ -9,7 +7,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform -from ._utils import _setup_float_or_seq, _setup_size, has_any, query_bounding_box +from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size, has_any, query_bounding_box class Identity(Transform): @@ -142,19 +140,13 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.gaussian_blur(inpt, self.kernel_size, **params) -def _default_arg(value: Any) -> Any: - return value - - class ToDtype(Transform): _transformed_types = (torch.Tensor,) def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None: super().__init__() if not isinstance(dtype, dict): - # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. - # If it were possible, we could replace this with `defaultdict(lambda: dtype)` - dtype = defaultdict(functools.partial(_default_arg, dtype)) + dtype = _get_defaultdict(dtype) self.dtype = dtype def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: @@ -170,7 +162,7 @@ class PermuteDimensions(Transform): def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None: super().__init__() if not isinstance(dims, dict): - dims = defaultdict(functools.partial(_default_arg, dims)) + dims = _get_defaultdict(dims) self.dims = dims def _transform( @@ -188,7 +180,7 @@ class TransposeDimensions(Transform): def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None: super().__init__() if not isinstance(dims, dict): - dims = defaultdict(functools.partial(_default_arg, dims)) + dims = _get_defaultdict(dims) self.dims = dims def _transform( diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index b3e241d166b..cff439b8872 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,7 +1,7 @@ import functools import numbers from collections import defaultdict -from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union import PIL.Image @@ -42,8 +42,17 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: raise TypeError("Got inappropriate fill arg") -def _default_fill(fill: FillType) -> FillType: - return fill +T = TypeVar("T") + + +def _default_arg(value: T) -> T: + return value + + +def _get_defaultdict(default: T) -> Dict[Any, T]: + # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. + # If it were possible, we could replace this with `defaultdict(lambda: default)` + return defaultdict(functools.partial(_default_arg, default)) def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: @@ -52,9 +61,7 @@ def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, F if isinstance(fill, dict): return fill - # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. - # If it were possible, we could replace this with `defaultdict(lambda: fill)` - return defaultdict(functools.partial(_default_fill, fill)) + return _get_defaultdict(fill) def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: