diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 15ec6e9b39e..19966e22bbf 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1644,3 +1644,142 @@ def test__transform(self): assert isinstance(ohe_labels, features.OneHotLabel) assert ohe_labels.shape == (4, 3) assert ohe_labels.categories == labels.categories == categories + + +class TestAPIConsistency: + @pytest.mark.parametrize("antialias", [True, False]) + @pytest.mark.parametrize( + "inpt", + [ + torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), + PIL.Image.new("RGB", (256, 256), 123), + features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + ], + ) + def test_random_resized_crop(self, antialias, inpt): + from torchvision.transforms import transforms as ref_transforms + + size = 224 + t_ref = ref_transforms.RandomResizedCrop(size, antialias=antialias) + t = transforms.RandomResizedCrop(size, antialias=antialias) + + torch.manual_seed(12) + expected_output = t_ref(inpt) + + torch.manual_seed(12) + output = t(inpt) + + if isinstance(inpt, PIL.Image.Image): + expected_output = pil_to_tensor(expected_output) + output = pil_to_tensor(output) + + torch.testing.assert_close(expected_output, output) + + @pytest.mark.parametrize( + "inpt", + [ + torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), + PIL.Image.new("RGB", (256, 256), 123), + features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + ], + ) + def test_randaug(self, inpt): + from torchvision.transforms import autoaugment as ref_transforms + + interpolation = InterpolationMode.BILINEAR + t_ref = ref_transforms.RandAugment(interpolation=interpolation) + t = transforms.RandAugment(interpolation=interpolation) + + torch.manual_seed(12) + expected_output = t_ref(inpt) + + torch.manual_seed(12) + output = t(inpt) + + if isinstance(inpt, PIL.Image.Image): + expected_output = pil_to_tensor(expected_output) + output = pil_to_tensor(output) + + torch.testing.assert_close(expected_output, output) + + @pytest.mark.parametrize( + "inpt", + [ + torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), + PIL.Image.new("RGB", (256, 256), 123), + features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + ], + ) + def test_trivial_aug(self, inpt): + from torchvision.transforms import autoaugment as ref_transforms + + interpolation = InterpolationMode.BILINEAR + t_ref = ref_transforms.TrivialAugmentWide(interpolation=interpolation) + t = transforms.TrivialAugmentWide(interpolation=interpolation) + + torch.manual_seed(12) + expected_output = t_ref(inpt) + + torch.manual_seed(12) + output = t(inpt) + + if isinstance(inpt, PIL.Image.Image): + expected_output = pil_to_tensor(expected_output) + output = pil_to_tensor(output) + + torch.testing.assert_close(expected_output, output) + + @pytest.mark.parametrize( + "inpt", + [ + torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), + PIL.Image.new("RGB", (256, 256), 123), + features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + ], + ) + def test_augmix(self, inpt): + from torchvision.transforms import autoaugment as ref_transforms + + interpolation = InterpolationMode.BILINEAR + t_ref = ref_transforms.AugMix(interpolation=interpolation) + t = transforms.AugMix(interpolation=interpolation) + + torch.manual_seed(12) + expected_output = t_ref(inpt) + + torch.manual_seed(12) + output = t(inpt) + + if isinstance(inpt, PIL.Image.Image): + expected_output = pil_to_tensor(expected_output) + output = pil_to_tensor(output) + + torch.testing.assert_close(expected_output, output) + + @pytest.mark.parametrize( + "inpt", + [ + torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), + PIL.Image.new("RGB", (256, 256), 123), + features.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), + ], + ) + def test_aa(self, inpt): + from torchvision.transforms import autoaugment as ref_transforms + + interpolation = InterpolationMode.BILINEAR + aa_policy = ref_transforms.AutoAugmentPolicy("imagenet") + t_ref = ref_transforms.AutoAugment(aa_policy, interpolation=interpolation) + t = transforms.AutoAugment(aa_policy, interpolation=interpolation) + + torch.manual_seed(12) + expected_output = t_ref(inpt) + + torch.manual_seed(12) + output = t(inpt) + + if isinstance(inpt, PIL.Image.Image): + expected_output = pil_to_tensor(expected_output) + output = pil_to_tensor(output) + + torch.testing.assert_close(expected_output, output)