Skip to content

Commit 74ea933

Browse files
authored
Cleanup prototype transforms tests (#6984)
* minor cleanup of the prototype transforms tests * refactor ImagePair * pretty format enum
1 parent 4df1a85 commit 74ea933

File tree

3 files changed

+43
-63
lines changed

3 files changed

+43
-63
lines changed

test/prototype_common_utils.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import dataclasses
5+
import enum
56
import functools
67
import pathlib
78
from collections import defaultdict
@@ -53,45 +54,31 @@ def __init__(
5354
actual,
5455
expected,
5556
*,
56-
agg_method=None,
57-
allowed_percentage_diff=None,
57+
mae=False,
5858
**other_parameters,
5959
):
6060
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
6161
actual, expected = [to_image_tensor(input) for input in [actual, expected]]
6262

6363
super().__init__(actual, expected, **other_parameters)
64-
self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method
65-
self.allowed_percentage_diff = allowed_percentage_diff
64+
self.mae = mae
6665

6766
def compare(self) -> None:
6867
actual, expected = self.actual, self.expected
6968

7069
self._compare_attributes(actual, expected)
71-
7270
actual, expected = self._equalize_attributes(actual, expected)
73-
actual, expected = self._promote_for_comparison(actual, expected)
74-
abs_diff = torch.abs(actual - expected)
7571

76-
if self.allowed_percentage_diff is not None:
77-
percentage_diff = float((abs_diff.ne(0).to(torch.float64).mean()))
78-
if percentage_diff > self.allowed_percentage_diff:
72+
if self.mae:
73+
actual, expected = self._promote_for_comparison(actual, expected)
74+
mae = float(torch.abs(actual - expected).float().mean())
75+
if mae > self.atol:
7976
raise self._make_error_meta(
8077
AssertionError,
81-
f"{percentage_diff:.1%} elements differ, "
82-
f"but only {self.allowed_percentage_diff:.1%} is allowed",
78+
f"The MAE of the images is {mae}, but only {self.atol} is allowed.",
8379
)
84-
85-
if self.agg_method is None:
86-
super()._compare_values(actual, expected)
8780
else:
88-
agg_abs_diff = float(self.agg_method(abs_diff.to(torch.float64)))
89-
if agg_abs_diff > self.atol:
90-
raise self._make_error_meta(
91-
AssertionError,
92-
f"The '{self.agg_method.__name__}' of the absolute difference is {agg_abs_diff}, "
93-
f"but only {self.atol} is allowed.",
94-
)
81+
super()._compare_values(actual, expected)
9582

9683

9784
def assert_close(
@@ -142,6 +129,8 @@ def parametrized_error_message(*args, **kwargs):
142129
def to_str(obj):
143130
if isinstance(obj, torch.Tensor) and obj.numel() > 10:
144131
return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
132+
elif isinstance(obj, enum.Enum):
133+
return f"{type(obj).__name__}.{obj.name}"
145134
else:
146135
return repr(obj)
147136

@@ -174,11 +163,13 @@ def __iter__(self):
174163
yield self.kwargs
175164

176165
def load(self, device="cpu"):
177-
args = tuple(arg.load(device) if isinstance(arg, TensorLoader) else arg for arg in self.args)
178-
kwargs = {
179-
keyword: arg.load(device) if isinstance(arg, TensorLoader) else arg for keyword, arg in self.kwargs.items()
180-
}
181-
return args, kwargs
166+
return ArgsKwargs(
167+
*(arg.load(device) if isinstance(arg, TensorLoader) else arg for arg in self.args),
168+
**{
169+
keyword: arg.load(device) if isinstance(arg, TensorLoader) else arg
170+
for keyword, arg in self.kwargs.items()
171+
},
172+
)
182173

183174

184175
DEFAULT_SQUARE_SPATIAL_SIZE = 15

test/prototype_transforms_kernel_infos.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
# values to be tested. If not specified, `sample_inputs_fn` will be used.
5353
reference_inputs_fn=None,
5454
# If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the
55-
# the reference inputs. This is usually used whenever we use a PIL kernel as reference.
55+
# reference inputs. This is usually used whenever we use a PIL kernel as reference.
5656
# Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same
5757
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
5858
# dtype.
@@ -73,8 +73,8 @@ def __init__(
7373
self.float32_vs_uint8 = float32_vs_uint8
7474

7575

76-
def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, agg_method=None):
77-
return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, agg_method=agg_method)
76+
def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
77+
return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, mae=mae)
7878

7979

8080
def cuda_vs_cpu_pixel_difference(atol=1):
@@ -84,21 +84,21 @@ def cuda_vs_cpu_pixel_difference(atol=1):
8484
}
8585

8686

87-
def pil_reference_pixel_difference(atol=1, agg_method=None):
87+
def pil_reference_pixel_difference(atol=1, mae=False):
8888
return {
8989
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): _pixel_difference_closeness_kwargs(
90-
atol, agg_method=agg_method
90+
atol, mae=mae
9191
)
9292
}
9393

9494

95-
def float32_vs_uint8_pixel_difference(atol=1, agg_method=None):
95+
def float32_vs_uint8_pixel_difference(atol=1, mae=False):
9696
return {
9797
(
9898
("TestKernels", "test_float32_vs_uint8"),
9999
torch.float32,
100100
"cpu",
101-
): _pixel_difference_closeness_kwargs(atol, dtype=torch.float32, agg_method=agg_method)
101+
): _pixel_difference_closeness_kwargs(atol, dtype=torch.float32, mae=mae)
102102
}
103103

104104

@@ -359,9 +359,9 @@ def reference_inputs_resize_bounding_box():
359359
reference_inputs_fn=reference_inputs_resize_image_tensor,
360360
float32_vs_uint8=True,
361361
closeness_kwargs={
362-
**pil_reference_pixel_difference(10, agg_method="mean"),
362+
**pil_reference_pixel_difference(10, mae=True),
363363
**cuda_vs_cpu_pixel_difference(),
364-
**float32_vs_uint8_pixel_difference(1, agg_method="mean"),
364+
**float32_vs_uint8_pixel_difference(1, mae=True),
365365
},
366366
test_marks=[
367367
xfail_jit_python_scalar_arg("size"),
@@ -613,7 +613,7 @@ def sample_inputs_affine_video():
613613
reference_fn=pil_reference_wrapper(F.affine_image_pil),
614614
reference_inputs_fn=reference_inputs_affine_image_tensor,
615615
float32_vs_uint8=True,
616-
closeness_kwargs=pil_reference_pixel_difference(10, agg_method="mean"),
616+
closeness_kwargs=pil_reference_pixel_difference(10, mae=True),
617617
test_marks=[
618618
xfail_jit_python_scalar_arg("shear"),
619619
xfail_jit_tuple_instead_of_list("fill"),
@@ -869,7 +869,7 @@ def sample_inputs_rotate_video():
869869
reference_fn=pil_reference_wrapper(F.rotate_image_pil),
870870
reference_inputs_fn=reference_inputs_rotate_image_tensor,
871871
float32_vs_uint8=True,
872-
closeness_kwargs=pil_reference_pixel_difference(1, agg_method="mean"),
872+
closeness_kwargs=pil_reference_pixel_difference(1, mae=True),
873873
test_marks=[
874874
xfail_jit_tuple_instead_of_list("fill"),
875875
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
@@ -1054,8 +1054,8 @@ def sample_inputs_resized_crop_video():
10541054
float32_vs_uint8=True,
10551055
closeness_kwargs={
10561056
**cuda_vs_cpu_pixel_difference(),
1057-
**pil_reference_pixel_difference(3, agg_method="mean"),
1058-
**float32_vs_uint8_pixel_difference(3, agg_method="mean"),
1057+
**pil_reference_pixel_difference(3, mae=True),
1058+
**float32_vs_uint8_pixel_difference(3, mae=True),
10591059
},
10601060
),
10611061
KernelInfo(
@@ -1288,7 +1288,7 @@ def sample_inputs_perspective_video():
12881288
reference_inputs_fn=reference_inputs_perspective_image_tensor,
12891289
float32_vs_uint8=float32_vs_uint8_fill_adapter,
12901290
closeness_kwargs={
1291-
**pil_reference_pixel_difference(2, agg_method="mean"),
1291+
**pil_reference_pixel_difference(2, mae=True),
12921292
**cuda_vs_cpu_pixel_difference(),
12931293
**float32_vs_uint8_pixel_difference(),
12941294
},
@@ -1371,7 +1371,7 @@ def sample_inputs_elastic_video():
13711371
reference_inputs_fn=reference_inputs_elastic_image_tensor,
13721372
float32_vs_uint8=float32_vs_uint8_fill_adapter,
13731373
closeness_kwargs={
1374-
**float32_vs_uint8_pixel_difference(6, agg_method="mean"),
1374+
**float32_vs_uint8_pixel_difference(6, mae=True),
13751375
**cuda_vs_cpu_pixel_difference(),
13761376
},
13771377
),
@@ -2028,7 +2028,7 @@ def sample_inputs_adjust_hue_video():
20282028
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
20292029
float32_vs_uint8=True,
20302030
closeness_kwargs={
2031-
**pil_reference_pixel_difference(2, agg_method="mean"),
2031+
**pil_reference_pixel_difference(2, mae=True),
20322032
**float32_vs_uint8_pixel_difference(),
20332033
},
20342034
),

test/test_prototype_transforms_functional.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,7 @@ def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
6161
]
6262

6363

64-
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None):
65-
if condition is None:
66-
67-
def condition(info):
68-
return True
69-
64+
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
7065
def decorator(test_fn):
7166
parts = test_fn.__qualname__.split(".")
7267
if len(parts) == 1:
@@ -81,9 +76,6 @@ def decorator(test_fn):
8176
argnames = ("info", "args_kwargs")
8277
argvalues = []
8378
for info in infos:
84-
if not condition(info):
85-
continue
86-
8779
argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
8880

8981
return pytest.mark.parametrize(argnames, argvalues)(test_fn)
@@ -110,9 +102,8 @@ class TestKernels:
110102
args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
111103
)
112104
reference_inputs = make_info_args_kwargs_parametrization(
113-
KERNEL_INFOS,
105+
[info for info in KERNEL_INFOS if info.reference_fn is not None],
114106
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
115-
condition=lambda info: info.reference_fn is not None,
116107
)
117108

118109
@ignore_jit_warning_no_profile
@@ -131,7 +122,7 @@ def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
131122
actual,
132123
expected,
133124
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
134-
msg=parametrized_error_message(*other_args, *kwargs),
125+
msg=parametrized_error_message(*other_args, **kwargs),
135126
)
136127

137128
def _unbatch(self, batch, *, data_dims):
@@ -188,7 +179,7 @@ def test_batched_vs_single(self, test_id, info, args_kwargs, device):
188179
actual,
189180
expected,
190181
**info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
191-
msg=parametrized_error_message(*other_args, *kwargs),
182+
msg=parametrized_error_message(*other_args, **kwargs),
192183
)
193184

194185
@sample_inputs
@@ -218,7 +209,7 @@ def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
218209
output_cpu,
219210
check_device=False,
220211
**info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
221-
msg=parametrized_error_message(*other_args, *kwargs),
212+
msg=parametrized_error_message(*other_args, **kwargs),
222213
)
223214

224215
@sample_inputs
@@ -245,7 +236,7 @@ def test_against_reference(self, test_id, info, args_kwargs):
245236
actual,
246237
expected,
247238
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
248-
msg=parametrized_error_message(*other_args, *kwargs),
239+
msg=parametrized_error_message(*other_args, **kwargs),
249240
)
250241

251242
@make_info_args_kwargs_parametrization(
@@ -272,7 +263,7 @@ def test_float32_vs_uint8(self, test_id, info, args_kwargs):
272263
actual,
273264
expected,
274265
**info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
275-
msg=parametrized_error_message(*other_args, *kwargs),
266+
msg=parametrized_error_message(*other_args, **kwargs),
276267
)
277268

278269

@@ -290,9 +281,8 @@ def make_spy(fn, *, module=None, name=None):
290281

291282
class TestDispatchers:
292283
image_sample_inputs = make_info_args_kwargs_parametrization(
293-
DISPATCHER_INFOS,
284+
[info for info in DISPATCHER_INFOS if features.Image in info.kernels],
294285
args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
295-
condition=lambda info: features.Image in info.kernels,
296286
)
297287

298288
@ignore_jit_warning_no_profile
@@ -341,9 +331,8 @@ def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
341331
spy.assert_called_once()
342332

343333
@make_info_args_kwargs_parametrization(
344-
DISPATCHER_INFOS,
334+
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
345335
args_kwargs_fn=lambda info: info.sample_inputs(features.Image),
346-
condition=lambda info: info.pil_kernel_info is not None,
347336
)
348337
def test_dispatch_pil(self, info, args_kwargs, spy_on):
349338
(image_feature, *other_args), kwargs = args_kwargs.load()

0 commit comments

Comments
 (0)