Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
make_masks,
make_one_hot_labels,
make_segmentation_mask,
make_video,
make_videos,
)
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features, transforms
from torchvision.prototype.transforms._utils import _isinstance
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
Expand Down Expand Up @@ -1826,3 +1828,74 @@ def test_to_dtype(dtype, expected_dtypes):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value


@pytest.mark.parametrize(
("dims", "inverse_dims"),
[
(
{torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: None},
{torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: None},
),
(
{torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: (1, 2, 3, 0)},
{torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: (3, 0, 1, 2)},
),
],
)
def test_permute_dimensions(dims, inverse_dims):
sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)

transform = transforms.PermuteDimensions(dims)
transformed_sample = transform(sample)

for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]

if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value


@pytest.mark.parametrize(
"dims",
[
(-1, -2),
{torch.Tensor: (-1, -2), features.Image: (1, 2), features.Video: None},
],
)
def test_transpose_dimensions(dims):
sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)

transform = transforms.TransposeDimensions(dims)
transformed_sample = transform(sample)

for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]

transposed_dims = transform.dims.get(value_type)
if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value
12 changes: 11 additions & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,17 @@
TenCrop,
)
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype
from ._misc import (
GaussianBlur,
Identity,
Lambda,
LinearTransformation,
Normalize,
PermuteDimensions,
RemoveSmallBoundingBoxes,
ToDtype,
TransposeDimensions,
)
from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage

from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip
51 changes: 40 additions & 11 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import functools
from collections import defaultdict
from typing import Any, Callable, Dict, List, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import PIL.Image

Expand All @@ -9,7 +7,7 @@
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform

from ._utils import _setup_float_or_seq, _setup_size, has_any, query_bounding_box
from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size, has_any, query_bounding_box


class Identity(Transform):
Expand Down Expand Up @@ -145,15 +143,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
class ToDtype(Transform):
_transformed_types = (torch.Tensor,)

def _default_dtype(self, dtype: torch.dtype) -> torch.dtype:
return dtype

def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None:
def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None:
super().__init__()
if not isinstance(dtype, dict):
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: dtype)`
dtype = defaultdict(functools.partial(self._default_dtype, dtype))
dtype = _get_defaultdict(dtype)
self.dtype = dtype

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
Expand All @@ -163,6 +156,42 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.to(dtype=dtype)


class PermuteDimensions(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)

def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
self.dims = dims

def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt.permute(*dims)


class TransposeDimensions(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)

def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
self.dims = dims

def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt.transpose(*dims)


class RemoveSmallBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel)

Expand Down
19 changes: 13 additions & 6 deletions torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import numbers
from collections import defaultdict
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union

import PIL.Image

Expand Down Expand Up @@ -42,8 +42,17 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
raise TypeError("Got inappropriate fill arg")


def _default_fill(fill: FillType) -> FillType:
return fill
T = TypeVar("T")


def _default_arg(value: T) -> T:
return value


def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))


def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
Expand All @@ -52,9 +61,7 @@ def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, F
if isinstance(fill, dict):
return fill

# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: fill)`
return defaultdict(functools.partial(_default_fill, fill))
return _get_defaultdict(fill)


def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
Expand Down