Skip to content

Commit b6574c9

Browse files
authored
port tests for transforms.ColorJitter (#7968)
1 parent 5fa8050 commit b6574c9

File tree

2 files changed

+64
-60
lines changed

2 files changed

+64
-60
lines changed

test/test_transforms_v2_consistency.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -228,23 +228,6 @@ def __init__(
228228
# Use default tolerances of `torch.testing.assert_close`
229229
closeness_kwargs=dict(rtol=None, atol=None),
230230
),
231-
ConsistencyConfig(
232-
v2_transforms.ColorJitter,
233-
legacy_transforms.ColorJitter,
234-
[
235-
ArgsKwargs(),
236-
ArgsKwargs(brightness=0.1),
237-
ArgsKwargs(brightness=(0.2, 0.3)),
238-
ArgsKwargs(contrast=0.4),
239-
ArgsKwargs(contrast=(0.5, 0.6)),
240-
ArgsKwargs(saturation=0.7),
241-
ArgsKwargs(saturation=(0.8, 0.9)),
242-
ArgsKwargs(hue=0.3),
243-
ArgsKwargs(hue=(-0.1, 0.2)),
244-
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3),
245-
],
246-
closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
247-
),
248231
ConsistencyConfig(
249232
v2_transforms.PILToTensor,
250233
legacy_transforms.PILToTensor,
@@ -453,49 +436,6 @@ def test_call_consistency(config, args_kwargs):
453436
)
454437

455438

456-
get_params_parametrization = pytest.mark.parametrize(
457-
("config", "get_params_args_kwargs"),
458-
[
459-
pytest.param(
460-
next(config for config in CONSISTENCY_CONFIGS if config.prototype_cls is transform_cls),
461-
get_params_args_kwargs,
462-
id=transform_cls.__name__,
463-
)
464-
for transform_cls, get_params_args_kwargs in [
465-
(v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
466-
(v2_transforms.AutoAugment, ArgsKwargs(5)),
467-
]
468-
],
469-
)
470-
471-
472-
@get_params_parametrization
473-
def test_get_params_alias(config, get_params_args_kwargs):
474-
assert config.prototype_cls.get_params is config.legacy_cls.get_params
475-
476-
if not config.args_kwargs:
477-
return
478-
args, kwargs = config.args_kwargs[0]
479-
legacy_transform = config.legacy_cls(*args, **kwargs)
480-
prototype_transform = config.prototype_cls(*args, **kwargs)
481-
482-
assert prototype_transform.get_params is legacy_transform.get_params
483-
484-
485-
@get_params_parametrization
486-
def test_get_params_jit(config, get_params_args_kwargs):
487-
get_params_args, get_params_kwargs = get_params_args_kwargs
488-
489-
torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs)
490-
491-
if not config.args_kwargs:
492-
return
493-
args, kwargs = config.args_kwargs[0]
494-
transform = config.prototype_cls(*args, **kwargs)
495-
496-
torch.jit.script(transform.get_params)(*get_params_args, **get_params_kwargs)
497-
498-
499439
@pytest.mark.parametrize(
500440
("config", "args_kwargs"),
501441
[

test/test_transforms_v2_refactored.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3881,3 +3881,67 @@ def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, fo
38813881
)
38823882

38833883
assert_close(actual, expected, rtol=0, atol=1)
3884+
3885+
3886+
class TestColorJitter:
3887+
@pytest.mark.parametrize(
3888+
"make_input",
3889+
[make_image_tensor, make_image_pil, make_image, make_video],
3890+
)
3891+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
3892+
@pytest.mark.parametrize("device", cpu_and_cuda())
3893+
def test_transform(self, make_input, dtype, device):
3894+
if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"):
3895+
pytest.skip(
3896+
"PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
3897+
"will degenerate to that anyway."
3898+
)
3899+
3900+
check_transform(
3901+
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25),
3902+
make_input(dtype=dtype, device=device),
3903+
)
3904+
3905+
def test_transform_noop(self):
3906+
input = make_image()
3907+
input_version = input._version
3908+
3909+
transform = transforms.ColorJitter()
3910+
output = transform(input)
3911+
3912+
assert output is input
3913+
assert output.data_ptr() == input.data_ptr()
3914+
assert output._version == input_version
3915+
3916+
def test_transform_error(self):
3917+
with pytest.raises(ValueError, match="must be non negative"):
3918+
transforms.ColorJitter(brightness=-1)
3919+
3920+
for brightness in [object(), [1, 2, 3]]:
3921+
with pytest.raises(TypeError, match="single number or a sequence with length 2"):
3922+
transforms.ColorJitter(brightness=brightness)
3923+
3924+
with pytest.raises(ValueError, match="values should be between"):
3925+
transforms.ColorJitter(brightness=(-1, 0.5))
3926+
3927+
with pytest.raises(ValueError, match="values should be between"):
3928+
transforms.ColorJitter(hue=1)
3929+
3930+
@pytest.mark.parametrize("brightness", [None, 0.1, (0.2, 0.3)])
3931+
@pytest.mark.parametrize("contrast", [None, 0.4, (0.5, 0.6)])
3932+
@pytest.mark.parametrize("saturation", [None, 0.7, (0.8, 0.9)])
3933+
@pytest.mark.parametrize("hue", [None, 0.3, (-0.1, 0.2)])
3934+
def test_transform_correctness(self, brightness, contrast, saturation, hue):
3935+
image = make_image(dtype=torch.uint8, device="cpu")
3936+
3937+
transform = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
3938+
3939+
with freeze_rng_state():
3940+
torch.manual_seed(0)
3941+
actual = transform(image)
3942+
3943+
torch.manual_seed(0)
3944+
expected = F.to_image(transform(F.to_pil_image(image)))
3945+
3946+
mae = (actual.float() - expected.float()).abs().mean()
3947+
assert mae < 2

0 commit comments

Comments
 (0)