From 044d1200df040dd33543a0f44b82dcb00d695017 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 15 Sep 2022 16:42:08 +0200 Subject: [PATCH 1/3] add segmentation reference consistency tests --- test/prototype_common_utils.py | 3 +- test/test_prototype_transforms_consistency.py | 120 +++++++++++++++++- 2 files changed, 116 insertions(+), 7 deletions(-) diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 297b103248f..297fdb5a179 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -68,7 +68,8 @@ def __init__( def _process_inputs(self, actual, expected, *, id, allow_subclasses): actual, expected = [ - to_image_tensor(input) if not isinstance(input, torch.Tensor) else input for input in [actual, expected] + to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input) + for input in [actual, expected] ] return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index fac2eb0bd94..b1f65bf6476 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -1,5 +1,6 @@ import enum import inspect +import random from importlib.machinery import SourceFileLoader from pathlib import Path @@ -16,6 +17,7 @@ make_image, make_images, make_label, + make_segmentation_mask, ) from torchvision import transforms as legacy_transforms from torchvision._utils import sequence_to_str @@ -852,10 +854,12 @@ def test_aa(self, inpt, interpolation): assert_equal(expected_output, output) -# Import reference detection transforms here for consistency checks -# torchvision/references/detection/transforms.py -ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py" -det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module() +def import_transforms_from_references(reference): + ref_det_filepath = Path(__file__).parent.parent / "references" / reference / "transforms.py" + return SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module() + + +det_transforms = import_transforms_from_references("detection") class TestRefDetTransforms: @@ -873,7 +877,7 @@ def make_datapoints(self, with_mask=True): yield (pil_image, target) - tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8) + tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB)) target = { "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -883,7 +887,7 @@ def make_datapoints(self, with_mask=True): yield (tensor_image, target) - feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)) + feature_image = make_image(size=size, color_space=features.ColorSpace.RGB) target = { "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -927,3 +931,107 @@ def test_transform(self, t_ref, t, data_kwargs): expected_output = t_ref(*dp) assert_equal(expected_output, output) + + +seg_transforms = import_transforms_from_references("segmentation") + + +class TestSegDetTransforms: + def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): + size = (256, 640) + num_categories = 21 + + conv_fns = [] + if supports_pil: + conv_fns.append(to_image_pil) + conv_fns.extend([torch.Tensor, lambda x: x]) + + for conv_fn in conv_fns: + feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype) + feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) + + dp = (conv_fn(feature_image), feature_mask) + dp_ref = ( + to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image), + to_image_pil(feature_mask), + ) + + yield dp, dp_ref + + def set_seed(self, seed): + torch.manual_seed(seed) + random.seed(seed) + + def check(self, t, t_ref, data_kwargs=None): + for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()): + + self.set_seed(12) + output = t(dp) + + self.set_seed(12) + expected_output = t_ref(*dp_ref) + + assert_equal(output, expected_output) + + @pytest.mark.parametrize( + ("t_ref", "t", "data_kwargs"), + [ + ( + seg_transforms.RandomHorizontalFlip(flip_prob=1.0), + prototype_transforms.RandomHorizontalFlip(p=1.0), + dict(), + ), + ( + seg_transforms.RandomHorizontalFlip(flip_prob=0.0), + prototype_transforms.RandomHorizontalFlip(p=0.0), + dict(), + ), + # ( + # seg_transforms.RandomCrop(size=480), + # prototype_transforms.RandomCrop( + # size=480, pad_if_needed=True, fill=defaultdict(lambda: 0, {features.Mask: 255}) + # ), + # dict(), + # ), + ( + seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + dict(supports_pil=False, image_dtype=torch.float), + ), + ], + ) + def test_common(self, t_ref, t, data_kwargs): + self.check(t, t_ref, data_kwargs) + + def test_random_resize_train(self, mocker): + base_size = 520 + min_size = base_size // 2 + max_size = base_size * 2 + + randint = torch.randint + + def patched_randint(a, b, *other_args, **kwargs): + if kwargs or len(other_args) > 1 or other_args[0] != (): + return randint(a, b, *other_args, **kwargs) + + return random.randint(a, b) + + t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) + mocker.patch( + "torchvision.prototype.transforms._geometry.torch.randint", + new=patched_randint, + ) + + t_ref = det_transforms.RandomResize(min_size=min_size, max_size=max_size) + + self.check(t, t_ref) + + def test_random_resize_eval(self): + torch.manual_seed(0) + base_size = 520 + + t = prototype_transforms.Resize(size=base_size, antialias=True) + + t_ref = det_transforms.RandomResize(min_size=base_size, max_size=base_size) + + self.check(t, t_ref) From 7a9eb0cbfb81e54f34bba9fb08d36b93d748cfa1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 19 Sep 2022 10:27:29 +0200 Subject: [PATCH 2/3] fall back to smoke tests for resize --- test/test_prototype_transforms_consistency.py | 47 +++++++++++++++---- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index b1f65bf6476..75fc571edad 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -936,7 +936,7 @@ def test_transform(self, t_ref, t, data_kwargs): seg_transforms = import_transforms_from_references("segmentation") -class TestSegDetTransforms: +class TestRefSegTransforms: def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): size = (256, 640) num_categories = 21 @@ -958,17 +958,17 @@ def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): yield dp, dp_ref - def set_seed(self, seed): + def set_seed(self, seed=12): torch.manual_seed(seed) random.seed(seed) def check(self, t, t_ref, data_kwargs=None): for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()): - self.set_seed(12) + self.set_seed() output = t(dp) - self.set_seed(12) + self.set_seed() expected_output = t_ref(*dp_ref) assert_equal(output, expected_output) @@ -1003,6 +1003,33 @@ def check(self, t, t_ref, data_kwargs=None): def test_common(self, t_ref, t, data_kwargs): self.check(t, t_ref, data_kwargs) + def check_resize(self, mocker, t_ref, t): + mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") + mock_ref = mocker.patch("torchvision.transforms.functional.resize") + + for dp, dp_ref in self.make_datapoints(): + mock.reset_mock() + mock_ref.reset_mock() + + self.set_seed() + t(dp) + assert mock.call_count == 2 + assert all( + actual is expected + for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp) + ) + + self.set_seed() + t_ref(*dp_ref) + assert mock_ref.call_count == 2 + assert all( + actual is expected + for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref) + ) + + for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list): + assert args_kwargs[0][1] == [args_kwargs_ref[0][1]] + def test_random_resize_train(self, mocker): base_size = 520 min_size = base_size // 2 @@ -1016,22 +1043,24 @@ def patched_randint(a, b, *other_args, **kwargs): return random.randint(a, b) + # We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported + # normally t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) mocker.patch( "torchvision.prototype.transforms._geometry.torch.randint", new=patched_randint, ) - t_ref = det_transforms.RandomResize(min_size=min_size, max_size=max_size) + t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size) - self.check(t, t_ref) + self.check_resize(mocker, t_ref, t) - def test_random_resize_eval(self): + def test_random_resize_eval(self, mocker): torch.manual_seed(0) base_size = 520 t = prototype_transforms.Resize(size=base_size, antialias=True) - t_ref = det_transforms.RandomResize(min_size=base_size, max_size=base_size) + t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size) - self.check(t, t_ref) + self.check_resize(mocker, t_ref, t) From 318b15c9bb24888aeba690ac65eeec465b17c03c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 20 Sep 2022 14:11:52 +0200 Subject: [PATCH 3/3] add test for RandomCrop --- test/test_prototype_transforms_consistency.py | 47 +++++++++++++++---- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 75fc571edad..9e2e3051189 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -1,6 +1,7 @@ import enum import inspect import random +from collections import defaultdict from importlib.machinery import SourceFileLoader from pathlib import Path @@ -22,9 +23,10 @@ from torchvision import transforms as legacy_transforms from torchvision._utils import sequence_to_str from torchvision.prototype import features, transforms as prototype_transforms +from torchvision.prototype.transforms import functional as F +from torchvision.prototype.transforms._utils import query_chw from torchvision.prototype.transforms.functional import to_image_pil - DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) @@ -936,6 +938,32 @@ def test_transform(self, t_ref, t, data_kwargs): seg_transforms = import_transforms_from_references("segmentation") +# We need this transform for two reasons: +# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name +# counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True` +# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size. +class PadIfSmaller(prototype_transforms.Transform): + def __init__(self, size, fill=0): + super().__init__() + self.size = size + self.fill = prototype_transforms._geometry._setup_fill_arg(fill) + + def _get_params(self, sample): + _, height, width = query_chw(sample) + padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] + needs_padding = any(padding) + return dict(padding=padding, needs_padding=needs_padding) + + def _transform(self, inpt, params): + if not params["needs_padding"]: + return inpt + + fill = self.fill[type(inpt)] + fill = F._geometry._convert_fill_arg(fill) + + return F.pad(inpt, padding=params["padding"], fill=fill) + + class TestRefSegTransforms: def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): size = (256, 640) @@ -986,13 +1014,16 @@ def check(self, t, t_ref, data_kwargs=None): prototype_transforms.RandomHorizontalFlip(p=0.0), dict(), ), - # ( - # seg_transforms.RandomCrop(size=480), - # prototype_transforms.RandomCrop( - # size=480, pad_if_needed=True, fill=defaultdict(lambda: 0, {features.Mask: 255}) - # ), - # dict(), - # ), + ( + seg_transforms.RandomCrop(size=480), + prototype_transforms.Compose( + [ + PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})), + prototype_transforms.RandomCrop(size=480), + ] + ), + dict(), + ), ( seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),