diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 2bb98002e12..fac2eb0bd94 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -1,12 +1,22 @@ import enum import inspect +from importlib.machinery import SourceFileLoader +from pathlib import Path import numpy as np import PIL.Image import pytest import torch -from prototype_common_utils import ArgsKwargs, assert_equal, make_images +from prototype_common_utils import ( + ArgsKwargs, + assert_equal, + make_bounding_box, + make_detection_mask, + make_image, + make_images, + make_label, +) from torchvision import transforms as legacy_transforms from torchvision._utils import sequence_to_str from torchvision.prototype import features, transforms as prototype_transforms @@ -840,3 +850,80 @@ def test_aa(self, inpt, interpolation): output = t(inpt) assert_equal(expected_output, output) + + +# Import reference detection transforms here for consistency checks +# torchvision/references/detection/transforms.py +ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py" +det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module() + + +class TestRefDetTransforms: + def make_datapoints(self, with_mask=True): + size = (600, 800) + num_objects = 22 + + pil_image = to_image_pil(make_image(size=size, color_space=features.ColorSpace.RGB)) + target = { + "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), + "labels": make_label(extra_dims=(num_objects,), categories=80), + } + if with_mask: + target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long) + + yield (pil_image, target) + + tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8) + target = { + "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), + "labels": make_label(extra_dims=(num_objects,), categories=80), + } + if with_mask: + target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long) + + yield (tensor_image, target) + + feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)) + target = { + "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), + "labels": make_label(extra_dims=(num_objects,), categories=80), + } + if with_mask: + target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long) + + yield (feature_image, target) + + @pytest.mark.parametrize( + "t_ref, t, data_kwargs", + [ + (det_transforms.RandomHorizontalFlip(p=1.0), prototype_transforms.RandomHorizontalFlip(p=1.0), {}), + (det_transforms.RandomIoUCrop(), prototype_transforms.RandomIoUCrop(), {"with_mask": False}), + (det_transforms.RandomZoomOut(), prototype_transforms.RandomZoomOut(), {"with_mask": False}), + (det_transforms.ScaleJitter((1024, 1024)), prototype_transforms.ScaleJitter((1024, 1024)), {}), + ( + det_transforms.FixedSizeCrop((1024, 1024), fill=0), + prototype_transforms.FixedSizeCrop((1024, 1024), fill=0), + {}, + ), + ( + det_transforms.RandomShortestSize( + min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 + ), + prototype_transforms.RandomShortestSize( + min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 + ), + {}, + ), + ], + ) + def test_transform(self, t_ref, t, data_kwargs): + for dp in self.make_datapoints(**data_kwargs): + + # We should use prototype transform first as reference transform performs inplace target update + torch.manual_seed(12) + output = t(dp) + + torch.manual_seed(12) + expected_output = t_ref(*dp) + + assert_equal(expected_output, output)