From 0c469ac0f40fa8046437d7715186097a32401769 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 6 Oct 2022 21:22:07 +0200 Subject: [PATCH 1/3] replace new_like with wrap_like --- test/test_prototype_features.py | 4 +- test/test_prototype_transforms.py | 17 ++-- .../prototype/features/_bounding_box.py | 57 +++++++------ torchvision/prototype/features/_encoded.py | 11 ++- torchvision/prototype/features/_feature.py | 55 +++++------- torchvision/prototype/features/_image.py | 85 ++++++++++--------- torchvision/prototype/features/_label.py | 20 +++-- torchvision/prototype/features/_mask.py | 47 +++++++--- torchvision/prototype/transforms/_augment.py | 22 ++--- .../prototype/transforms/_auto_augment.py | 2 +- torchvision/prototype/transforms/_color.py | 2 +- .../prototype/transforms/_deprecated.py | 4 +- torchvision/prototype/transforms/_geometry.py | 16 ++-- torchvision/prototype/transforms/_meta.py | 10 ++- torchvision/prototype/transforms/_misc.py | 2 +- .../transforms/functional/_augment.py | 2 +- .../transforms/functional/_geometry.py | 4 +- 17 files changed, 198 insertions(+), 162 deletions(-) diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py index 2701dd66be0..d2b0d2e632c 100644 --- a/test/test_prototype_features.py +++ b/test/test_prototype_features.py @@ -99,14 +99,14 @@ def test_inplace_op_no_wrapping(): assert type(label) is features.Label -def test_new_like(): +def test_wrap_like(): tensor = torch.tensor([0, 1, 0], dtype=torch.int64) label = features.Label(tensor, categories=["foo", "bar"]) # any operation besides .to() and .clone() will do here output = label * 2 - label_new = features.Label.new_like(label, output) + label_new = features.Label.wrap_like(label, output) assert type(label_new) is features.Label assert label_new.data_ptr() == output.data_ptr() diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 9734a5dc30a..784b4b8bb34 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -8,6 +8,7 @@ import torch from common_utils import assert_equal, cpu_and_gpu from prototype_common_utils import ( + DEFAULT_EXTRA_DIMS, make_bounding_box, make_bounding_boxes, make_detection_mask, @@ -22,6 +23,8 @@ from torchvision.prototype import features, transforms from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image +BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] + def make_vanilla_tensor_images(*args, **kwargs): for image in make_images(*args, **kwargs): @@ -107,13 +110,11 @@ def test_common(self, transform, input): ( transform, [ - dict( - image=features.Image.new_like(image, image.unsqueeze(0), dtype=torch.float), - one_hot_label=features.OneHotLabel.new_like( - one_hot_label, one_hot_label.unsqueeze(0), dtype=torch.float - ), + dict(image=image, one_hot_label=one_hot_label) + for image, one_hot_label in itertools.product( + make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), + make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), ) - for image, one_hot_label in itertools.product(make_images(), make_one_hot_labels()) ], ) for transform in [ @@ -293,7 +294,7 @@ def test_features_bounding_box(self, p): actual = transform(input) expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input - expected = features.BoundingBox.new_like(input, data=expected_image_tensor) + expected = features.BoundingBox.wrap_like(input, expected_image_tensor) assert_equal(expected, actual) assert actual.format == expected.format assert actual.image_size == expected.image_size @@ -346,7 +347,7 @@ def test_features_bounding_box(self, p): actual = transform(input) expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input - expected = features.BoundingBox.new_like(input, data=expected_image_tensor) + expected = features.BoundingBox.wrap_like(input, expected_image_tensor) assert_equal(expected, actual) assert actual.format == expected.format assert actual.image_size == expected.image_size diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 9ccd4fa62ad..7b69af5f9bb 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -19,6 +19,13 @@ class BoundingBox(_Feature): format: BoundingBoxFormat image_size: Tuple[int, int] + @classmethod + def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, image_size: Tuple[int, int]) -> BoundingBox: + bounding_box = tensor.as_subclass(cls) + bounding_box.format = format + bounding_box.image_size = image_size + return bounding_box + def __new__( cls, data: Any, @@ -29,52 +36,46 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, ) -> BoundingBox: - bounding_box = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if isinstance(format, str): format = BoundingBoxFormat.from_str(format.upper()) - bounding_box.format = format - - bounding_box.image_size = image_size - return bounding_box - - def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(format=self.format, image_size=self.image_size) + return cls._wrap(tensor, format=format, image_size=image_size) @classmethod - def new_like( + def wrap_like( cls, other: BoundingBox, - data: Any, + tensor: torch.Tensor, *, - format: Optional[Union[BoundingBoxFormat, str]] = None, + format: Optional[BoundingBoxFormat] = None, image_size: Optional[Tuple[int, int]] = None, - **kwargs: Any, ) -> BoundingBox: - return super().new_like( - other, - data, + return cls._wrap( + tensor, format=format if format is not None else other.format, image_size=image_size if image_size is not None else other.image_size, - **kwargs, ) + def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] + return self._make_repr(format=self.format, image_size=self.image_size) + def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: if isinstance(format, str): format = BoundingBoxFormat.from_str(format.upper()) - return BoundingBox.new_like( + return BoundingBox.wrap_like( self, self._F.convert_format_bounding_box(self, old_format=self.format, new_format=format), format=format ) def horizontal_flip(self) -> BoundingBox: output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size) - return BoundingBox.new_like(self, output) + return BoundingBox.wrap_like(self, output) def vertical_flip(self) -> BoundingBox: output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size) - return BoundingBox.new_like(self, output) + return BoundingBox.wrap_like(self, output) def resize( # type: ignore[override] self, @@ -84,19 +85,19 @@ def resize( # type: ignore[override] antialias: bool = False, ) -> BoundingBox: output, image_size = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size) - return BoundingBox.new_like(self, output, image_size=image_size) + return BoundingBox.wrap_like(self, output, image_size=image_size) def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: output, image_size = self._F.crop_bounding_box( self, self.format, top=top, left=left, height=height, width=width ) - return BoundingBox.new_like(self, output, image_size=image_size) + return BoundingBox.wrap_like(self, output, image_size=image_size) def center_crop(self, output_size: List[int]) -> BoundingBox: output, image_size = self._F.center_crop_bounding_box( self, format=self.format, image_size=self.image_size, output_size=output_size ) - return BoundingBox.new_like(self, output, image_size=image_size) + return BoundingBox.wrap_like(self, output, image_size=image_size) def resized_crop( self, @@ -109,7 +110,7 @@ def resized_crop( antialias: bool = False, ) -> BoundingBox: output, image_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size) - return BoundingBox.new_like(self, output, image_size=image_size) + return BoundingBox.wrap_like(self, output, image_size=image_size) def pad( self, @@ -120,7 +121,7 @@ def pad( output, image_size = self._F.pad_bounding_box( self, format=self.format, image_size=self.image_size, padding=padding, padding_mode=padding_mode ) - return BoundingBox.new_like(self, output, image_size=image_size) + return BoundingBox.wrap_like(self, output, image_size=image_size) def rotate( self, @@ -133,7 +134,7 @@ def rotate( output, image_size = self._F.rotate_bounding_box( self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center ) - return BoundingBox.new_like(self, output, image_size=image_size) + return BoundingBox.wrap_like(self, output, image_size=image_size) def affine( self, @@ -155,7 +156,7 @@ def affine( shear=shear, center=center, ) - return BoundingBox.new_like(self, output, dtype=output.dtype) + return BoundingBox.wrap_like(self, output) def perspective( self, @@ -164,7 +165,7 @@ def perspective( fill: FillTypeJIT = None, ) -> BoundingBox: output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs) - return BoundingBox.new_like(self, output, dtype=output.dtype) + return BoundingBox.wrap_like(self, output) def elastic( self, @@ -173,4 +174,4 @@ def elastic( fill: FillTypeJIT = None, ) -> BoundingBox: output = self._F.elastic_bounding_box(self, self.format, displacement) - return BoundingBox.new_like(self, output, dtype=output.dtype) + return BoundingBox.wrap_like(self, output) diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index b8b9839600f..6a76b29bb19 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -14,6 +14,10 @@ class EncodedData(_Feature): + @classmethod + def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: + return tensor.as_subclass(cls) + def __new__( cls, data: Any, @@ -22,8 +26,13 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, ) -> EncodedData: + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? - return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) + return cls._wrap(tensor) + + @classmethod + def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: + return cls._wrap(tensor) @classmethod def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D: diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 2da10be90ab..a56441f2967 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -21,48 +21,39 @@ def is_simple_tensor(inpt: Any) -> bool: class _Feature(torch.Tensor): __F: Optional[ModuleType] = None - def __new__( - cls: Type[F], + @staticmethod + def _to_tensor( data: Any, - *, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, - ) -> F: - return ( - torch.as_tensor( # type: ignore[return-value] - data, - dtype=dtype, - device=device, - ) - .as_subclass(cls) - .requires_grad_(requires_grad) - ) + ) -> torch.Tensor: + return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) - @classmethod - def new_like( - cls: Type[F], - other: F, + # FIXME: this is just here for BC with the prototype datasets. Some datasets use the _Feature directly to have a + # a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be + # interpreted as images. We should decide if we want a public no-op feature like `GenericFeature` or make this one + # public again. + def __new__( + cls, data: Any, - *, dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, - requires_grad: Optional[bool] = None, - **kwargs: Any, - ) -> F: - return cls( - data, - dtype=dtype if dtype is not None else other.dtype, - device=device if device is not None else other.device, - requires_grad=requires_grad if requires_grad is not None else other.requires_grad, - **kwargs, - ) + requires_grad: bool = False, + ) -> _Feature: + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + return tensor.as_subclass(_Feature) + + @classmethod + def wrap_like(cls: Type[F], other: F, tensor: torch.Tensor) -> F: + # FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved, + # this method should be made abstract + # raise NotImplementedError + return tensor.as_subclass(cls) _NO_WRAPPING_EXCEPTIONS = { - torch.Tensor.clone: lambda cls, input, output: cls.new_like(input, output), - torch.Tensor.to: lambda cls, input, output: cls.new_like( - input, output, dtype=output.dtype, device=output.device - ), + torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), + torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output), # We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus # retains the type automatically torch.Tensor.requires_grad_: lambda cls, input, output: output, diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index c953ae78c2a..23f81678d79 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -62,6 +62,12 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace: class Image(_Feature): color_space: ColorSpace + @classmethod + def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Image: + image = tensor.as_subclass(cls) + image.color_space = color_space + return image + def __new__( cls, data: Any, @@ -71,36 +77,33 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, ) -> Image: - data = torch.as_tensor(data, dtype=dtype, device=device) - if data.ndim < 2: + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + if tensor.ndim < 2: raise ValueError - elif data.ndim == 2: - data = data.unsqueeze(0) - image = super().__new__(cls, data, requires_grad=requires_grad) + elif tensor.ndim == 2: + tensor = tensor.unsqueeze(0) if color_space is None: - color_space = ColorSpace.from_tensor_shape(image.shape) # type: ignore[arg-type] + color_space = ColorSpace.from_tensor_shape(tensor.shape) # type: ignore[arg-type] if color_space == ColorSpace.OTHER: warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") elif isinstance(color_space, str): color_space = ColorSpace.from_str(color_space.upper()) elif not isinstance(color_space, ColorSpace): raise ValueError - image.color_space = color_space - return image - - def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(color_space=self.color_space) + return cls._wrap(tensor, color_space=color_space) @classmethod - def new_like( - cls, other: Image, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any - ) -> Image: - return super().new_like( - other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs + def wrap_like(cls, other: Image, tensor: torch.Tensor, *, color_space: Optional[ColorSpace] = None) -> Image: + return cls._wrap( + tensor, + color_space=color_space if color_space is not None else other.color_space, ) + def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] + return self._make_repr(color_space=self.color_space) + @property def image_size(self) -> Tuple[int, int]: return cast(Tuple[int, int], tuple(self.shape[-2:])) @@ -113,7 +116,7 @@ def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) if isinstance(color_space, str): color_space = ColorSpace.from_str(color_space.upper()) - return Image.new_like( + return Image.wrap_like( self, self._F.convert_color_space_image_tensor( self, old_color_space=self.color_space, new_color_space=color_space, copy=copy @@ -129,15 +132,15 @@ def show(self) -> None: def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image: # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we # promote this out of the prototype state - return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) + return Image.wrap_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) def horizontal_flip(self) -> Image: output = self._F.horizontal_flip_image_tensor(self) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def vertical_flip(self) -> Image: output = self._F.vertical_flip_image_tensor(self) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def resize( # type: ignore[override] self, @@ -149,15 +152,15 @@ def resize( # type: ignore[override] output = self._F.resize_image_tensor( self, size, interpolation=interpolation, max_size=max_size, antialias=antialias ) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def crop(self, top: int, left: int, height: int, width: int) -> Image: output = self._F.crop_image_tensor(self, top, left, height, width) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def center_crop(self, output_size: List[int]) -> Image: output = self._F.center_crop_image_tensor(self, output_size=output_size) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def resized_crop( self, @@ -172,7 +175,7 @@ def resized_crop( output = self._F.resized_crop_image_tensor( self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias ) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def pad( self, @@ -181,7 +184,7 @@ def pad( padding_mode: str = "constant", ) -> Image: output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def rotate( self, @@ -194,7 +197,7 @@ def rotate( output = self._F._geometry.rotate_image_tensor( self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center ) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def affine( self, @@ -216,7 +219,7 @@ def affine( fill=fill, center=center, ) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def perspective( self, @@ -227,7 +230,7 @@ def perspective( output = self._F._geometry.perspective_image_tensor( self, perspective_coeffs, interpolation=interpolation, fill=fill ) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def elastic( self, @@ -236,55 +239,55 @@ def elastic( fill: FillTypeJIT = None, ) -> Image: output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def adjust_brightness(self, brightness_factor: float) -> Image: output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def adjust_saturation(self, saturation_factor: float) -> Image: output = self._F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def adjust_contrast(self, contrast_factor: float) -> Image: output = self._F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def adjust_sharpness(self, sharpness_factor: float) -> Image: output = self._F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def adjust_hue(self, hue_factor: float) -> Image: output = self._F.adjust_hue_image_tensor(self, hue_factor=hue_factor) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def adjust_gamma(self, gamma: float, gain: float = 1) -> Image: output = self._F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def posterize(self, bits: int) -> Image: output = self._F.posterize_image_tensor(self, bits=bits) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def solarize(self, threshold: float) -> Image: output = self._F.solarize_image_tensor(self, threshold=threshold) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def autocontrast(self) -> Image: output = self._F.autocontrast_image_tensor(self) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def equalize(self) -> Image: output = self._F.equalize_image_tensor(self) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def invert(self) -> Image: output = self._F.invert_image_tensor(self) - return Image.new_like(self, output) + return Image.wrap_like(self, output) def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image: output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma) - return Image.new_like(self, output) + return Image.wrap_like(self, output) ImageType = Union[torch.Tensor, PIL.Image.Image, Image] diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index ebaa84d66ce..9c2bcfc0fb1 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -14,6 +14,12 @@ class _LabelBase(_Feature): categories: Optional[Sequence[str]] + @classmethod + def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: + label_base = tensor.as_subclass(cls) + label_base.categories = categories + return label_base + def __new__( cls: Type[L], data: Any, @@ -23,16 +29,14 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, ) -> L: - label_base = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) - - label_base.categories = categories - - return label_base + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + return cls._wrap(tensor, categories=categories) @classmethod - def new_like(cls: Type[L], other: L, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any) -> L: - return super().new_like( - other, data, categories=categories if categories is not None else other.categories, **kwargs + def wrap_like(cls: Type[L], other: L, tensor: torch.Tensor, *, categories: Optional[Sequence[str]] = None) -> L: + return cls._wrap( + tensor, + categories=categories if categories is not None else other.categories, ) @classmethod diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 9dd614752a6..65793dc45df 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, Union +from typing import Any, List, Optional, Union import torch from torchvision.transforms import InterpolationMode @@ -9,13 +9,36 @@ class Mask(_Feature): + @classmethod + def _wrap(cls, tensor: torch.Tensor) -> Mask: + return tensor.as_subclass(cls) + + def __new__( + cls, + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: bool = False, + ) -> Mask: + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + return cls._wrap(tensor) + + @classmethod + def wrap_like( + cls, + other: Mask, + tensor: torch.Tensor, + ) -> Mask: + return cls._wrap(tensor) + def horizontal_flip(self) -> Mask: output = self._F.horizontal_flip_mask(self) - return Mask.new_like(self, output) + return Mask.wrap_like(self, output) def vertical_flip(self) -> Mask: output = self._F.vertical_flip_mask(self) - return Mask.new_like(self, output) + return Mask.wrap_like(self, output) def resize( # type: ignore[override] self, @@ -25,15 +48,15 @@ def resize( # type: ignore[override] antialias: bool = False, ) -> Mask: output = self._F.resize_mask(self, size, max_size=max_size) - return Mask.new_like(self, output) + return Mask.wrap_like(self, output) def crop(self, top: int, left: int, height: int, width: int) -> Mask: output = self._F.crop_mask(self, top, left, height, width) - return Mask.new_like(self, output) + return Mask.wrap_like(self, output) def center_crop(self, output_size: List[int]) -> Mask: output = self._F.center_crop_mask(self, output_size=output_size) - return Mask.new_like(self, output) + return Mask.wrap_like(self, output) def resized_crop( self, @@ -46,7 +69,7 @@ def resized_crop( antialias: bool = False, ) -> Mask: output = self._F.resized_crop_mask(self, top, left, height, width, size=size) - return Mask.new_like(self, output) + return Mask.wrap_like(self, output) def pad( self, @@ -55,7 +78,7 @@ def pad( padding_mode: str = "constant", ) -> Mask: output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill) - return Mask.new_like(self, output) + return Mask.wrap_like(self, output) def rotate( self, @@ -66,7 +89,7 @@ def rotate( center: Optional[List[float]] = None, ) -> Mask: output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill) - return Mask.new_like(self, output) + return Mask.wrap_like(self, output) def affine( self, @@ -87,7 +110,7 @@ def affine( fill=fill, center=center, ) - return Mask.new_like(self, output) + return Mask.wrap_like(self, output) def perspective( self, @@ -96,7 +119,7 @@ def perspective( fill: FillTypeJIT = None, ) -> Mask: output = self._F.perspective_mask(self, perspective_coeffs, fill=fill) - return Mask.new_like(self, output) + return Mask.wrap_like(self, output) def elastic( self, @@ -105,4 +128,4 @@ def elastic( fill: FillTypeJIT = None, ) -> Mask: output = self._F.elastic_mask(self, displacement, fill=fill) - return Mask.new_like(self, output, dtype=output.dtype) + return Mask.wrap_like(self, output) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 3cd925fd996..848b69d2899 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -119,7 +119,7 @@ def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features raise ValueError("Need a batch of one hot labels") output = inpt.clone() output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam)) - return features.OneHotLabel.new_like(inpt, output) + return features.OneHotLabel.wrap_like(inpt, output) class RandomMixup(_BaseMixupCutmix): @@ -135,7 +135,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam)) if isinstance(inpt, features.Image): - output = features.Image.new_like(inpt, output) + output = features.Image.wrap_like(inpt, output) return output elif isinstance(inpt, features.OneHotLabel): @@ -178,7 +178,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] if isinstance(inpt, features.Image): - output = features.Image.new_like(inpt, output) + output = features.Image.wrap_like(inpt, output) return output elif isinstance(inpt, features.OneHotLabel): @@ -213,9 +213,11 @@ def _copy_paste( antialias: Optional[bool], ) -> Tuple[features.TensorImageType, Dict[str, Any]]: - paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection]) - paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection]) - paste_labels = paste_target["labels"].new_like(paste_target["labels"], paste_target["labels"][random_selection]) + paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection]) + paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection]) + paste_labels = paste_target["labels"].wrap_like( + paste_target["labels"], paste_target["labels"][random_selection] + ) masks = target["masks"] @@ -317,7 +319,7 @@ def _insert_outputs( c0, c1, c2, c3 = 0, 0, 0, 0 for i, obj in enumerate(flat_sample): if isinstance(obj, features.Image): - flat_sample[i] = features.Image.new_like(obj, output_images[c0]) + flat_sample[i] = features.Image.wrap_like(obj, output_images[c0]) c0 += 1 elif isinstance(obj, PIL.Image.Image): flat_sample[i] = F.to_image_pil(output_images[c0]) @@ -326,13 +328,13 @@ def _insert_outputs( flat_sample[i] = output_images[c0] c0 += 1 elif isinstance(obj, features.BoundingBox): - flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"]) + flat_sample[i] = features.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"]) c1 += 1 elif isinstance(obj, features.Mask): - flat_sample[i] = features.Mask.new_like(obj, output_targets[c2]["masks"]) + flat_sample[i] = features.Mask.wrap_like(obj, output_targets[c2]["masks"]) c2 += 1 elif isinstance(obj, (features.Label, features.OneHotLabel)): - flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] + flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] c3 += 1 def forward(self, *inputs: Any) -> Any: diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index c98e5c36e4a..f7448ca27b5 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -518,7 +518,7 @@ def forward(self, *inputs: Any) -> Any: mix = mix.view(orig_dims).to(dtype=image.dtype) if isinstance(orig_image, features.Image): - mix = features.Image.new_like(orig_image, mix) + mix = features.Image.wrap_like(orig_image, mix) elif isinstance(orig_image, PIL.Image.Image): mix = F.to_image_pil(mix) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index e0ee8d1b96a..148a6b0c3e1 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -117,7 +117,7 @@ def _permute_channels(self, inpt: features.ImageType, permutation: torch.Tensor) output = inpt[..., permutation, :, :] if isinstance(inpt, features.Image): - output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER) + output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) elif isinstance(inpt, PIL.Image.Image): output = F.to_image_pil(output) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index a9341415c1a..3979b178f48 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -55,7 +55,7 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) if isinstance(inpt, features.Image): - output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY) + output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) return output @@ -84,5 +84,5 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) if isinstance(inpt, features.Image): - output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY) + output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) return output diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 008d4d195cb..7585c7f92b2 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -158,8 +158,8 @@ class FiveCrop(Transform): ... def forward(self, sample: Tuple[Tuple[features.Image, ...], features.Label]): ... images, labels = sample ... batch_size = len(images) - ... images = features.Image.new_like(images[0], torch.stack(images)) - ... labels = features.Label.new_like(labels, labels.repeat(batch_size)) + ... images = features.Image.wrap_like(images[0], torch.stack(images)) + ... labels = features.Label.wrap_like(labels, labels.repeat(batch_size)) ... return images, labels ... >>> image = features.Image(torch.rand(3, 256, 256)) @@ -677,18 +677,18 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: is_within_crop_area = params["is_within_crop_area"] if isinstance(inpt, (features.Label, features.OneHotLabel)): - return inpt.new_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type] + return inpt.wrap_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type] output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) if isinstance(output, features.BoundingBox): bboxes = output[is_within_crop_area] bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size) - output = features.BoundingBox.new_like(output, bboxes) + output = features.BoundingBox.wrap_like(output, bboxes) elif isinstance(output, features.Mask): # apply is_within_crop_area if mask is one-hot encoded masks = output[is_within_crop_area] - output = features.Mask.new_like(output, masks) + output = features.Mask.wrap_like(output, masks) return output @@ -801,7 +801,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: bounding_boxes = cast( features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=new_height, width=new_width) ) - bounding_boxes = features.BoundingBox.new_like( + bounding_boxes = features.BoundingBox.wrap_like( bounding_boxes, F.clamp_bounding_box( bounding_boxes, format=bounding_boxes.format, image_size=bounding_boxes.image_size @@ -840,9 +840,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["is_valid"] is not None: if isinstance(inpt, (features.Label, features.OneHotLabel, features.Mask)): - inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] + inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] elif isinstance(inpt, features.BoundingBox): - inpt = features.BoundingBox.new_like( + inpt = features.BoundingBox.wrap_like( inpt, F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, image_size=inpt.image_size), ) diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 2ea3014aa6c..f87fc8cbb48 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, cast, Dict, Optional, Union import PIL.Image @@ -18,7 +18,7 @@ def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None: def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox: output = F.convert_format_bounding_box(inpt, old_format=inpt.format, new_format=params["format"]) - return features.BoundingBox.new_like(inpt, output, format=params["format"]) + return features.BoundingBox.wrap_like(inpt, output, format=params["format"]) class ConvertImageDtype(Transform): @@ -30,7 +30,9 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType: output = F.convert_image_dtype(inpt, dtype=self.dtype) - return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) # type: ignore[arg-type] + return ( + output if features.is_simple_tensor(inpt) else features.Image.wrap_like(cast(features.Image, inpt), output) + ) class ConvertColorSpace(Transform): @@ -65,4 +67,4 @@ class ClampBoundingBoxes(Transform): def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox: output = F.clamp_bounding_box(inpt, format=inpt.format, image_size=inpt.image_size) - return features.BoundingBox.new_like(inpt, output) + return features.BoundingBox.wrap_like(inpt, output) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 976e9f8b5ff..ed62b84d0a1 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -171,4 +171,4 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(valid_indices=valid_indices) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return inpt.new_like(inpt, inpt[params["valid_indices"]]) + return inpt.wrap_like(inpt, inpt[params["valid_indices"]]) diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index fb48c35888d..a50cc05da90 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -29,7 +29,7 @@ def erase( if isinstance(inpt, torch.Tensor): output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): - output = features.Image.new_like(inpt, output) + output = features.Image.wrap_like(inpt, output) return output else: # isinstance(inpt, PIL.Image.Image): return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 7a291967bfd..3c5f4378d96 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1272,7 +1272,7 @@ def five_crop( if isinstance(inpt, torch.Tensor): output = five_crop_image_tensor(inpt, size) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): - output = tuple(features.Image.new_like(inpt, item) for item in output) # type: ignore[assignment] + output = tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[assignment] return output else: # isinstance(inpt, PIL.Image.Image): return five_crop_image_pil(inpt, size) @@ -1309,7 +1309,7 @@ def ten_crop(inpt: features.ImageTypeJIT, size: List[int], vertical_flip: bool = if isinstance(inpt, torch.Tensor): output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): - output = [features.Image.new_like(inpt, item) for item in output] + output = [features.Image.wrap_like(inpt, item) for item in output] return output else: # isinstance(inpt, PIL.Image.Image): return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip) From e69f20af39221e3f629ecfb4c8a9558476f694ff Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 7 Oct 2022 16:36:18 +0200 Subject: [PATCH 2/3] fix videos --- torchvision/prototype/features/_video.py | 74 +++++++++++++----------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/torchvision/prototype/features/_video.py b/torchvision/prototype/features/_video.py index e19b6f7ed1c..a58027243cf 100644 --- a/torchvision/prototype/features/_video.py +++ b/torchvision/prototype/features/_video.py @@ -13,6 +13,12 @@ class Video(_Feature): color_space: ColorSpace + @classmethod + def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Video: + image = tensor.as_subclass(cls) + image.color_space = color_space + return image + def __new__( cls, data: Any, @@ -22,7 +28,7 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, ) -> Video: - data = torch.as_tensor(data, dtype=dtype, device=device) + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if data.ndim < 4: raise ValueError video = super().__new__(cls, data, requires_grad=requires_grad) @@ -35,21 +41,19 @@ def __new__( color_space = ColorSpace.from_str(color_space.upper()) elif not isinstance(color_space, ColorSpace): raise ValueError - video.color_space = color_space - - return video - def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(color_space=self.color_space) + return cls._wrap(tensor, color_space=color_space) @classmethod - def new_like( - cls, other: Video, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any - ) -> Video: - return super().new_like( - other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs + def wrap_like(cls, other: Video, tensor: torch.Tensor, *, color_space: Optional[ColorSpace] = None) -> Video: + return cls._wrap( + tensor, + color_space=color_space if color_space is not None else other.color_space, ) + def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] + return self._make_repr(color_space=self.color_space) + # TODO: rename this (and all instances of this term to spatial size) @property def image_size(self) -> Tuple[int, int]: @@ -67,7 +71,7 @@ def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) if isinstance(color_space, str): color_space = ColorSpace.from_str(color_space.upper()) - return Video.new_like( + return Video.wrap_like( self, self._F.convert_color_space_video( self, old_color_space=self.color_space, new_color_space=color_space, copy=copy @@ -77,11 +81,11 @@ def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) def horizontal_flip(self) -> Video: output = self._F.horizontal_flip_video(self) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def vertical_flip(self) -> Video: output = self._F.vertical_flip_video(self) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def resize( # type: ignore[override] self, @@ -91,15 +95,15 @@ def resize( # type: ignore[override] antialias: bool = False, ) -> Video: output = self._F.resize_video(self, size, interpolation=interpolation, max_size=max_size, antialias=antialias) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def crop(self, top: int, left: int, height: int, width: int) -> Video: output = self._F.crop_video(self, top, left, height, width) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def center_crop(self, output_size: List[int]) -> Video: output = self._F.center_crop_video(self, output_size=output_size) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def resized_crop( self, @@ -114,7 +118,7 @@ def resized_crop( output = self._F.resized_crop_video( self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias ) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def pad( self, @@ -123,7 +127,7 @@ def pad( padding_mode: str = "constant", ) -> Video: output = self._F.pad_video(self, padding, fill=fill, padding_mode=padding_mode) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def rotate( self, @@ -136,7 +140,7 @@ def rotate( output = self._F._geometry.rotate_video( self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center ) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def affine( self, @@ -158,7 +162,7 @@ def affine( fill=fill, center=center, ) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def perspective( self, @@ -167,7 +171,7 @@ def perspective( fill: FillTypeJIT = None, ) -> Video: output = self._F._geometry.perspective_video(self, perspective_coeffs, interpolation=interpolation, fill=fill) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def elastic( self, @@ -176,55 +180,55 @@ def elastic( fill: FillTypeJIT = None, ) -> Video: output = self._F._geometry.elastic_video(self, displacement, interpolation=interpolation, fill=fill) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def adjust_brightness(self, brightness_factor: float) -> Video: output = self._F.adjust_brightness_video(self, brightness_factor=brightness_factor) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def adjust_saturation(self, saturation_factor: float) -> Video: output = self._F.adjust_saturation_video(self, saturation_factor=saturation_factor) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def adjust_contrast(self, contrast_factor: float) -> Video: output = self._F.adjust_contrast_video(self, contrast_factor=contrast_factor) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def adjust_sharpness(self, sharpness_factor: float) -> Video: output = self._F.adjust_sharpness_video(self, sharpness_factor=sharpness_factor) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def adjust_hue(self, hue_factor: float) -> Video: output = self._F.adjust_hue_video(self, hue_factor=hue_factor) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def adjust_gamma(self, gamma: float, gain: float = 1) -> Video: output = self._F.adjust_gamma_video(self, gamma=gamma, gain=gain) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def posterize(self, bits: int) -> Video: output = self._F.posterize_video(self, bits=bits) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def solarize(self, threshold: float) -> Video: output = self._F.solarize_video(self, threshold=threshold) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def autocontrast(self) -> Video: output = self._F.autocontrast_video(self) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def equalize(self) -> Video: output = self._F.equalize_video(self) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def invert(self) -> Video: output = self._F.invert_video(self) - return Video.new_like(self, output) + return Video.wrap_like(self, output) def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video: output = self._F.gaussian_blur_video(self, kernel_size=kernel_size, sigma=sigma) - return Video.new_like(self, output) + return Video.wrap_like(self, output) VideoType = Union[torch.Tensor, Video] From fcada94306db2845eb95f2370e9378a160b36a55 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 7 Oct 2022 16:47:05 +0200 Subject: [PATCH 3/3] revert casting in favor of ignoring mypy --- torchvision/prototype/transforms/_meta.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index d6d995cf5e1..74fbcd60f02 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -1,4 +1,4 @@ -from typing import Any, cast, Dict, Optional, Union +from typing import Any, Dict, Optional, Union import PIL.Image @@ -31,7 +31,9 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType: output = F.convert_image_dtype(inpt, dtype=self.dtype) return ( - output if features.is_simple_tensor(inpt) else features.Image.wrap_like(cast(features.Image, inpt), output) + output + if features.is_simple_tensor(inpt) + else features.Image.wrap_like(inpt, output) # type: ignore[arg-type] )