diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 5d4dc92a11b..2b2639e2cc5 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -7,7 +7,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F -from ._utils import query_image, get_image_dimensions, has_all, has_any +from ._utils import query_image, get_image_dimensions, has_all, has_any, is_simple_tensor class RandomErasing(Transform): @@ -90,7 +90,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if isinstance(input, features.Image): output = F.erase_image_tensor(input, **params) return features.Image.new_like(input, output) - elif isinstance(input, torch.Tensor): + elif is_simple_tensor(input): return F.erase_image_tensor(input, **params) else: return input diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index cb4e5979102..c451feb9a32 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -8,7 +8,7 @@ from torchvision.prototype.utils._internal import query_recursively from torchvision.transforms.functional import pil_to_tensor, to_pil_image -from ._utils import get_image_dimensions +from ._utils import get_image_dimensions, is_simple_tensor K = TypeVar("K") V = TypeVar("V") @@ -89,7 +89,7 @@ def _dispatch_image_kernels( if isinstance(input, features.Image): output = image_tensor_kernel(input, *args, **kwargs) return features.Image.new_like(input, output) - elif isinstance(input, torch.Tensor): + elif is_simple_tensor(input): return image_tensor_kernel(input, *args, **kwargs) else: # isinstance(input, PIL.Image.Image): return image_pil_kernel(input, *args, **kwargs) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 6f4f7a6cb4d..44e31dee856 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -8,7 +8,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int -from ._utils import query_image, get_image_dimensions, has_any +from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor class HorizontalFlip(Transform): @@ -21,7 +21,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return features.BoundingBox.new_like(input, output) elif isinstance(input, PIL.Image.Image): return F.horizontal_flip_image_pil(input) - elif isinstance(input, torch.Tensor): + elif is_simple_tensor(input): return F.horizontal_flip_image_tensor(input) else: return input @@ -49,7 +49,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size))) elif isinstance(input, PIL.Image.Image): return F.resize_image_pil(input, self.size, interpolation=self.interpolation) - elif isinstance(input, torch.Tensor): + elif is_simple_tensor(input): return F.resize_image_tensor(input, self.size, interpolation=self.interpolation) else: return input @@ -64,7 +64,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if isinstance(input, features.Image): output = F.center_crop_image_tensor(input, self.output_size) return features.Image.new_like(input, output) - elif isinstance(input, torch.Tensor): + elif is_simple_tensor(input): return F.center_crop_image_tensor(input, self.output_size) elif isinstance(input, PIL.Image.Image): return F.center_crop_image_pil(input, self.output_size) @@ -156,7 +156,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: input, **params, size=list(self.size), interpolation=self.interpolation ) return features.Image.new_like(input, output) - elif isinstance(input, torch.Tensor): + elif is_simple_tensor(input): return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation) elif isinstance(input, PIL.Image.Image): return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation) diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 09d2892769c..6634a3144f7 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -6,6 +6,8 @@ from torchvision.prototype.transforms import Transform, functional as F from torchvision.transforms.functional import convert_image_dtype +from ._utils import is_simple_tensor + class ConvertBoundingBoxFormat(Transform): def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None: @@ -15,7 +17,7 @@ def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None: self.format = format def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is features.BoundingBox: + if isinstance(input, features.BoundingBox): output = F.convert_bounding_box_format(input, old_format=input.format, new_format=params["format"]) return features.BoundingBox.new_like(input, output, format=params["format"]) else: @@ -28,9 +30,11 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: self.dtype = dtype def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is features.Image: + if isinstance(input, features.Image): output = convert_image_dtype(input, dtype=self.dtype) return features.Image.new_like(input, output, dtype=self.dtype) + elif is_simple_tensor(input): + return convert_image_dtype(input, dtype=self.dtype) else: return input @@ -57,7 +61,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: input, old_color_space=input.color_space, new_color_space=self.color_space ) return features.Image.new_like(input, output, color_space=self.color_space) - elif isinstance(input, torch.Tensor): + elif is_simple_tensor(input): if self.old_color_space is None: raise RuntimeError( f"In order to convert vanilla tensor images, `{type(self).__name__}(...)` " diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index fa49c35265e..f2dc426897b 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -6,7 +6,7 @@ class DecodeImage(Transform): def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is features.EncodedImage: + if isinstance(input, features.EncodedImage): output = F.decode_image_with_pil(input) return features.Image(output) else: @@ -19,7 +19,7 @@ def __init__(self, num_categories: int = -1): self.num_categories = num_categories def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is features.Label: + if isinstance(input, features.Label): num_categories = self.num_categories if num_categories == -1 and input.categories is not None: num_categories = len(input.categories) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 74cbd84a64e..0517757a758 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -46,3 +46,7 @@ def has_any(sample: Any, *types: Type) -> bool: def has_all(sample: Any, *types: Type) -> bool: return not bool(set(types) - set(_extract_types(sample))) + + +def is_simple_tensor(input: Any) -> bool: + return isinstance(input, torch.Tensor) and not isinstance(input, features._Feature)