diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 5928e6718c1..351430e1c9d 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -18,10 +18,12 @@ make_masks, make_one_hot_labels, make_segmentation_mask, + make_video, make_videos, ) from torchvision.ops.boxes import box_iou from torchvision.prototype import features, transforms +from torchvision.prototype.transforms._utils import _isinstance from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] @@ -1826,3 +1828,74 @@ def test_to_dtype(dtype, expected_dtypes): assert transformed_value.dtype is expected_dtypes[value_type] else: assert transformed_value is value + + +@pytest.mark.parametrize( + ("dims", "inverse_dims"), + [ + ( + {torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: None}, + {torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: None}, + ), + ( + {torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: (1, 2, 3, 0)}, + {torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: (3, 0, 1, 2)}, + ), + ], +) +def test_permute_dimensions(dims, inverse_dims): + sample = dict( + plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), + image=make_image(), + bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY), + video=make_video(), + str="str", + int=0, + ) + + transform = transforms.PermuteDimensions(dims) + transformed_sample = transform(sample) + + for key, value in sample.items(): + value_type = type(value) + transformed_value = transformed_sample[key] + + if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)): + if transform.dims.get(value_type) is not None: + assert transformed_value.permute(inverse_dims[value_type]).equal(value) + assert type(transformed_value) == torch.Tensor + else: + assert transformed_value is value + + +@pytest.mark.parametrize( + "dims", + [ + (-1, -2), + {torch.Tensor: (-1, -2), features.Image: (1, 2), features.Video: None}, + ], +) +def test_transpose_dimensions(dims): + sample = dict( + plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), + image=make_image(), + bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY), + video=make_video(), + str="str", + int=0, + ) + + transform = transforms.TransposeDimensions(dims) + transformed_sample = transform(sample) + + for key, value in sample.items(): + value_type = type(value) + transformed_value = transformed_sample[key] + + transposed_dims = transform.dims.get(value_type) + if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)): + if transposed_dims is not None: + assert transformed_value.transpose(*transposed_dims).equal(value) + assert type(transformed_value) == torch.Tensor + else: + assert transformed_value is value diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 5324db63496..5bf5a12cd78 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -40,7 +40,17 @@ TenCrop, ) from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype -from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype +from ._misc import ( + GaussianBlur, + Identity, + Lambda, + LinearTransformation, + Normalize, + PermuteDimensions, + RemoveSmallBoundingBoxes, + ToDtype, + TransposeDimensions, +) from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index bf7af5c26c7..aad684bf1a8 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,6 +1,4 @@ -import functools -from collections import defaultdict -from typing import Any, Callable, Dict, List, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import PIL.Image @@ -9,7 +7,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform -from ._utils import _setup_float_or_seq, _setup_size, has_any, query_bounding_box +from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size, has_any, query_bounding_box class Identity(Transform): @@ -145,15 +143,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ToDtype(Transform): _transformed_types = (torch.Tensor,) - def _default_dtype(self, dtype: torch.dtype) -> torch.dtype: - return dtype - - def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None: + def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None: super().__init__() if not isinstance(dtype, dict): - # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. - # If it were possible, we could replace this with `defaultdict(lambda: dtype)` - dtype = defaultdict(functools.partial(self._default_dtype, dtype)) + dtype = _get_defaultdict(dtype) self.dtype = dtype def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: @@ -163,6 +156,42 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt.to(dtype=dtype) +class PermuteDimensions(Transform): + _transformed_types = (features.is_simple_tensor, features.Image, features.Video) + + def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None: + super().__init__() + if not isinstance(dims, dict): + dims = _get_defaultdict(dims) + self.dims = dims + + def _transform( + self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] + ) -> torch.Tensor: + dims = self.dims[type(inpt)] + if dims is None: + return inpt.as_subclass(torch.Tensor) + return inpt.permute(*dims) + + +class TransposeDimensions(Transform): + _transformed_types = (features.is_simple_tensor, features.Image, features.Video) + + def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None: + super().__init__() + if not isinstance(dims, dict): + dims = _get_defaultdict(dims) + self.dims = dims + + def _transform( + self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any] + ) -> torch.Tensor: + dims = self.dims[type(inpt)] + if dims is None: + return inpt.as_subclass(torch.Tensor) + return inpt.transpose(*dims) + + class RemoveSmallBoundingBoxes(Transform): _transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index b3e241d166b..cff439b8872 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,7 +1,7 @@ import functools import numbers from collections import defaultdict -from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union import PIL.Image @@ -42,8 +42,17 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: raise TypeError("Got inappropriate fill arg") -def _default_fill(fill: FillType) -> FillType: - return fill +T = TypeVar("T") + + +def _default_arg(value: T) -> T: + return value + + +def _get_defaultdict(default: T) -> Dict[Any, T]: + # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. + # If it were possible, we could replace this with `defaultdict(lambda: default)` + return defaultdict(functools.partial(_default_arg, default)) def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: @@ -52,9 +61,7 @@ def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, F if isinstance(fill, dict): return fill - # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. - # If it were possible, we could replace this with `defaultdict(lambda: fill)` - return defaultdict(functools.partial(_default_fill, fill)) + return _get_defaultdict(fill) def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: