From b7aa77d319b081fc63af2cc9527619bdda25c9d5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 25 Oct 2022 10:41:18 +0200 Subject: [PATCH 1/3] unwrap features in dispatchers --- .../prototype/transforms/functional/_augment.py | 2 ++ .../prototype/transforms/functional/_deprecated.py | 13 +++++++------ .../prototype/transforms/functional/_geometry.py | 8 ++++++-- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 20e5ac91676..930b35d2a05 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -35,6 +35,8 @@ def erase( inplace: bool = False, ) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]: if isinstance(inpt, torch.Tensor): + if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): + inpt = inpt.as_subclass(torch.Tensor) output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index e18c267e84c..57382a4d857 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -25,12 +25,13 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima def rgb_to_grayscale( inpt: Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT], num_output_channels: int = 1 ) -> Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT]: - old_color_space = ( - features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] - if isinstance(inpt, torch.Tensor) - and (torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))) - else None - ) + if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): + inpt = inpt.as_subclass(torch.Tensor) + old_color_space = None + elif isinstance(inpt, torch.Tensor): + old_color_space = features._image._from_tensor_shape(inpt.shape) + else: + old_color_space = None call = ", num_output_channels=3" if num_output_channels == 3 else "" replacement = ( diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 1451b83cf26..b9229c0ba16 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1401,10 +1401,12 @@ def five_crop( ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: # TODO: consider breaking BC here to return List[features.ImageTypeJIT/VideoTypeJIT] to align this op with `ten_crop` if isinstance(inpt, torch.Tensor): + if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): + inpt = inpt.as_subclass(torch.Tensor) output = five_crop_image_tensor(inpt, size) if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): - tmp = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type] - output = tmp # type: ignore[assignment] + output = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type] + # output = tmp # type: ignore[assignment] return output else: # isinstance(inpt, PIL.Image.Image): return five_crop_image_pil(inpt, size) @@ -1445,6 +1447,8 @@ def ten_crop( inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], size: List[int], vertical_flip: bool = False ) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]: if isinstance(inpt, torch.Tensor): + if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): + inpt = inpt.as_subclass(torch.Tensor) output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): output = [inpt.wrap_like(inpt, item) for item in output] # type: ignore[arg-type] From 05b3de54d8626cfc9982ef4edcbb7d544151f6e5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 25 Oct 2022 10:44:20 +0200 Subject: [PATCH 2/3] cleanup --- torchvision/prototype/transforms/functional/_deprecated.py | 2 +- torchvision/prototype/transforms/functional/_geometry.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index 57382a4d857..1075e9a64ca 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -29,7 +29,7 @@ def rgb_to_grayscale( inpt = inpt.as_subclass(torch.Tensor) old_color_space = None elif isinstance(inpt, torch.Tensor): - old_color_space = features._image._from_tensor_shape(inpt.shape) + old_color_space = features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] else: old_color_space = None diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index b9229c0ba16..26b43fc8474 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1405,8 +1405,8 @@ def five_crop( inpt = inpt.as_subclass(torch.Tensor) output = five_crop_image_tensor(inpt, size) if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): - output = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type] - # output = tmp # type: ignore[assignment] + tmp = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type] + output = tmp # type: ignore[assignment] return output else: # isinstance(inpt, PIL.Image.Image): return five_crop_image_pil(inpt, size) From 3f5f7e45404b6f858bc73c9149fbc9d4618c6e23 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 25 Oct 2022 12:00:28 +0200 Subject: [PATCH 3/3] align erase / five_crop / ten_crop with other dispatchers --- .../transforms/functional/_augment.py | 17 +++++---- .../transforms/functional/_geometry.py | 35 +++++++++++-------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 930b35d2a05..baa3e157385 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -34,12 +34,15 @@ def erase( v: torch.Tensor, inplace: bool = False, ) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]: - if isinstance(inpt, torch.Tensor): - if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): - inpt = inpt.as_subclass(torch.Tensor) - output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] - return output + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + ): + return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + elif isinstance(inpt, features.Image): + output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) + return features.Image.wrap_like(inpt, output) + elif isinstance(inpt, features.Video): + output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) + return features.Video.wrap_like(inpt, output) else: # isinstance(inpt, PIL.Image.Image): return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 26b43fc8474..a112db7e127 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1400,14 +1400,16 @@ def five_crop( inpt: ImageOrVideoTypeJIT, size: List[int] ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: # TODO: consider breaking BC here to return List[features.ImageTypeJIT/VideoTypeJIT] to align this op with `ten_crop` - if isinstance(inpt, torch.Tensor): - if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): - inpt = inpt.as_subclass(torch.Tensor) - output = five_crop_image_tensor(inpt, size) - if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): - tmp = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type] - output = tmp # type: ignore[assignment] - return output + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + ): + return five_crop_image_tensor(inpt, size) + elif isinstance(inpt, features.Image): + output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size) + return tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value] + elif isinstance(inpt, features.Video): + output = five_crop_video(inpt.as_subclass(torch.Tensor), size) + return tuple(features.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value] else: # isinstance(inpt, PIL.Image.Image): return five_crop_image_pil(inpt, size) @@ -1446,12 +1448,15 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F def ten_crop( inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], size: List[int], vertical_flip: bool = False ) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]: - if isinstance(inpt, torch.Tensor): - if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): - inpt = inpt.as_subclass(torch.Tensor) - output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) - if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): - output = [inpt.wrap_like(inpt, item) for item in output] # type: ignore[arg-type] - return output + if isinstance(inpt, torch.Tensor) and ( + torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) + ): + return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) + elif isinstance(inpt, features.Image): + output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) + return [features.Image.wrap_like(inpt, item) for item in output] + elif isinstance(inpt, features.Video): + output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) + return [features.Video.wrap_like(inpt, item) for item in output] else: # isinstance(inpt, PIL.Image.Image): return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)