From 49f7e5adefd2ad4043593b2027f150f3d94be4eb Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 25 Oct 2022 17:48:52 +0100 Subject: [PATCH 01/15] Change random generator for ColorJitter. --- torchvision/prototype/transforms/_color.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 3647365c3fb..0dcf636c3db 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -51,7 +51,7 @@ def _check_input( @staticmethod def _generate_value(left: float, right: float) -> float: - return float(torch.distributions.Uniform(left, right).sample()) + return torch.empty(1).uniform_(left, right).item() def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: fn_idx = torch.randperm(4) From 99b1685ed6976607637869ddc3cf3391e1a99175 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 11:24:51 +0100 Subject: [PATCH 02/15] Move `_convert_fill_arg` from runtime to constructor. --- test/test_prototype_transforms.py | 26 +++++++++---------- test/test_prototype_transforms_consistency.py | 2 -- .../prototype/transforms/_auto_augment.py | 3 +-- torchvision/prototype/transforms/_geometry.py | 9 ------- torchvision/prototype/transforms/_utils.py | 25 +++++++++++++++--- .../transforms/functional/_geometry.py | 14 ---------- 6 files changed, 35 insertions(+), 44 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 4334d157e40..fab4cc0ddd6 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -389,7 +389,7 @@ def test__transform(self, padding, fill, padding_mode, mocker): inpt = mocker.MagicMock(spec=features.Image) _ = transform(inpt) - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) if isinstance(padding, tuple): padding = list(padding) fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) @@ -405,14 +405,14 @@ def test__transform_image_mask(self, fill, mocker): _ = transform(inpt) if isinstance(fill, int): - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) calls = [ mocker.call(image, padding=1, fill=fill, padding_mode="constant"), mocker.call(mask, padding=1, fill=fill, padding_mode="constant"), ] else: - fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)]) - fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)]) + fill_img = transforms._utils._convert_fill_arg(fill[type(image)]) + fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)]) calls = [ mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"), mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"), @@ -466,7 +466,7 @@ def test__transform(self, fill, side_range, mocker): torch.rand(1) # random apply changes random state params = transform._get_params([inpt]) - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) fn.assert_called_once_with(inpt, **params, fill=fill) @pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) @@ -485,14 +485,14 @@ def test__transform_image_mask(self, fill, mocker): params = transform._get_params(inpt) if isinstance(fill, int): - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) calls = [ mocker.call(image, **params, fill=fill), mocker.call(mask, **params, fill=fill), ] else: - fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)]) - fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)]) + fill_img = transforms._utils._convert_fill_arg(fill[type(image)]) + fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)]) calls = [ mocker.call(image, **params, fill=fill_img), mocker.call(mask, **params, fill=fill_mask), @@ -556,7 +556,7 @@ def test__transform(self, degrees, expand, fill, center, mocker): torch.manual_seed(12) params = transform._get_params(inpt) - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center) @pytest.mark.parametrize("angle", [34, -87]) @@ -694,7 +694,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker torch.manual_seed(12) params = transform._get_params([inpt]) - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center) @@ -939,7 +939,7 @@ def test__transform(self, distortion_scale, mocker): torch.rand(1) # random apply changes random state params = transform._get_params([inpt]) - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) @@ -1009,7 +1009,7 @@ def test__transform(self, alpha, sigma, mocker): transform._get_params = mocker.MagicMock() _ = transform(inpt) params = transform._get_params([inpt]) - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) @@ -1632,7 +1632,7 @@ def test__transform(self, mocker, needs): if not needs_crop: assert args[0] is inpt_sentinel assert args[1] is padding_sentinel - fill_sentinel = transforms.functional._geometry._convert_fill_arg(fill_sentinel) + fill_sentinel = transforms._utils._convert_fill_arg(fill_sentinel) assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) else: mock_pad.assert_not_called() diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index b0022baaa37..a23783b0037 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -983,8 +983,6 @@ def _transform(self, inpt, params): return inpt fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) - return F.pad(inpt, padding=params["padding"], fill=fill) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 56d581eff9e..d1667556f1c 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -71,10 +71,9 @@ def _apply_image_or_video_transform( transform_id: str, magnitude: float, interpolation: InterpolationMode, - fill: Dict[Type, features.FillType], + fill: Dict[Type, features.FillTypeJIT], ) -> Union[features.ImageType, features.VideoType]: fill_ = fill[type(image)] - fill_ = F._geometry._convert_fill_arg(fill_) if transform_id == "Identity": return image diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 440e23ab631..78081a865d9 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -235,7 +235,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if not isinstance(padding, int): padding = list(padding) - fill = F._geometry._convert_fill_arg(fill) return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) @@ -274,7 +273,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) return F.pad(inpt, **params, fill=fill) @@ -305,7 +303,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) return F.rotate( inpt, **params, @@ -384,7 +381,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) return F.affine( inpt, **params, @@ -478,8 +474,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) - inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) if params["needs_crop"]: @@ -535,7 +529,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) return F.perspective( inpt, **params, @@ -584,7 +577,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) return F.elastic( inpt, **params, @@ -855,7 +847,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) return inpt diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index cff439b8872..55cce2f6e36 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -7,7 +7,7 @@ from torchvision._utils import sequence_to_str from torchvision.prototype import features -from torchvision.prototype.features._feature import FillType +from torchvision.prototype.features._feature import FillType, FillTypeJIT from torchvision.prototype.transforms.functional._meta import get_dimensions, get_spatial_size from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 @@ -55,13 +55,30 @@ def _get_defaultdict(default: T) -> Dict[Any, T]: return defaultdict(functools.partial(_default_arg, default)) -def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: +def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT: + # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 + # So, we can't reassign fill to 0 + # if fill is None: + # fill = 0 + if fill is None: + return fill + + # This cast does Sequence -> List[float] to please mypy and torch.jit.script + if not isinstance(fill, (int, float)): + fill = [float(v) for v in list(fill)] + return fill + + +def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]: _check_fill_arg(fill) if isinstance(fill, dict): - return fill + fill_copy = {} + for k, v in fill.items(): + fill_copy[k] = _convert_fill_arg(v) + return fill_copy - return _get_defaultdict(fill) + return _get_defaultdict(_convert_fill_arg(fill)) def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index a112db7e127..7f709b73b4b 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -470,20 +470,6 @@ def affine_video( ) -def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT: - # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 - # So, we can't reassign fill to 0 - # if fill is None: - # fill = 0 - if fill is None: - return fill - - # This cast does Sequence -> List[float] to please mypy and torch.jit.script - if not isinstance(fill, (int, float)): - fill = [float(v) for v in list(fill)] - return fill - - def affine( inpt: features.InputTypeJIT, angle: Union[int, float], From 17d818428a75e25a6583f441ccb7b1c269c4d465 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 11:34:48 +0100 Subject: [PATCH 03/15] Remove unnecessary TypeVars. --- torchvision/prototype/transforms/_auto_augment.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index d1667556f1c..75ffa7a93f6 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,5 +1,5 @@ import math -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union import PIL.Image import torch @@ -11,9 +11,6 @@ from ._utils import _isinstance, _setup_fill_arg -K = TypeVar("K") -V = TypeVar("V") - class _AutoAugmentBase(Transform): def __init__( @@ -26,7 +23,7 @@ def __init__( self.interpolation = interpolation self.fill = _setup_fill_arg(fill) - def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: + def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]: keys = tuple(dct.keys()) key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] From 5e0be6e8a0daf7b3879bff2cb7f7766239030897 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 11:39:18 +0100 Subject: [PATCH 04/15] Remove unnecessary casts --- .../prototype/transforms/_auto_augment.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 75ffa7a93f6..12ed8bc4a63 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,5 +1,5 @@ import math -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import PIL.Image import torch @@ -166,9 +166,7 @@ class AutoAugment(_AutoAugmentBase): "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( - lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) - .round() - .int(), + lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), False, ), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), @@ -323,9 +321,7 @@ class RandAugment(_AutoAugmentBase): "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( - lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) - .round() - .int(), + lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), False, ), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), @@ -379,9 +375,7 @@ class TrivialAugmentWide(_AutoAugmentBase): "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Posterize": ( - lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) - .round() - .int(), + lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(), False, ), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), @@ -426,9 +420,7 @@ class AugMix(_AutoAugmentBase): "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True), "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), "Posterize": ( - lambda num_bins, height, width: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) - .round() - .int(), + lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), False, ), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), From 08ae56f08d53a845041ffd2ffd18cb41f1cfbf75 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 12:25:02 +0100 Subject: [PATCH 05/15] Update comments. --- torchvision/prototype/transforms/_auto_augment.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 12ed8bc4a63..3f19fda67d0 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -505,7 +505,13 @@ def forward(self, *inputs: Any) -> Any: aug = self._apply_image_or_video_transform( aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) - mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) + mix.add_( + # The multiplication below could become in-place provided `aug is not batch and aug.is_floating_point()` + # Currently we can't do this because `aug` has to be `unint8` to support ops like `equalize`. + # TODO: change this once all ops in `F` support float inputs. + combined_weights[:, i].reshape(batch_dims) + * aug + ) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) if isinstance(orig_image_or_video, (features.Image, features.Video)): From 7b8be17b50caf717a4790770247d3072cd66e81b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 12:46:55 +0100 Subject: [PATCH 06/15] Minor code-quality changes on Geometical Transforms. --- torchvision/prototype/transforms/_geometry.py | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 78081a865d9..a4db1c775fa 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -223,19 +223,16 @@ def __init__( _check_padding_arg(padding) _check_padding_mode_arg(padding_mode) + # This cast does Sequence[int] -> List[int] and is required to make mypy happy + if not isinstance(padding, int): + padding = list(padding) self.padding = padding self.fill = _setup_fill_arg(fill) self.padding_mode = padding_mode def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - - # This cast does Sequence[int] -> List[int] and is required to make mypy happy - padding = self.padding - if not isinstance(padding, int): - padding = list(padding) - - return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) + return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) class RandomZoomOut(_RandomApplyTransform): @@ -298,7 +295,7 @@ def __init__( self.center = center def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() return dict(angle=angle) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: @@ -355,7 +352,7 @@ def __init__( def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_spatial_size(flat_inputs) - angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() if self.translate is not None: max_dx = float(self.translate[0] * width) max_dy = float(self.translate[1] * height) @@ -366,15 +363,15 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: translate = (0, 0) if self.scale is not None: - scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()) + scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() else: scale = 1.0 shear_x = shear_y = 0.0 if self.shear is not None: - shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()) + shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item() if len(self.shear) == 4: - shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()) + shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item() shear = (shear_x, shear_y) return dict(angle=angle, translate=translate, scale=scale, shear=shear) @@ -451,12 +448,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: needs_pad = any(padding) needs_vert_crop, top = ( - (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) + (True, torch.randint(0, padded_height - cropped_height + 1, size=()).item()) if padded_height > cropped_height else (False, 0) ) needs_horz_crop, left = ( - (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) + (True, torch.randint(0, padded_width - cropped_width + 1, size=()).item()) if padded_width > cropped_width else (False, 0) ) @@ -506,21 +503,23 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: half_height = height // 2 half_width = width // 2 + bound_height = int(distortion_scale * half_height) + 1 + bound_width = int(distortion_scale * half_width) + 1 topleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), + torch.randint(0, bound_width, size=(1,)).item(), + torch.randint(0, bound_height, size=(1,)).item(), ] topright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), + torch.randint(width - bound_width, width, size=(1,)).item(), + torch.randint(0, bound_height, size=(1,)).item(), ] botright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), + torch.randint(width - bound_width, width, size=(1,)).item(), + torch.randint(height - bound_height, height, size=(1,)).item(), ] botleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), + torch.randint(0, bound_width, size=(1,)).item(), + torch.randint(height - bound_height, height, size=(1,)).item(), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] @@ -623,7 +622,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: while True: # sample an option - idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) + idx = torch.randint(low=0, high=len(self.options), size=(1,)).item() min_jaccard_overlap = self.options[idx] if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option return dict() From 5f7e1ee237a2b534bb96584465773264a4d499dc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 13:51:10 +0100 Subject: [PATCH 07/15] Fixing linter and other minor fixes. --- torchvision/prototype/transforms/_geometry.py | 4 ++-- torchvision/prototype/transforms/_type_conversion.py | 8 ++++---- torchvision/prototype/transforms/_utils.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index a4db1c775fa..27b1cf0aad2 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -522,7 +522,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: torch.randint(height - bound_height, height, size=(1,)).item(), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] - endpoints = [topleft, topright, botright, botleft] + endpoints: List[List[int]] = [topleft, topright, botright, botleft] # type: ignore[list-item] perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints) return dict(perspective_coeffs=perspective_coeffs) @@ -622,7 +622,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: while True: # sample an option - idx = torch.randint(low=0, high=len(self.options), size=(1,)).item() + idx: int = torch.randint(low=0, high=len(self.options), size=(1,)).item() # type: ignore[assignment] min_jaccard_overlap = self.options[idx] if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option return dict() diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index d4ee7387126..d0b11d53a8f 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -1,4 +1,4 @@ -from typing import Any, cast, Dict, Optional, Union +from typing import Any, Dict, Optional, Union import numpy as np import PIL.Image @@ -13,7 +13,7 @@ class DecodeImage(Transform): _transformed_types = (features.EncodedImage,) def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image: - return cast(features.Image, F.decode_image_with_pil(inpt)) + return F.decode_image_with_pil(inpt) # type: ignore[no-any-return] class LabelToOneHot(Transform): @@ -27,7 +27,7 @@ def _transform(self, inpt: features.Label, params: Dict[str, Any]) -> features.O num_categories = self.num_categories if num_categories == -1 and inpt.categories is not None: num_categories = len(inpt.categories) - output = one_hot(inpt, num_classes=num_categories) + output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) return features.OneHotLabel(output, categories=inpt.categories) def extra_repr(self) -> str: @@ -50,7 +50,7 @@ class ToImageTensor(Transform): def _transform( self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] ) -> features.Image: - return cast(features.Image, F.to_image_tensor(inpt)) + return F.to_image_tensor(inpt) # type: ignore[no-any-return] class ToImagePIL(Transform): diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 55cce2f6e36..1ff2da1dd2b 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -97,7 +97,7 @@ def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox: - bounding_boxes = {inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)} + bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)] if not bounding_boxes: raise TypeError("No bounding box was found in the sample") elif len(bounding_boxes) > 1: From b0b9b5527216de23452f1147e4a0430b02bc255f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 14:48:56 +0100 Subject: [PATCH 08/15] Change mitigation for mypy.` --- torchvision/prototype/transforms/_geometry.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 27b1cf0aad2..0753006b402 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -506,23 +506,23 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: bound_height = int(distortion_scale * half_height) + 1 bound_width = int(distortion_scale * half_width) + 1 topleft = [ - torch.randint(0, bound_width, size=(1,)).item(), - torch.randint(0, bound_height, size=(1,)).item(), + int(torch.randint(0, bound_width, size=(1,))), + int(torch.randint(0, bound_height, size=(1,))), ] topright = [ - torch.randint(width - bound_width, width, size=(1,)).item(), - torch.randint(0, bound_height, size=(1,)).item(), + int(torch.randint(width - bound_width, width, size=(1,))), + int(torch.randint(0, bound_height, size=(1,))), ] botright = [ - torch.randint(width - bound_width, width, size=(1,)).item(), - torch.randint(height - bound_height, height, size=(1,)).item(), + int(torch.randint(width - bound_width, width, size=(1,))), + int(torch.randint(height - bound_height, height, size=(1,))), ] botleft = [ - torch.randint(0, bound_width, size=(1,)).item(), - torch.randint(height - bound_height, height, size=(1,)).item(), + int(torch.randint(0, bound_width, size=(1,))), + int(torch.randint(height - bound_height, height, size=(1,))), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] - endpoints: List[List[int]] = [topleft, topright, botright, botleft] # type: ignore[list-item] + endpoints = [topleft, topright, botright, botleft] perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints) return dict(perspective_coeffs=perspective_coeffs) @@ -622,7 +622,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: while True: # sample an option - idx: int = torch.randint(low=0, high=len(self.options), size=(1,)).item() # type: ignore[assignment] + idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) min_jaccard_overlap = self.options[idx] if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option return dict() From ee3196904253f46548a20bc4f2b8cbb8363d7f4a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 15:03:36 +0100 Subject: [PATCH 09/15] Fixing the tests. --- test/test_prototype_transforms_consistency.py | 2 +- torchvision/prototype/transforms/_utils.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index a23783b0037..43e1891af7a 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -1040,7 +1040,7 @@ def check(self, t, t_ref, data_kwargs=None): seg_transforms.RandomCrop(size=480), prototype_transforms.Compose( [ - PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})), + PadIfSmaller(size=480, fill={features.Mask: 255, features.Image: 0, PIL.Image.Image: 0, torch.Tensor: 0}), prototype_transforms.RandomCrop(size=480), ] ), diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 1ff2da1dd2b..c1176771ef4 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -33,13 +33,15 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: - if isinstance(fill, dict): + if type(fill) == dict: + # Do exact type check to avoid accepting default dicts from the user. DefaultDict values can be verified only + # at runtime not at construction type. for key, value in fill.items(): # Check key for type _check_fill_arg(value) else: if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate fill arg") + raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.") T = TypeVar("T") @@ -73,10 +75,9 @@ def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, F _check_fill_arg(fill) if isinstance(fill, dict): - fill_copy = {} for k, v in fill.items(): - fill_copy[k] = _convert_fill_arg(v) - return fill_copy + fill[k] = _convert_fill_arg(v) + return fill return _get_defaultdict(_convert_fill_arg(fill)) From 88328f5a81a3b4aa9585a1b523e19ac80d32f96e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 15:15:24 +0100 Subject: [PATCH 10/15] Fixing the tests. --- test/test_prototype_transforms_consistency.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 43e1891af7a..ca28de173bd 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -1040,7 +1040,9 @@ def check(self, t, t_ref, data_kwargs=None): seg_transforms.RandomCrop(size=480), prototype_transforms.Compose( [ - PadIfSmaller(size=480, fill={features.Mask: 255, features.Image: 0, PIL.Image.Image: 0, torch.Tensor: 0}), + PadIfSmaller( + size=480, fill={features.Mask: 255, features.Image: 0, PIL.Image.Image: 0, torch.Tensor: 0} + ), prototype_transforms.RandomCrop(size=480), ] ), From e15f53639eeaa9b7acdebf2ca2878a9c409e95a6 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 15:50:50 +0100 Subject: [PATCH 11/15] Fix linter --- test/test_prototype_transforms_consistency.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index ca28de173bd..5b1dcef6f23 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -2,7 +2,6 @@ import inspect import random import re -from collections import defaultdict from importlib.machinery import SourceFileLoader from pathlib import Path From 53f12bb77b6248cdbd66c378b14de2a570e7e311 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 15:52:15 +0100 Subject: [PATCH 12/15] Restore dict copy. --- torchvision/prototype/transforms/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index c1176771ef4..bfa7ac39891 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -75,9 +75,10 @@ def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, F _check_fill_arg(fill) if isinstance(fill, dict): + fill_copy = {} for k, v in fill.items(): - fill[k] = _convert_fill_arg(v) - return fill + fill_copy[k] = _convert_fill_arg(v) + return fill_copy return _get_defaultdict(_convert_fill_arg(fill)) From 843bcc960a05899ea090aabcc46f6dcc6b9e0068 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 16:39:06 +0100 Subject: [PATCH 13/15] Handling of defaultdicts --- test/test_prototype_transforms_consistency.py | 5 ++--- torchvision/prototype/transforms/_utils.py | 16 ++++++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 5b1dcef6f23..a23783b0037 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -2,6 +2,7 @@ import inspect import random import re +from collections import defaultdict from importlib.machinery import SourceFileLoader from pathlib import Path @@ -1039,9 +1040,7 @@ def check(self, t, t_ref, data_kwargs=None): seg_transforms.RandomCrop(size=480), prototype_transforms.Compose( [ - PadIfSmaller( - size=480, fill={features.Mask: 255, features.Image: 0, PIL.Image.Image: 0, torch.Tensor: 0} - ), + PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})), prototype_transforms.RandomCrop(size=480), ] ), diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index bfa7ac39891..2272396f766 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -33,12 +33,13 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: - if type(fill) == dict: - # Do exact type check to avoid accepting default dicts from the user. DefaultDict values can be verified only - # at runtime not at construction type. + if isinstance(fill, dict): for key, value in fill.items(): # Check key for type _check_fill_arg(value) + if isinstance(fill, defaultdict) and callable(fill.default_factory): + default_value = fill.default_factory() + _check_fill_arg(default_value) else: if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.") @@ -75,10 +76,13 @@ def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, F _check_fill_arg(fill) if isinstance(fill, dict): - fill_copy = {} for k, v in fill.items(): - fill_copy[k] = _convert_fill_arg(v) - return fill_copy + fill[k] = _convert_fill_arg(v) + if isinstance(fill, defaultdict) and callable(fill.default_factory): + default_value = fill.default_factory() + sanitized_default = _convert_fill_arg(default_value) + fill.default_factory = functools.partial(_default_arg, sanitized_default) + return fill # type: ignore[return-value] return _get_defaultdict(_convert_fill_arg(fill)) From 8e6af8d696d482a62ce6fde5e11e86f1ff298016 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 16:42:12 +0100 Subject: [PATCH 14/15] restore int idiom --- torchvision/prototype/transforms/_geometry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 0753006b402..c5ab38d8418 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -448,12 +448,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: needs_pad = any(padding) needs_vert_crop, top = ( - (True, torch.randint(0, padded_height - cropped_height + 1, size=()).item()) + (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) if padded_height > cropped_height else (False, 0) ) needs_horz_crop, left = ( - (True, torch.randint(0, padded_width - cropped_width + 1, size=()).item()) + (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) if padded_width > cropped_width else (False, 0) ) From 11af094b930a378c6a91afc85f4e6c932203e750 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 26 Oct 2022 16:58:03 +0100 Subject: [PATCH 15/15] Update todo --- torchvision/prototype/transforms/_auto_augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 3f19fda67d0..3714fc13682 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -508,7 +508,7 @@ def forward(self, *inputs: Any) -> Any: mix.add_( # The multiplication below could become in-place provided `aug is not batch and aug.is_floating_point()` # Currently we can't do this because `aug` has to be `unint8` to support ops like `equalize`. - # TODO: change this once all ops in `F` support float inputs. + # TODO: change this once all ops in `F` support floats. https://github.com/pytorch/vision/issues/6840 combined_weights[:, i].reshape(batch_dims) * aug )