diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 20e5ac91676..baa3e157385 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -34,10 +34,15 @@ def erase( v: torch.Tensor, inplace: bool = False, ) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]: - if isinstance(inpt, torch.Tensor): - output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] - return output + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + ): + return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + elif isinstance(inpt, features.Image): + output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) + return features.Image.wrap_like(inpt, output) + elif isinstance(inpt, features.Video): + output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) + return features.Video.wrap_like(inpt, output) else: # isinstance(inpt, PIL.Image.Image): return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index e18c267e84c..1075e9a64ca 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -25,12 +25,13 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima def rgb_to_grayscale( inpt: Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT], num_output_channels: int = 1 ) -> Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT]: - old_color_space = ( - features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] - if isinstance(inpt, torch.Tensor) - and (torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))) - else None - ) + if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): + inpt = inpt.as_subclass(torch.Tensor) + old_color_space = None + elif isinstance(inpt, torch.Tensor): + old_color_space = features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] + else: + old_color_space = None call = ", num_output_channels=3" if num_output_channels == 3 else "" replacement = ( diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 1451b83cf26..a112db7e127 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1400,12 +1400,16 @@ def five_crop( inpt: ImageOrVideoTypeJIT, size: List[int] ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: # TODO: consider breaking BC here to return List[features.ImageTypeJIT/VideoTypeJIT] to align this op with `ten_crop` - if isinstance(inpt, torch.Tensor): - output = five_crop_image_tensor(inpt, size) - if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): - tmp = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type] - output = tmp # type: ignore[assignment] - return output + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + ): + return five_crop_image_tensor(inpt, size) + elif isinstance(inpt, features.Image): + output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size) + return tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value] + elif isinstance(inpt, features.Video): + output = five_crop_video(inpt.as_subclass(torch.Tensor), size) + return tuple(features.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value] else: # isinstance(inpt, PIL.Image.Image): return five_crop_image_pil(inpt, size) @@ -1444,10 +1448,15 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F def ten_crop( inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], size: List[int], vertical_flip: bool = False ) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]: - if isinstance(inpt, torch.Tensor): - output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) - if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): - output = [inpt.wrap_like(inpt, item) for item in output] # type: ignore[arg-type] - return output + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + ): + return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) + elif isinstance(inpt, features.Image): + output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) + return [features.Image.wrap_like(inpt, item) for item in output] + elif isinstance(inpt, features.Video): + output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) + return [features.Video.wrap_like(inpt, item) for item in output] else: # isinstance(inpt, PIL.Image.Image): return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)