Skip to content

Commit 7bd5976

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] port tests for rgb_to_grayscale functional and transforms (#7967)
Reviewed By: matteobettini Differential Revision: D49600771 fbshipit-source-id: a4d6ed523f2ed0919fa2f73884429e89ecb8b27d
1 parent e195bf5 commit 7bd5976

File tree

3 files changed

+55
-24
lines changed

3 files changed

+55
-24
lines changed

test/test_transforms_v2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,9 @@ class TestSmoke:
116116
(transforms.RandAugment(), auto_augment_adapter),
117117
(transforms.TrivialAugmentWide(), auto_augment_adapter),
118118
(transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None),
119-
(transforms.Grayscale(), None),
120119
(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None),
121120
(transforms.RandomAutocontrast(p=1.0), None),
122121
(transforms.RandomEqualize(p=1.0), None),
123-
(transforms.RandomGrayscale(p=1.0), None),
124122
(transforms.RandomInvert(p=1.0), None),
125123
(transforms.RandomChannelPermutation(), None),
126124
(transforms.RandomPhotometricDistort(p=1.0), None),

test/test_transforms_v2_consistency.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -122,17 +122,6 @@ def __init__(
122122
(torch.float32, torch.float64),
123123
]
124124
],
125-
ConsistencyConfig(
126-
v2_transforms.Grayscale,
127-
legacy_transforms.Grayscale,
128-
[
129-
ArgsKwargs(num_output_channels=1),
130-
ArgsKwargs(num_output_channels=3),
131-
],
132-
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
133-
# Use default tolerances of `torch.testing.assert_close`
134-
closeness_kwargs=dict(rtol=None, atol=None),
135-
),
136125
ConsistencyConfig(
137126
v2_transforms.ToPILImage,
138127
legacy_transforms.ToPILImage,
@@ -217,17 +206,6 @@ def __init__(
217206
],
218207
closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
219208
),
220-
ConsistencyConfig(
221-
v2_transforms.RandomGrayscale,
222-
legacy_transforms.RandomGrayscale,
223-
[
224-
ArgsKwargs(p=0),
225-
ArgsKwargs(p=1),
226-
],
227-
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
228-
# Use default tolerances of `torch.testing.assert_close`
229-
closeness_kwargs=dict(rtol=None, atol=None),
230-
),
231209
ConsistencyConfig(
232210
v2_transforms.PILToTensor,
233211
legacy_transforms.PILToTensor,

test/test_transforms_v2_refactored.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3945,3 +3945,58 @@ def test_transform_correctness(self, brightness, contrast, saturation, hue):
39453945

39463946
mae = (actual.float() - expected.float()).abs().mean()
39473947
assert mae < 2
3948+
3949+
3950+
class TestRgbToGrayscale:
3951+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
3952+
@pytest.mark.parametrize("device", cpu_and_cuda())
3953+
def test_kernel_image(self, dtype, device):
3954+
check_kernel(F.rgb_to_grayscale_image, make_image(dtype=dtype, device=device))
3955+
3956+
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
3957+
def test_functional(self, make_input):
3958+
check_functional(F.rgb_to_grayscale, make_input())
3959+
3960+
@pytest.mark.parametrize(
3961+
("kernel", "input_type"),
3962+
[
3963+
(F.rgb_to_grayscale_image, torch.Tensor),
3964+
(F._rgb_to_grayscale_image_pil, PIL.Image.Image),
3965+
(F.rgb_to_grayscale_image, tv_tensors.Image),
3966+
],
3967+
)
3968+
def test_functional_signature(self, kernel, input_type):
3969+
check_functional_kernel_signature_match(F.rgb_to_grayscale, kernel=kernel, input_type=input_type)
3970+
3971+
@pytest.mark.parametrize("transform", [transforms.Grayscale(), transforms.RandomGrayscale(p=1)])
3972+
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
3973+
def test_transform(self, transform, make_input):
3974+
check_transform(transform, make_input())
3975+
3976+
@pytest.mark.parametrize("num_output_channels", [1, 3])
3977+
@pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)])
3978+
def test_image_correctness(self, num_output_channels, fn):
3979+
image = make_image(dtype=torch.uint8, device="cpu")
3980+
3981+
actual = fn(image, num_output_channels=num_output_channels)
3982+
expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels))
3983+
3984+
assert_equal(actual, expected, rtol=0, atol=1)
3985+
3986+
@pytest.mark.parametrize("num_input_channels", [1, 3])
3987+
def test_random_transform_correctness(self, num_input_channels):
3988+
image = make_image(
3989+
color_space={
3990+
1: "GRAY",
3991+
3: "RGB",
3992+
}[num_input_channels],
3993+
dtype=torch.uint8,
3994+
device="cpu",
3995+
)
3996+
3997+
transform = transforms.RandomGrayscale(p=1)
3998+
3999+
actual = transform(image)
4000+
expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_input_channels))
4001+
4002+
assert_equal(actual, expected, rtol=0, atol=1)

0 commit comments

Comments
 (0)