-
Notifications
You must be signed in to change notification settings - Fork 7.2k
add segmentation reference consistency tests #6591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
044d120
7a9eb0c
0863741
318b15c
db9df24
8c3bfb3
641d5af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| import enum | ||
| import inspect | ||
| import random | ||
| from importlib.machinery import SourceFileLoader | ||
| from pathlib import Path | ||
|
|
||
|
|
@@ -16,6 +17,7 @@ | |
| make_image, | ||
| make_images, | ||
| make_label, | ||
| make_segmentation_mask, | ||
| ) | ||
| from torchvision import transforms as legacy_transforms | ||
| from torchvision._utils import sequence_to_str | ||
|
|
@@ -852,10 +854,12 @@ def test_aa(self, inpt, interpolation): | |
| assert_equal(expected_output, output) | ||
|
|
||
|
|
||
| # Import reference detection transforms here for consistency checks | ||
| # torchvision/references/detection/transforms.py | ||
| ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py" | ||
| det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module() | ||
| def import_transforms_from_references(reference): | ||
| ref_det_filepath = Path(__file__).parent.parent / "references" / reference / "transforms.py" | ||
| return SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module() | ||
|
|
||
|
|
||
| det_transforms = import_transforms_from_references("detection") | ||
|
|
||
|
|
||
| class TestRefDetTransforms: | ||
|
|
@@ -873,7 +877,7 @@ def make_datapoints(self, with_mask=True): | |
|
|
||
| yield (pil_image, target) | ||
|
|
||
| tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8) | ||
| tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB)) | ||
| target = { | ||
| "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), | ||
| "labels": make_label(extra_dims=(num_objects,), categories=80), | ||
|
|
@@ -883,7 +887,7 @@ def make_datapoints(self, with_mask=True): | |
|
|
||
| yield (tensor_image, target) | ||
|
|
||
| feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)) | ||
| feature_image = make_image(size=size, color_space=features.ColorSpace.RGB) | ||
| target = { | ||
| "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), | ||
| "labels": make_label(extra_dims=(num_objects,), categories=80), | ||
|
|
@@ -927,3 +931,107 @@ def test_transform(self, t_ref, t, data_kwargs): | |
| expected_output = t_ref(*dp) | ||
|
|
||
| assert_equal(expected_output, output) | ||
|
|
||
|
|
||
| seg_transforms = import_transforms_from_references("segmentation") | ||
|
|
||
|
|
||
| class TestSegDetTransforms: | ||
| def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): | ||
| size = (256, 640) | ||
| num_categories = 21 | ||
|
|
||
| conv_fns = [] | ||
| if supports_pil: | ||
| conv_fns.append(to_image_pil) | ||
| conv_fns.extend([torch.Tensor, lambda x: x]) | ||
|
|
||
| for conv_fn in conv_fns: | ||
| feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype) | ||
| feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) | ||
|
|
||
| dp = (conv_fn(feature_image), feature_mask) | ||
| dp_ref = ( | ||
| to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image), | ||
| to_image_pil(feature_mask), | ||
| ) | ||
|
|
||
| yield dp, dp_ref | ||
|
|
||
| def set_seed(self, seed): | ||
| torch.manual_seed(seed) | ||
| random.seed(seed) | ||
|
|
||
| def check(self, t, t_ref, data_kwargs=None): | ||
| for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()): | ||
|
|
||
| self.set_seed(12) | ||
| output = t(dp) | ||
|
|
||
| self.set_seed(12) | ||
| expected_output = t_ref(*dp_ref) | ||
|
|
||
| assert_equal(output, expected_output) | ||
|
|
||
| @pytest.mark.parametrize( | ||
| ("t_ref", "t", "data_kwargs"), | ||
| [ | ||
| ( | ||
| seg_transforms.RandomHorizontalFlip(flip_prob=1.0), | ||
| prototype_transforms.RandomHorizontalFlip(p=1.0), | ||
| dict(), | ||
| ), | ||
| ( | ||
| seg_transforms.RandomHorizontalFlip(flip_prob=0.0), | ||
| prototype_transforms.RandomHorizontalFlip(p=0.0), | ||
| dict(), | ||
| ), | ||
| # ( | ||
| # seg_transforms.RandomCrop(size=480), | ||
| # prototype_transforms.RandomCrop( | ||
| # size=480, pad_if_needed=True, fill=defaultdict(lambda: 0, {features.Mask: 255}) | ||
| # ), | ||
| # dict(), | ||
| # ), | ||
pmeier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ( | ||
| seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | ||
| prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | ||
| dict(supports_pil=False, image_dtype=torch.float), | ||
| ), | ||
| ], | ||
| ) | ||
| def test_common(self, t_ref, t, data_kwargs): | ||
| self.check(t, t_ref, data_kwargs) | ||
|
|
||
| def test_random_resize_train(self, mocker): | ||
| base_size = 520 | ||
| min_size = base_size // 2 | ||
| max_size = base_size * 2 | ||
|
|
||
| randint = torch.randint | ||
|
|
||
| def patched_randint(a, b, *other_args, **kwargs): | ||
| if kwargs or len(other_args) > 1 or other_args[0] != (): | ||
| return randint(a, b, *other_args, **kwargs) | ||
|
|
||
| return random.randint(a, b) | ||
|
|
||
| t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) | ||
| mocker.patch( | ||
| "torchvision.prototype.transforms._geometry.torch.randint", | ||
| new=patched_randint, | ||
| ) | ||
|
|
||
| t_ref = det_transforms.RandomResize(min_size=min_size, max_size=max_size) | ||
|
|
||
| self.check(t, t_ref) | ||
|
|
||
| def test_random_resize_eval(self): | ||
|
||
| torch.manual_seed(0) | ||
| base_size = 520 | ||
|
|
||
| t = prototype_transforms.Resize(size=base_size, antialias=True) | ||
|
|
||
| t_ref = det_transforms.RandomResize(min_size=base_size, max_size=base_size) | ||
|
|
||
| self.check(t, t_ref) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed this before. Let's use the utilities everywhere.