diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index d395c224785..43a7df4f3a2 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,5 +1,3 @@ -import itertools - import re import PIL.Image @@ -19,7 +17,6 @@ from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.prototype import datapoints, transforms -from torchvision.transforms.v2._utils import _convert_fill_arg from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil from torchvision.transforms.v2.utils import check_type, is_simple_tensor @@ -187,66 +184,6 @@ def test__get_params(self, mocker): assert params["needs_pad"] assert any(pad > 0 for pad in params["padding"]) - @pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2))) - def test__transform(self, mocker, needs): - fill_sentinel = 12 - padding_mode_sentinel = mocker.MagicMock() - - transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel) - transform._transformed_types = (mocker.MagicMock,) - mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) - - needs_crop, needs_pad = needs - top_sentinel = mocker.MagicMock() - left_sentinel = mocker.MagicMock() - height_sentinel = mocker.MagicMock() - width_sentinel = mocker.MagicMock() - is_valid = mocker.MagicMock() if needs_crop else None - padding_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", - return_value=dict( - needs_crop=needs_crop, - top=top_sentinel, - left=left_sentinel, - height=height_sentinel, - width=width_sentinel, - is_valid=is_valid, - padding=padding_sentinel, - needs_pad=needs_pad, - ), - ) - - inpt_sentinel = mocker.MagicMock() - - mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop") - mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad") - transform(inpt_sentinel) - - if needs_crop: - mock_crop.assert_called_once_with( - inpt_sentinel, - top=top_sentinel, - left=left_sentinel, - height=height_sentinel, - width=width_sentinel, - ) - else: - mock_crop.assert_not_called() - - if needs_pad: - # If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use - # `MagicMock.assert_called_once_with` and have to perform the checks manually - mock_pad.assert_called_once() - args, kwargs = mock_pad.call_args - if not needs_crop: - assert args[0] is inpt_sentinel - assert args[1] is padding_sentinel - fill_sentinel = _convert_fill_arg(fill_sentinel) - assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) - else: - mock_pad.assert_not_called() - def test__transform_culling(self, mocker): batch_size = 10 canvas_size = (10, 10) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 5f4a9b62898..4db2abe7fc4 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -27,7 +27,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import datapoints from torchvision.ops.boxes import box_iou -from torchvision.transforms.functional import InterpolationMode, to_pil_image +from torchvision.transforms.functional import to_pil_image from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw @@ -419,46 +419,6 @@ def test_assertions(self): with pytest.raises(ValueError, match="Padding mode should be either"): transforms.Pad(12, padding_mode="abc") - @pytest.mark.parametrize("padding", [1, (1, 2), [1, 2, 3, 4]]) - @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) - @pytest.mark.parametrize("padding_mode", ["constant", "edge"]) - def test__transform(self, padding, fill, padding_mode, mocker): - transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode) - - fn = mocker.patch("torchvision.transforms.v2.functional.pad") - inpt = mocker.MagicMock(spec=datapoints.Image) - _ = transform(inpt) - - fill = transforms._utils._convert_fill_arg(fill) - if isinstance(padding, tuple): - padding = list(padding) - fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) - - @pytest.mark.parametrize("fill", [12, {datapoints.Image: 12, datapoints.Mask: 34}]) - def test__transform_image_mask(self, fill, mocker): - transform = transforms.Pad(1, fill=fill, padding_mode="constant") - - fn = mocker.patch("torchvision.transforms.v2.functional.pad") - image = datapoints.Image(torch.rand(3, 32, 32)) - mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32))) - inpt = [image, mask] - _ = transform(inpt) - - if isinstance(fill, int): - fill = transforms._utils._convert_fill_arg(fill) - calls = [ - mocker.call(image, padding=1, fill=fill, padding_mode="constant"), - mocker.call(mask, padding=1, fill=fill, padding_mode="constant"), - ] - else: - fill_img = transforms._utils._convert_fill_arg(fill[type(image)]) - fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)]) - calls = [ - mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"), - mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"), - ] - fn.assert_has_calls(calls) - class TestRandomZoomOut: def test_assertions(self): @@ -487,56 +447,6 @@ def test__get_params(self, fill, side_range): assert 0 <= params["padding"][2] <= (side_range[1] - 1) * w assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h - @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) - @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) - def test__transform(self, fill, side_range, mocker): - inpt = make_image((24, 32)) - - transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1) - - fn = mocker.patch("torchvision.transforms.v2.functional.pad") - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - torch.rand(1) # random apply changes random state - params = transform._get_params([inpt]) - - fill = transforms._utils._convert_fill_arg(fill) - fn.assert_called_once_with(inpt, **params, fill=fill) - - @pytest.mark.parametrize("fill", [12, {datapoints.Image: 12, datapoints.Mask: 34}]) - def test__transform_image_mask(self, fill, mocker): - transform = transforms.RandomZoomOut(fill=fill, p=1.0) - - fn = mocker.patch("torchvision.transforms.v2.functional.pad") - image = datapoints.Image(torch.rand(3, 32, 32)) - mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32))) - inpt = [image, mask] - - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - torch.rand(1) # random apply changes random state - params = transform._get_params(inpt) - - if isinstance(fill, int): - fill = transforms._utils._convert_fill_arg(fill) - calls = [ - mocker.call(image, **params, fill=fill), - mocker.call(mask, **params, fill=fill), - ] - else: - fill_img = transforms._utils._convert_fill_arg(fill[type(image)]) - fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)]) - calls = [ - mocker.call(image, **params, fill=fill_img), - mocker.call(mask, **params, fill=fill_mask), - ] - fn.assert_has_calls(calls) - class TestRandomCrop: def test_assertions(self): @@ -599,51 +509,6 @@ def test__get_params(self, padding, pad_if_needed, size): assert params["needs_pad"] is any(padding) assert params["padding"] == padding - @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) - @pytest.mark.parametrize("pad_if_needed", [False, True]) - @pytest.mark.parametrize("fill", [False, True]) - @pytest.mark.parametrize("padding_mode", ["constant", "edge"]) - def test__transform(self, padding, pad_if_needed, fill, padding_mode, mocker): - output_size = [10, 12] - transform = transforms.RandomCrop( - output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode - ) - - h, w = size = (32, 32) - inpt = make_image(size) - - if isinstance(padding, int): - new_size = (h + padding, w + padding) - elif isinstance(padding, list): - new_size = (h + sum(padding[0::2]), w + sum(padding[1::2])) - else: - new_size = size - expected = make_image(new_size) - _ = mocker.patch("torchvision.transforms.v2.functional.pad", return_value=expected) - fn_crop = mocker.patch("torchvision.transforms.v2.functional.crop") - - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - params = transform._get_params([inpt]) - if padding is None and not pad_if_needed: - fn_crop.assert_called_once_with( - inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1] - ) - elif not pad_if_needed: - fn_crop.assert_called_once_with( - expected, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1] - ) - elif padding is None: - # vfdev-5: I do not know how to mock and test this case - pass - else: - # vfdev-5: I do not know how to mock and test this case - pass - class TestGaussianBlur: def test_assertions(self): @@ -675,62 +540,6 @@ def test__get_params(self, sigma): assert sigma[0] <= params["sigma"][0] <= sigma[1] assert sigma[0] <= params["sigma"][1] <= sigma[1] - @pytest.mark.parametrize("kernel_size", [3, [3, 5], (5, 3)]) - @pytest.mark.parametrize("sigma", [2.0, [2.0, 3.0]]) - def test__transform(self, kernel_size, sigma, mocker): - transform = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma) - - if isinstance(kernel_size, (tuple, list)): - assert transform.kernel_size == kernel_size - else: - kernel_size = (kernel_size, kernel_size) - assert transform.kernel_size == kernel_size - - if isinstance(sigma, (tuple, list)): - assert transform.sigma == sigma - else: - assert transform.sigma == [sigma, sigma] - - fn = mocker.patch("torchvision.transforms.v2.functional.gaussian_blur") - inpt = mocker.MagicMock(spec=datapoints.Image) - inpt.num_channels = 3 - inpt.canvas_size = (24, 32) - - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - params = transform._get_params([inpt]) - - fn.assert_called_once_with(inpt, kernel_size, **params) - - -class TestRandomColorOp: - @pytest.mark.parametrize("p", [0.0, 1.0]) - @pytest.mark.parametrize( - "transform_cls, func_op_name, kwargs", - [ - (transforms.RandomEqualize, "equalize", {}), - (transforms.RandomInvert, "invert", {}), - (transforms.RandomAutocontrast, "autocontrast", {}), - (transforms.RandomPosterize, "posterize", {"bits": 4}), - (transforms.RandomSolarize, "solarize", {"threshold": 0.5}), - (transforms.RandomAdjustSharpness, "adjust_sharpness", {"sharpness_factor": 0.5}), - ], - ) - def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker): - transform = transform_cls(p=p, **kwargs) - - fn = mocker.patch(f"torchvision.transforms.v2.functional.{func_op_name}") - inpt = mocker.MagicMock(spec=datapoints.Image) - _ = transform(inpt) - if p > 0.0: - fn.assert_called_once_with(inpt, **kwargs) - else: - assert fn.call_count == 0 - class TestRandomPerspective: def test_assertions(self): @@ -751,28 +560,6 @@ def test__get_params(self): assert "coefficients" in params assert len(params["coefficients"]) == 8 - @pytest.mark.parametrize("distortion_scale", [0.1, 0.7]) - def test__transform(self, distortion_scale, mocker): - interpolation = InterpolationMode.BILINEAR - fill = 12 - transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation) - - fn = mocker.patch("torchvision.transforms.v2.functional.perspective") - - inpt = make_image((24, 32)) - - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - torch.rand(1) # random apply changes random state - params = transform._get_params([inpt]) - - fill = transforms._utils._convert_fill_arg(fill) - fn.assert_called_once_with(inpt, None, None, **params, fill=fill, interpolation=interpolation) - class TestElasticTransform: def test_assertions(self): @@ -813,35 +600,6 @@ def test__get_params(self): assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all() assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all() - @pytest.mark.parametrize("alpha", [5.0, [5.0, 10.0]]) - @pytest.mark.parametrize("sigma", [2.0, [2.0, 5.0]]) - def test__transform(self, alpha, sigma, mocker): - interpolation = InterpolationMode.BILINEAR - fill = 12 - transform = transforms.ElasticTransform(alpha, sigma=sigma, fill=fill, interpolation=interpolation) - - if isinstance(alpha, float): - assert transform.alpha == [alpha, alpha] - else: - assert transform.alpha == alpha - - if isinstance(sigma, float): - assert transform.sigma == [sigma, sigma] - else: - assert transform.sigma == sigma - - fn = mocker.patch("torchvision.transforms.v2.functional.elastic") - inpt = mocker.MagicMock(spec=datapoints.Image) - inpt.num_channels = 3 - inpt.canvas_size = (24, 32) - - # Let's mock transform._get_params to control the output: - transform._get_params = mocker.MagicMock() - _ = transform(inpt) - params = transform._get_params([inpt]) - fill = transforms._utils._convert_fill_arg(fill) - fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) - class TestRandomErasing: def test_assertions(self): @@ -889,40 +647,6 @@ def test__get_params(self, value): assert 0 <= i <= height - h assert 0 <= j <= width - w - @pytest.mark.parametrize("p", [0, 1]) - def test__transform(self, mocker, p): - transform = transforms.RandomErasing(p=p) - transform._transformed_types = (mocker.MagicMock,) - - i_sentinel = mocker.MagicMock() - j_sentinel = mocker.MagicMock() - h_sentinel = mocker.MagicMock() - w_sentinel = mocker.MagicMock() - v_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.transforms.v2._augment.RandomErasing._get_params", - return_value=dict(i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel), - ) - - inpt_sentinel = mocker.MagicMock() - - mock = mocker.patch("torchvision.transforms.v2._augment.F.erase") - output = transform(inpt_sentinel) - - if p: - mock.assert_called_once_with( - inpt_sentinel, - i=i_sentinel, - j=j_sentinel, - h=h_sentinel, - w=w_sentinel, - v=v_sentinel, - inplace=transform.inplace, - ) - else: - mock.assert_not_called() - assert output is inpt_sentinel - class TestTransform: @pytest.mark.parametrize( @@ -1111,23 +835,12 @@ def test__transform(self, mocker): sample = [image, bboxes, masks] - fn = mocker.patch("torchvision.transforms.v2.functional.crop", side_effect=lambda x, **params: x) is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool) params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area) transform._get_params = mocker.MagicMock(return_value=params) output = transform(sample) - assert fn.call_count == 3 - - expected_calls = [ - mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), - mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), - mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), - ] - - fn.assert_has_calls(expected_calls) - # check number of bboxes vs number of labels: output_bboxes = output[1] assert isinstance(output_bboxes, datapoints.BoundingBoxes) @@ -1164,29 +877,6 @@ def test__get_params(self): assert int(canvas_size[0] * r_min) <= height <= int(canvas_size[0] * r_max) assert int(canvas_size[1] * r_min) <= width <= int(canvas_size[1] * r_max) - def test__transform(self, mocker): - interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode) - antialias_sentinel = mocker.MagicMock() - - transform = transforms.ScaleJitter( - target_size=(16, 12), interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - transform._transformed_types = (mocker.MagicMock,) - - size_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.transforms.v2._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel) - ) - - inpt_sentinel = mocker.MagicMock() - - mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize") - transform(inpt_sentinel) - - mock.assert_called_once_with( - inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - class TestRandomShortestSize: @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)]) @@ -1211,30 +901,6 @@ def test__get_params(self, min_size, max_size): else: assert shorter in min_size - def test__transform(self, mocker): - interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode) - antialias_sentinel = mocker.MagicMock() - - transform = transforms.RandomShortestSize( - min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - transform._transformed_types = (mocker.MagicMock,) - - size_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.transforms.v2._geometry.RandomShortestSize._get_params", - return_value=dict(size=size_sentinel), - ) - - inpt_sentinel = mocker.MagicMock() - - mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize") - transform(inpt_sentinel) - - mock.assert_called_once_with( - inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - class TestLinearTransformation: def test_assertions(self): @@ -1260,7 +926,7 @@ def test__transform(self, inpt): transform = transforms.LinearTransformation(m, v) if isinstance(inpt, PIL.Image.Image): - with pytest.raises(TypeError, match="LinearTransformation does not work on PIL Images"): + with pytest.raises(TypeError, match="does not support PIL images"): transform(inpt) else: output = transform(inpt) @@ -1284,30 +950,6 @@ def test__get_params(self): assert min_size <= size < max_size - def test__transform(self, mocker): - interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode) - antialias_sentinel = mocker.MagicMock() - - transform = transforms.RandomResize( - min_size=-1, max_size=-1, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - transform._transformed_types = (mocker.MagicMock,) - - size_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.transforms.v2._geometry.RandomResize._get_params", - return_value=dict(size=size_sentinel), - ) - - inpt_sentinel = mocker.MagicMock() - - mock_resize = mocker.patch("torchvision.transforms.v2._geometry.F.resize") - transform(inpt_sentinel) - - mock_resize.assert_called_with( - inpt_sentinel, size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - class TestUniformTemporalSubsample: @pytest.mark.parametrize( diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index f5ea69279a1..bcab4355c54 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -1259,68 +1259,6 @@ def check(self, t, t_ref, data_kwargs=None): def test_common(self, t_ref, t, data_kwargs): self.check(t, t_ref, data_kwargs) - def check_resize(self, mocker, t_ref, t): - mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize") - mock_ref = mocker.patch("torchvision.transforms.functional.resize") - - for dp, dp_ref in self.make_datapoints(): - mock.reset_mock() - mock_ref.reset_mock() - - self.set_seed() - t(dp) - assert mock.call_count == 2 - assert all( - actual is expected - for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp) - ) - - self.set_seed() - t_ref(*dp_ref) - assert mock_ref.call_count == 2 - assert all( - actual is expected - for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref) - ) - - for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list): - assert args_kwargs[0][1] == [args_kwargs_ref[0][1]] - - def test_random_resize_train(self, mocker): - base_size = 520 - min_size = base_size // 2 - max_size = base_size * 2 - - randint = torch.randint - - def patched_randint(a, b, *other_args, **kwargs): - if kwargs or len(other_args) > 1 or other_args[0] != (): - return randint(a, b, *other_args, **kwargs) - - return random.randint(a, b) - - # We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported - # normally - t = v2_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) - mocker.patch( - "torchvision.transforms.v2._geometry.torch.randint", - new=patched_randint, - ) - - t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size) - - self.check_resize(mocker, t_ref, t) - - def test_random_resize_eval(self, mocker): - torch.manual_seed(0) - base_size = 520 - - t = v2_transforms.Resize(size=base_size, antialias=True) - - t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size) - - self.check_resize(mocker, t_ref, t) - @pytest.mark.parametrize( ("legacy_dispatcher", "name_only_params"), diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index c1a21b6346e..1e78c5ed6c5 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -39,7 +39,7 @@ from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.v2 import functional as F -from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal +from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal @pytest.fixture(autouse=True) @@ -376,35 +376,6 @@ def transform(bbox): return torch.stack([transform(b) for b in bounding_boxes.reshape(-1, 4).unbind()]).reshape(bounding_boxes.shape) -@pytest.mark.parametrize( - ("dispatcher", "registered_input_types"), - [(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()], -) -def test_exhaustive_kernel_registration(dispatcher, registered_input_types): - missing = { - torch.Tensor, - PIL.Image.Image, - datapoints.Image, - datapoints.BoundingBoxes, - datapoints.Mask, - datapoints.Video, - } - registered_input_types - if missing: - names = sorted(str(t) for t in missing) - raise AssertionError( - "\n".join( - [ - f"The dispatcher '{dispatcher.__name__}' has no kernel registered for", - "", - *[f"- {name}" for name in names], - "", - f"If available, register the kernels with @_register_kernel_internal({dispatcher.__name__}, ...).", - f"If not, register explicit no-ops with @_register_explicit_noop({', '.join(names)})", - ] - ) - ) - - class TestResize: INPUT_SIZE = (17, 11) OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)] @@ -2128,9 +2099,20 @@ def test_errors(self): with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"): F.register_kernel(F.resize, object) - with pytest.raises(ValueError, match="already has a kernel registered for type"): + with pytest.raises(ValueError, match="cannot be registered for the builtin datapoint classes"): F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) + class CustomDatapoint(datapoints.Datapoint): + pass + + def resize_custom_datapoint(): + pass + + F.register_kernel(F.resize, CustomDatapoint)(resize_custom_datapoint) + + with pytest.raises(ValueError, match="already has a kernel registered for type"): + F.register_kernel(F.resize, CustomDatapoint)(resize_custom_datapoint) + class TestGetKernel: # We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination @@ -2152,13 +2134,7 @@ class MyPILImage(PIL.Image.Image): pass for input_type in [str, int, object, MyTensor, MyPILImage]: - with pytest.raises( - TypeError, - match=( - "supports inputs of type torch.Tensor, PIL.Image.Image, " - "and subclasses of torchvision.datapoints.Datapoint" - ), - ): + with pytest.raises(TypeError, match="supports inputs of type"): _get_kernel(F.resize, input_type) def test_exact_match(self): @@ -2211,8 +2187,8 @@ def test_datapoint_subclass(self): class MyDatapoint(datapoints.Datapoint): pass - # Note that this will be an error in the future - assert _get_kernel(F.resize, MyDatapoint) is _noop + with pytest.raises(TypeError, match="supports inputs of type"): + _get_kernel(F.resize, MyDatapoint) def resize_my_datapoint(): pass diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 1a2802db0ac..fe2e8df47eb 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -101,7 +101,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_crop"]: - inpt = F.crop( + inpt = self._call_kernel( + F.crop, inpt, top=params["top"], left=params["left"], @@ -120,6 +121,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = _get_fill(self._fill, type(inpt)) - inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) + inpt = self._call_kernel(F.pad, inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) return inpt diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 87a43b118ce..9be7a40e8ca 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Tuple import PIL.Image import torch @@ -91,6 +91,14 @@ def __init__( self._log_ratio = torch.log(torch.tensor(self.ratio)) + def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): + warnings.warn( + f"{type(self).__name__}() is currently passing through inputs of type " + f"datapoints.{type(inpt).__name__}. This will likely change in the future." + ) + return super()._call_kernel(dispatcher, inpt, *args, **kwargs) + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: img_c, img_h, img_w = query_chw(flat_inputs) @@ -131,7 +139,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["v"] is not None: - inpt = F.erase(inpt, **params, inplace=self.inplace) + inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace) return inpt diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 90e3ce2ff2c..a3792797959 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -1,13 +1,12 @@ import collections.abc from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -import PIL.Image import torch -from torchvision import datapoints, transforms as _transforms +from torchvision import transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform from ._transform import _RandomApplyTransform -from .utils import is_simple_tensor, query_chw +from .utils import query_chw class Grayscale(Transform): @@ -24,19 +23,12 @@ class Grayscale(Transform): _v1_transform_cls = _transforms.Grayscale - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - def __init__(self, num_output_channels: int = 1): super().__init__() self.num_output_channels = num_output_channels def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) + return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels) class RandomGrayscale(_RandomApplyTransform): @@ -55,13 +47,6 @@ class RandomGrayscale(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomGrayscale - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - def __init__(self, p: float = 0.1) -> None: super().__init__(p=p) @@ -70,7 +55,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(num_input_channels=num_input_channels) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) + return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) class ColorJitter(Transform): @@ -167,13 +152,13 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: hue_factor = params["hue_factor"] for fn_id in params["fn_idx"]: if fn_id == 0 and brightness_factor is not None: - output = F.adjust_brightness(output, brightness_factor=brightness_factor) + output = self._call_kernel(F.adjust_brightness, output, brightness_factor=brightness_factor) elif fn_id == 1 and contrast_factor is not None: - output = F.adjust_contrast(output, contrast_factor=contrast_factor) + output = self._call_kernel(F.adjust_contrast, output, contrast_factor=contrast_factor) elif fn_id == 2 and saturation_factor is not None: - output = F.adjust_saturation(output, saturation_factor=saturation_factor) + output = self._call_kernel(F.adjust_saturation, output, saturation_factor=saturation_factor) elif fn_id == 3 and hue_factor is not None: - output = F.adjust_hue(output, hue_factor=hue_factor) + output = self._call_kernel(F.adjust_hue, output, hue_factor=hue_factor) return output @@ -183,19 +168,12 @@ class RandomChannelPermutation(Transform): .. v2betastatus:: RandomChannelPermutation transform """ - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) return dict(permutation=torch.randperm(num_channels)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.permute_channels(inpt, params["permutation"]) + return self._call_kernel(F.permute_channels, inpt, params["permutation"]) class RandomPhotometricDistort(Transform): @@ -224,13 +202,6 @@ class RandomPhotometricDistort(Transform): Default is 0.5. """ - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - def __init__( self, brightness: Tuple[float, float] = (0.875, 1.125), @@ -263,17 +234,17 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["brightness_factor"] is not None: - inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"]) + inpt = self._call_kernel(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"]) if params["contrast_factor"] is not None and params["contrast_before"]: - inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) + inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) if params["saturation_factor"] is not None: - inpt = F.adjust_saturation(inpt, saturation_factor=params["saturation_factor"]) + inpt = self._call_kernel(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"]) if params["hue_factor"] is not None: - inpt = F.adjust_hue(inpt, hue_factor=params["hue_factor"]) + inpt = self._call_kernel(F.adjust_hue, inpt, hue_factor=params["hue_factor"]) if params["contrast_factor"] is not None and not params["contrast_before"]: - inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) + inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) if params["channel_permutation"] is not None: - inpt = F.permute_channels(inpt, permutation=params["channel_permutation"]) + inpt = self._call_kernel(F.permute_channels, inpt, permutation=params["channel_permutation"]) return inpt @@ -293,7 +264,7 @@ class RandomEqualize(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomEqualize def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.equalize(inpt) + return self._call_kernel(F.equalize, inpt) class RandomInvert(_RandomApplyTransform): @@ -312,7 +283,7 @@ class RandomInvert(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomInvert def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.invert(inpt) + return self._call_kernel(F.invert, inpt) class RandomPosterize(_RandomApplyTransform): @@ -337,7 +308,7 @@ def __init__(self, bits: int, p: float = 0.5) -> None: self.bits = bits def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.posterize(inpt, bits=self.bits) + return self._call_kernel(F.posterize, inpt, bits=self.bits) class RandomSolarize(_RandomApplyTransform): @@ -362,7 +333,7 @@ def __init__(self, threshold: float, p: float = 0.5) -> None: self.threshold = threshold def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.solarize(inpt, threshold=self.threshold) + return self._call_kernel(F.solarize, inpt, threshold=self.threshold) class RandomAutocontrast(_RandomApplyTransform): @@ -381,7 +352,7 @@ class RandomAutocontrast(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomAutocontrast def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.autocontrast(inpt) + return self._call_kernel(F.autocontrast, inpt) class RandomAdjustSharpness(_RandomApplyTransform): @@ -406,4 +377,4 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: self.sharpness_factor = sharpness_factor def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor) + return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 5c285056928..b209140614e 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union import PIL.Image import torch @@ -44,7 +44,7 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.horizontal_flip(inpt) + return self._call_kernel(F.horizontal_flip, inpt) class RandomVerticalFlip(_RandomApplyTransform): @@ -64,7 +64,7 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.vertical_flip(inpt) + return self._call_kernel(F.vertical_flip, inpt) class Resize(Transform): @@ -152,7 +152,8 @@ def __init__( self.antialias = antialias def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resize( + return self._call_kernel( + F.resize, inpt, self.size, interpolation=self.interpolation, @@ -186,7 +187,7 @@ def __init__(self, size: Union[int, Sequence[int]]): self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.center_crop(inpt, output_size=self.size) + return self._call_kernel(F.center_crop, inpt, output_size=self.size) class RandomResizedCrop(Transform): @@ -307,8 +308,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resized_crop( - inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias + return self._call_kernel( + F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias ) @@ -357,8 +358,16 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): + warnings.warn( + f"{type(self).__name__}() is currently passing through inputs of type " + f"datapoints.{type(inpt).__name__}. This will likely change in the future." + ) + return super()._call_kernel(dispatcher, inpt, *args, **kwargs) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.five_crop(inpt, self.size) + return self._call_kernel(F.five_crop, inpt, self.size) def _check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): @@ -396,12 +405,20 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.vertical_flip = vertical_flip + def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): + warnings.warn( + f"{type(self).__name__}() is currently passing through inputs of type " + f"datapoints.{type(inpt).__name__}. This will likely change in the future." + ) + return super()._call_kernel(dispatcher, inpt, *args, **kwargs) + def _check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) + return self._call_kernel(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip) class Pad(Transform): @@ -475,7 +492,7 @@ def __init__( def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] class RandomZoomOut(_RandomApplyTransform): @@ -545,7 +562,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return F.pad(inpt, **params, fill=fill) + return self._call_kernel(F.pad, inpt, **params, fill=fill) class RandomRotation(Transform): @@ -611,7 +628,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return F.rotate( + return self._call_kernel( + F.rotate, inpt, **params, interpolation=self.interpolation, @@ -733,7 +751,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return F.affine( + return self._call_kernel( + F.affine, inpt, **params, interpolation=self.interpolation, @@ -889,10 +908,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = _get_fill(self._fill, type(inpt)) - inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) + inpt = self._call_kernel(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) if params["needs_crop"]: - inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) + inpt = self._call_kernel( + F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] + ) return inpt @@ -973,7 +994,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return F.perspective( + return self._call_kernel( + F.perspective, inpt, None, None, @@ -1050,7 +1072,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # if kernel size is even we have to make it odd if kx % 2 == 0: kx += 1 - dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma)) + dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) dx = dx * self.alpha[0] / size[0] dy = torch.rand([1, 1] + size) * 2 - 1 @@ -1059,14 +1081,15 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # if kernel size is even we have to make it odd if ky % 2 == 0: ky += 1 - dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma)) + dy = self._call_kernel(F.gaussian_blur, dy, [ky, ky], list(self.sigma)) dy = dy * self.alpha[1] / size[1] displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return F.elastic( + return self._call_kernel( + F.elastic, inpt, **params, fill=fill, @@ -1164,7 +1187,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # check for any valid boxes with centers within the crop area xyxy_bboxes = F.convert_format_bounding_boxes( - bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY + bboxes.as_subclass(torch.Tensor), + bboxes.format, + datapoints.BoundingBoxFormat.XYXY, ) cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) @@ -1188,7 +1213,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if len(params) < 1: return inpt - output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) + output = self._call_kernel( + F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] + ) if isinstance(output, datapoints.BoundingBoxes): # We "mark" the invalid boxes as degenreate, and they can be @@ -1262,7 +1289,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) + return self._call_kernel( + F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias + ) class RandomShortestSize(Transform): @@ -1330,7 +1359,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) + return self._call_kernel( + F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias + ) class RandomResize(Transform): @@ -1400,4 +1431,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=[size]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias) + return self._call_kernel( + F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias + ) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index da71cebb416..780d9f99446 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -106,7 +106,7 @@ def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tenso def _check_inputs(self, sample: Any) -> Any: if has_any(sample, PIL.Image.Image): - raise TypeError("LinearTransformation does not work on PIL Images") + raise TypeError(f"{type(self).__name__}() does not support PIL images.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: shape = inpt.shape @@ -157,7 +157,6 @@ class Normalize(Transform): """ _v1_transform_cls = _transforms.Normalize - _transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video) def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): super().__init__() @@ -170,7 +169,7 @@ def _check_inputs(self, sample: Any) -> Any: raise TypeError(f"{type(self).__name__}() does not support PIL images.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) + return self._call_kernel(F.normalize, inpt, mean=self.mean, std=self.std, inplace=self.inplace) class GaussianBlur(Transform): @@ -217,7 +216,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(sigma=[sigma, sigma]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.gaussian_blur(inpt, self.kernel_size, **params) + return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params) class ToDtype(Transform): @@ -290,7 +289,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) return inpt - return F.to_dtype(inpt, dtype=dtype, scale=self.scale) + return self._call_kernel(F.to_dtype, inpt, dtype=dtype, scale=self.scale) class ConvertImageDtype(Transform): @@ -320,14 +319,12 @@ class ConvertImageDtype(Transform): _v1_transform_cls = _transforms.ConvertImageDtype - _transformed_types = (is_simple_tensor, datapoints.Image) - def __init__(self, dtype: torch.dtype = torch.float32) -> None: super().__init__() self.dtype = dtype def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.to_dtype(inpt, dtype=self.dtype, scale=True) + return self._call_kernel(F.to_dtype, inpt, dtype=self.dtype, scale=True) class SanitizeBoundingBoxes(Transform): diff --git a/torchvision/transforms/v2/_temporal.py b/torchvision/transforms/v2/_temporal.py index 591341e7cc7..df39cde0ecd 100644 --- a/torchvision/transforms/v2/_temporal.py +++ b/torchvision/transforms/v2/_temporal.py @@ -25,4 +25,4 @@ def __init__(self, num_samples: int): self.num_samples = num_samples def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.uniform_temporal_subsample(inpt, self.num_samples) + return self._call_kernel(F.uniform_temporal_subsample, inpt, self.num_samples) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index f83ed5d6e11..5a310ddbd4c 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,6 +11,8 @@ from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor from torchvision.utils import _log_api_usage_once +from .functional._utils import _get_kernel + class Transform(nn.Module): @@ -28,6 +30,10 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict() + def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + kernel = _get_kernel(dispatcher, type(inpt), allow_passthrough=True) + return kernel(inpt, *args, **kwargs) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: raise NotImplementedError diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 1497638f6b3..4a927be9777 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -5,10 +5,9 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal +from ._utils import _get_kernel, _register_kernel_internal -@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True) def erase( inpt: torch.Tensor, i: int, diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 9ba88d31b94..4c087965f6c 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -10,12 +10,10 @@ from torchvision.utils import _log_api_usage_once from ._misc import _num_value_bits, to_dtype_image_tensor - from ._type_conversion import pil_to_tensor, to_image_pil -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal +from ._utils import _get_kernel, _register_kernel_internal -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video) def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: if torch.jit.is_scripting(): return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) @@ -70,8 +68,8 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te return output if fp else output.to(image1.dtype) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor: + if torch.jit.is_scripting(): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) @@ -107,7 +105,6 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) @@ -146,7 +143,6 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) @@ -185,7 +181,6 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) @@ -258,7 +253,6 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) @@ -370,7 +364,6 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: return adjust_hue_image_tensor(video, hue_factor=hue_factor) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: if torch.jit.is_scripting(): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) @@ -410,7 +403,6 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor: if torch.jit.is_scripting(): return posterize_image_tensor(inpt, bits=bits) @@ -444,7 +436,6 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: return posterize_image_tensor(video, bits=bits) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor: if torch.jit.is_scripting(): return solarize_image_tensor(inpt, threshold=threshold) @@ -472,7 +463,6 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: return solarize_image_tensor(video, threshold=threshold) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def autocontrast(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): return autocontrast_image_tensor(inpt) @@ -522,7 +512,6 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: return autocontrast_image_tensor(video) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def equalize(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): return equalize_image_tensor(inpt) @@ -612,7 +601,6 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: return equalize_image_tensor(video) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def invert(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): return invert_image_tensor(inpt) @@ -643,7 +631,6 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: return invert_image_tensor(video) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor: """Permute the channels of the input according to the given permutation. diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 6416a143c03..f8f3b1da0b3 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -25,13 +25,7 @@ from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil -from ._utils import ( - _FillTypeJIT, - _get_kernel, - _register_explicit_noop, - _register_five_ten_crop_kernel, - _register_kernel_internal, -) +from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -2203,7 +2197,6 @@ def resized_crop_video( ) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True) def five_crop( inpt: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -2230,8 +2223,8 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: return size -@_register_five_ten_crop_kernel(five_crop, torch.Tensor) -@_register_five_ten_crop_kernel(five_crop, datapoints.Image) +@_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor) +@_register_five_ten_crop_kernel_internal(five_crop, datapoints.Image) def five_crop_image_tensor( image: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -2250,7 +2243,7 @@ def five_crop_image_tensor( return tl, tr, bl, br, center -@_register_five_ten_crop_kernel(five_crop, PIL.Image.Image) +@_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image) def five_crop_image_pil( image: PIL.Image.Image, size: List[int] ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: @@ -2269,14 +2262,13 @@ def five_crop_image_pil( return tl, tr, bl, br, center -@_register_five_ten_crop_kernel(five_crop, datapoints.Video) +@_register_five_ten_crop_kernel_internal(five_crop, datapoints.Video) def five_crop_video( video: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return five_crop_image_tensor(video, size) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True) def ten_crop( inpt: torch.Tensor, size: List[int], vertical_flip: bool = False ) -> Tuple[ @@ -2300,8 +2292,8 @@ def ten_crop( return kernel(inpt, size=size, vertical_flip=vertical_flip) -@_register_five_ten_crop_kernel(ten_crop, torch.Tensor) -@_register_five_ten_crop_kernel(ten_crop, datapoints.Image) +@_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor) +@_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Image) def ten_crop_image_tensor( image: torch.Tensor, size: List[int], vertical_flip: bool = False ) -> Tuple[ @@ -2328,7 +2320,7 @@ def ten_crop_image_tensor( return non_flipped + flipped -@_register_five_ten_crop_kernel(ten_crop, PIL.Image.Image) +@_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image) def ten_crop_image_pil( image: PIL.Image.Image, size: List[int], vertical_flip: bool = False ) -> Tuple[ @@ -2355,7 +2347,7 @@ def ten_crop_image_pil( return non_flipped + flipped -@_register_five_ten_crop_kernel(ten_crop, datapoints.Video) +@_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Video) def ten_crop_video( video: torch.Tensor, size: List[int], vertical_flip: bool = False ) -> Tuple[ diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index a7177ab04e9..82891b8cc8b 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -8,10 +8,9 @@ from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal, _register_unsupported_type, is_simple_tensor +from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor -@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) def get_dimensions(inpt: torch.Tensor) -> List[int]: if torch.jit.is_scripting(): return get_dimensions_image_tensor(inpt) @@ -44,7 +43,6 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]: return get_dimensions_image_tensor(video) -@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image_tensor(inpt) @@ -123,7 +121,6 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int] return list(bounding_box.canvas_size) -@_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask) def get_num_frames(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_frames_video(inpt) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index ec9c194d51d..658b61cedb0 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -11,11 +11,9 @@ from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, _register_unsupported_type +from ._utils import _get_kernel, _register_kernel_internal -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -@_register_unsupported_type(PIL.Image.Image) def normalize( inpt: torch.Tensor, mean: List[float], @@ -73,7 +71,6 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in return normalize_image_tensor(video, mean, std, inplace=inplace) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor: if torch.jit.is_scripting(): return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) @@ -182,7 +179,6 @@ def gaussian_blur_video( return gaussian_blur_image_tensor(video, kernel_size, sigma) -@_register_unsupported_type(PIL.Image.Image) def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: if torch.jit.is_scripting(): return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale) diff --git a/torchvision/transforms/v2/functional/_temporal.py b/torchvision/transforms/v2/functional/_temporal.py index 78dcfc1ef92..8edd66c6600 100644 --- a/torchvision/transforms/v2/functional/_temporal.py +++ b/torchvision/transforms/v2/functional/_temporal.py @@ -1,16 +1,12 @@ -import PIL.Image import torch from torchvision import datapoints from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal +from ._utils import _get_kernel, _register_kernel_internal -@_register_explicit_noop( - PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True -) def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Tensor: if torch.jit.is_scripting(): return uniform_temporal_subsample_video(inpt, num_samples=num_samples) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ce1c320a745..8c95828ee4d 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,5 +1,4 @@ import functools -import warnings from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union import torch @@ -53,6 +52,11 @@ def _name_to_dispatcher(name): ) from None +_BUILTIN_DATAPOINT_TYPES = { + obj for obj in datapoints.__dict__.values() if isinstance(obj, type) and issubclass(obj, datapoints.Datapoint) +} + + def register_kernel(dispatcher, datapoint_cls): """Decorate a kernel to register it for a dispatcher and a (custom) datapoint type. @@ -70,20 +74,19 @@ def register_kernel(dispatcher, datapoint_cls): f"but got {dispatcher}." ) - if not ( - isinstance(datapoint_cls, type) - and issubclass(datapoint_cls, datapoints.Datapoint) - and datapoint_cls is not datapoints.Datapoint - ): + if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)): raise ValueError( f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, " f"but got {datapoint_cls}." ) + if datapoint_cls in _BUILTIN_DATAPOINT_TYPES: + raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}") + return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) -def _get_kernel(dispatcher, input_type): +def _get_kernel(dispatcher, input_type, *, allow_passthrough=False): registry = _KERNEL_REGISTRY.get(dispatcher) if not registry: raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.") @@ -104,78 +107,18 @@ def _get_kernel(dispatcher, input_type): elif cls in registry: return registry[cls] - # Note that in the future we are not going to return a noop here, but rather raise the error below - return _noop + if allow_passthrough: + return lambda inpt, *args, **kwargs: inpt raise TypeError( - f"Dispatcher {dispatcher} supports inputs of type torch.Tensor, PIL.Image.Image, " - f"and subclasses of torchvision.datapoints.Datapoint, " + f"Dispatcher F.{dispatcher.__name__} supports inputs of type {registry.keys()}, " f"but got {input_type} instead." ) -# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate -# stage. See https://github.com/pytorch/vision/pull/7747#issuecomment-1661698450 for details. - - -# In the future, the default behavior will be to error on unsupported types in dispatchers. The noop behavior that we -# need for transforms will be handled by _get_kernel rather than actually registering no-ops on the dispatcher. -# Finally, the use case of preventing users from registering kernels for our builtin types will be handled inside -# register_kernel. -def _register_explicit_noop(*datapoints_classes, warn_passthrough=False): - """ - Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users - from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior. - - For example, without explicit no-op registration the following would be valid user code: - - .. code:: - from torchvision.transforms.v2 import functional as F - - @F.register_kernel(F.adjust_brightness, datapoints.BoundingBox) - def lol(...): - ... - """ - - def decorator(dispatcher): - for cls in datapoints_classes: - msg = ( - f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. " - f"This will likely change in the future." - ) - _register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)( - functools.partial(_noop, __msg__=msg if warn_passthrough else None) - ) - return dispatcher - - return decorator - - -def _noop(inpt, *args, __msg__=None, **kwargs): - if __msg__: - warnings.warn(__msg__, UserWarning, stacklevel=2) - return inpt - - -# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that -# to error later, this decorator can be removed, since the error will be raised by _get_kernel -def _register_unsupported_type(*input_types): - def kernel(inpt, *args, __dispatcher_name__, **kwargs): - raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.") - - def decorator(dispatcher): - for input_type in input_types: - _register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)( - functools.partial(kernel, __dispatcher_name__=dispatcher.__name__) - ) - return dispatcher - - return decorator - - # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop # We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool -def _register_five_ten_crop_kernel(dispatcher, input_type): +def _register_five_ten_crop_kernel_internal(dispatcher, input_type): registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) if input_type in registry: raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.")