From 35b51131f608ae50d0fdcae20cc296fd2e048f0d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 28 Feb 2022 14:18:11 +0100 Subject: [PATCH 1/2] expand prototype functional scriptability tests --- test/test_prototype_transforms_functional.py | 53 ++++++++++++------- .../transforms/functional/_type_conversion.py | 4 +- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 4bfca28ae37..9daa4ccbc3a 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -199,21 +199,38 @@ def resize_bounding_box(): yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size) -class TestKernelsCommon: - @pytest.mark.parametrize("functional_info", FUNCTIONAL_INFOS, ids=lambda functional_info: functional_info.name) - def test_scriptable(self, functional_info): - jit.script(functional_info.functional) - - @pytest.mark.parametrize( - ("functional_info", "sample_input"), - [ - pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}") - for functional_info in FUNCTIONAL_INFOS - for idx, sample_input in enumerate(functional_info.sample_inputs()) - ], - ) - def test_eager_vs_scripted(self, functional_info, sample_input): - eager = functional_info(sample_input) - scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs) - - torch.testing.assert_close(eager, scripted) +@pytest.mark.parametrize( + "kernel", + [ + pytest.param(kernel, id=name) + for name, kernel in F.__dict__.items() + if not name.startswith("_") + and callable(kernel) + and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"}) + and "pil" not in name + and ( + name + not in { + "get_image_size", + "get_image_num_channels", + } + ) + ], +) +def test_scriptable(kernel): + jit.script(kernel) + + +@pytest.mark.parametrize( + ("functional_info", "sample_input"), + [ + pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}") + for functional_info in FUNCTIONAL_INFOS + for idx, sample_input in enumerate(functional_info.sample_inputs()) + ], +) +def test_eager_vs_scripted(functional_info, sample_input): + eager = functional_info(sample_input) + scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs) + + torch.testing.assert_close(eager, scripted) diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py index 09cb61b8a21..06b2daaf6f1 100644 --- a/torchvision/prototype/transforms/functional/_type_conversion.py +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -1,5 +1,5 @@ import unittest.mock -from typing import Dict, Any, Tuple, cast +from typing import Dict, Any, Tuple import numpy as np import PIL.Image @@ -22,4 +22,4 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor: - return cast(torch.Tensor, one_hot(label, num_classes=num_categories)) + return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return] From 347ed617ff4537989860b0f112063475cc963cda Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Mar 2022 13:26:06 +0100 Subject: [PATCH 2/2] remove obsolete skips --- test/test_prototype_transforms_functional.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 9daa4ccbc3a..409a855e23f 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -208,13 +208,6 @@ def resize_bounding_box(): and callable(kernel) and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"}) and "pil" not in name - and ( - name - not in { - "get_image_size", - "get_image_num_channels", - } - ) ], ) def test_scriptable(kernel):