diff --git a/torchvision/prototype/features/__init__.py b/torchvision/prototype/features/__init__.py index 6fc2fb6ea94..944ae9bd3c6 100644 --- a/torchvision/prototype/features/__init__.py +++ b/torchvision/prototype/features/__init__.py @@ -13,4 +13,14 @@ ) from ._label import Label, OneHotLabel from ._mask import Mask -from ._video import ImageOrVideoType, ImageOrVideoTypeJIT, TensorImageOrVideoType, TensorImageOrVideoTypeJIT, Video +from ._video import ( + ImageOrVideoType, + ImageOrVideoTypeJIT, + LegacyVideoType, + LegacyVideoTypeJIT, + TensorImageOrVideoType, + TensorImageOrVideoTypeJIT, + Video, + VideoType, + VideoTypeJIT, +) diff --git a/torchvision/prototype/features/_video.py b/torchvision/prototype/features/_video.py index a58027243cf..e32c36d5d9f 100644 --- a/torchvision/prototype/features/_video.py +++ b/torchvision/prototype/features/_video.py @@ -238,6 +238,7 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N TensorVideoType = Union[torch.Tensor, Video] TensorVideoTypeJIT = torch.Tensor +# TODO: decide if we should do definitions for both Images and Videos or use unions in the methods ImageOrVideoType = Union[ImageType, VideoType] ImageOrVideoTypeJIT = Union[ImageTypeJIT, VideoTypeJIT] TensorImageOrVideoType = Union[TensorImageType, TensorVideoType] diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index bcab0a3f454..7b2dca8a601 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -99,6 +99,7 @@ def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> return inpt +# TODO: Add support for Video: https://github.com/pytorch/vision/issues/6731 class _BaseMixupCutmix(_RandomApplyTransform): def __init__(self, alpha: float, p: float = 0.5) -> None: super().__init__(p=p) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 6ef9edba354..d078cb2d1cb 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -483,7 +483,8 @@ def forward(self, *inputs: Any) -> Any: augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE orig_dims = list(image_or_video.shape) - batch = image_or_video.view([1] * max(4 - image_or_video.ndim, 0) + orig_dims) + expected_dim = 5 if isinstance(orig_image_or_video, features.Video) else 4 + batch = image_or_video.view([1] * max(expected_dim - image_or_video.ndim, 0) + orig_dims) batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a @@ -520,7 +521,7 @@ def forward(self, *inputs: Any) -> Any: mix = mix.view(orig_dims).to(dtype=image_or_video.dtype) if isinstance(orig_image_or_video, (features.Image, features.Video)): - mix = type(orig_image_or_video).wrap_like(orig_image_or_video, mix) # type: ignore[arg-type] + mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type] elif isinstance(orig_image_or_video, PIL.Image.Image): mix = F.to_image_pil(mix) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 67a6cc3cc3f..340e721dab9 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -119,7 +119,7 @@ def _permute_channels( output = inpt[..., permutation, :, :] if isinstance(inpt, (features.Image, features.Video)): - output = type(inpt).wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type] + output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type] elif isinstance(inpt, PIL.Image.Image): output = F.to_image_pil(output) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 3979b178f48..f8aec22b96c 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -29,7 +29,7 @@ def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, class Grayscale(Transform): - _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor) + _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: deprecation_msg = ( @@ -52,15 +52,15 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: super().__init__() self.num_output_channels = num_output_channels - def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: + def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType: output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) - if isinstance(inpt, features.Image): - output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) + if isinstance(inpt, (features.Image, features.Video)): + output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type] return output class RandomGrayscale(_RandomApplyTransform): - _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor) + _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) def __init__(self, p: float = 0.1) -> None: warnings.warn( @@ -81,8 +81,8 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: num_input_channels, _, _ = query_chw(sample) return dict(num_input_channels=num_input_channels) - def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: + def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType: output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) - if isinstance(inpt, features.Image): - output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) + if isinstance(inpt, (features.Image, features.Video)): + output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type] return output diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 37e2aee0236..371ea7f69c5 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -155,12 +155,13 @@ class FiveCrop(Transform): """ Example: >>> class BatchMultiCrop(transforms.Transform): - ... def forward(self, sample: Tuple[Tuple[features.Image, ...], features.Label]): - ... images, labels = sample - ... batch_size = len(images) - ... images = features.Image.wrap_like(images[0], torch.stack(images)) + ... def forward(self, sample: Tuple[Tuple[Union[features.Image, features.Video], ...], features.Label]): + ... images_or_videos, labels = sample + ... batch_size = len(images_or_videos) + ... image_or_video = images_or_videos[0] + ... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos)) ... labels = features.Label.wrap_like(labels, labels.repeat(batch_size)) - ... return images, labels + ... return images_or_videos, labels ... >>> image = features.Image(torch.rand(3, 256, 256)) >>> label = features.Label(0) @@ -172,15 +173,21 @@ class FiveCrop(Transform): torch.Size([5]) """ - _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor) + _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform( - self, inpt: features.ImageType, params: Dict[str, Any] - ) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]: + self, inpt: features.ImageOrVideoType, params: Dict[str, Any] + ) -> Tuple[ + features.ImageOrVideoType, + features.ImageOrVideoType, + features.ImageOrVideoType, + features.ImageOrVideoType, + features.ImageOrVideoType, + ]: return F.five_crop(inpt, self.size) def forward(self, *inputs: Any) -> Any: @@ -194,14 +201,14 @@ class TenCrop(Transform): See :class:`~torchvision.prototype.transforms.FiveCrop` for an example. """ - _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor) + _transformed_types = (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video) def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.vertical_flip = vertical_flip - def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> List[features.ImageType]: + def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> List[features.ImageOrVideoType]: return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) def forward(self, *inputs: Any) -> Any: diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 74fbcd60f02..e5c7d05b017 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -22,18 +22,18 @@ def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> feat class ConvertImageDtype(Transform): - _transformed_types = (features.is_simple_tensor, features.Image) + _transformed_types = (features.is_simple_tensor, features.Image, features.Video) def __init__(self, dtype: torch.dtype = torch.float32) -> None: super().__init__() self.dtype = dtype - def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType: + def _transform( + self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any] + ) -> features.TensorImageOrVideoType: output = F.convert_image_dtype(inpt, dtype=self.dtype) return ( - output - if features.is_simple_tensor(inpt) - else features.Image.wrap_like(inpt, output) # type: ignore[arg-type] + output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined] ) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index dd1e1cdf8a1..d3c8a57dc80 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -140,6 +140,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.gaussian_blur(inpt, self.kernel_size, **params) +# TODO: Enhance as described at https://github.com/pytorch/vision/issues/6697 class ToDtype(Lambda): def __init__(self, dtype: torch.dtype, *types: Type) -> None: self.dtype = dtype diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 1e918cc3492..579442dc7b9 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -96,6 +96,7 @@ five_crop, five_crop_image_pil, five_crop_image_tensor, + five_crop_video, hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file horizontal_flip, horizontal_flip_bounding_box, @@ -136,6 +137,7 @@ ten_crop, ten_crop_image_pil, ten_crop_image_tensor, + ten_crop_video, vertical_flip, vertical_flip_bounding_box, vertical_flip_image_pil, diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 847343dbf20..57c3602cc14 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -35,7 +35,7 @@ def erase( if isinstance(inpt, torch.Tensor): output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): - output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] + output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] return output else: # isinstance(inpt, PIL.Image.Image): return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index cbdea5130ef..854920b968a 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, List +from typing import Any, List, Union import PIL.Image import torch @@ -22,10 +22,13 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima return _F.to_grayscale(inpt, num_output_channels=num_output_channels) -def rgb_to_grayscale(inpt: features.LegacyImageTypeJIT, num_output_channels: int = 1) -> features.LegacyImageTypeJIT: +def rgb_to_grayscale( + inpt: Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT], num_output_channels: int = 1 +) -> Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT]: old_color_space = ( features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)) + if isinstance(inpt, torch.Tensor) + and (torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))) else None ) @@ -56,7 +59,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: return _F.to_tensor(inpt) -def get_image_size(inpt: features.ImageTypeJIT) -> List[int]: +def get_image_size(inpt: features.ImageOrVideoTypeJIT) -> List[int]: warnings.warn( "The function `get_image_size(...)` is deprecated and will be removed in a future release. " "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 93df59ad646..44b4986aba0 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1376,16 +1376,27 @@ def five_crop_image_pil( return tl, tr, bl, br, center +def five_crop_video( + video: torch.Tensor, size: List[int] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return five_crop_image_tensor(video, size) + + def five_crop( - inpt: features.ImageTypeJIT, size: List[int] + inpt: features.ImageOrVideoTypeJIT, size: List[int] ) -> Tuple[ - features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT + features.ImageOrVideoTypeJIT, + features.ImageOrVideoTypeJIT, + features.ImageOrVideoTypeJIT, + features.ImageOrVideoTypeJIT, + features.ImageOrVideoTypeJIT, ]: - # TODO: consider breaking BC here to return List[features.ImageTypeJIT] to align this op with `ten_crop` + # TODO: consider breaking BC here to return List[features.ImageOrVideoTypeJIT] to align this op with `ten_crop` if isinstance(inpt, torch.Tensor): output = five_crop_image_tensor(inpt, size) - if not torch.jit.is_scripting() and isinstance(inpt, features.Image): - output = tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[assignment] + if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): + tmp = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type] + output = tmp # type: ignore[assignment] return output else: # isinstance(inpt, PIL.Image.Image): return five_crop_image_pil(inpt, size) @@ -1418,11 +1429,17 @@ def ten_crop_image_pil(image: PIL.Image.Image, size: List[int], vertical_flip: b return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] -def ten_crop(inpt: features.ImageTypeJIT, size: List[int], vertical_flip: bool = False) -> List[features.ImageTypeJIT]: +def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]: + return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip) + + +def ten_crop( + inpt: features.ImageOrVideoTypeJIT, size: List[int], vertical_flip: bool = False +) -> List[features.ImageOrVideoTypeJIT]: if isinstance(inpt, torch.Tensor): output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) - if not torch.jit.is_scripting() and isinstance(inpt, features.Image): - output = [features.Image.wrap_like(inpt, item) for item in output] + if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): + output = [inpt.wrap_like(inpt, item) for item in output] # type: ignore[arg-type] return output else: # isinstance(inpt, PIL.Image.Image): return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index e24b68c9fd6..c03d65c951b 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -55,6 +55,10 @@ def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]: return [height, width] +# TODO: Should we have get_spatial_size_video here? How about masks/bbox etc? What is the criterion for deciding when +# a kernel will be created? + + def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return get_spatial_size_image_tensor(inpt) @@ -246,7 +250,7 @@ def convert_color_space( ): if old_color_space is None: raise RuntimeError( - "In order to convert the color space of simple tensor images, " + "In order to convert the color space of simple tensors, " "the `old_color_space=...` parameter needs to be passed." ) return convert_color_space_image_tensor(