@@ -52,7 +52,7 @@ def __init__(
5252 # values to be tested. If not specified, `sample_inputs_fn` will be used.
5353 reference_inputs_fn = None ,
5454 # If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the
55- # the reference inputs. This is usually used whenever we use a PIL kernel as reference.
55+ # reference inputs. This is usually used whenever we use a PIL kernel as reference.
5656 # Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same
5757 # structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
5858 # dtype.
@@ -73,8 +73,8 @@ def __init__(
7373 self .float32_vs_uint8 = float32_vs_uint8
7474
7575
76- def _pixel_difference_closeness_kwargs (uint8_atol , * , dtype = torch .uint8 , agg_method = None ):
77- return dict (atol = uint8_atol / 255 * get_max_value (dtype ), rtol = 0 , agg_method = agg_method )
76+ def _pixel_difference_closeness_kwargs (uint8_atol , * , dtype = torch .uint8 , mae = False ):
77+ return dict (atol = uint8_atol / 255 * get_max_value (dtype ), rtol = 0 , mae = mae )
7878
7979
8080def cuda_vs_cpu_pixel_difference (atol = 1 ):
@@ -84,21 +84,21 @@ def cuda_vs_cpu_pixel_difference(atol=1):
8484 }
8585
8686
87- def pil_reference_pixel_difference (atol = 1 , agg_method = None ):
87+ def pil_reference_pixel_difference (atol = 1 , mae = False ):
8888 return {
8989 (("TestKernels" , "test_against_reference" ), torch .uint8 , "cpu" ): _pixel_difference_closeness_kwargs (
90- atol , agg_method = agg_method
90+ atol , mae = mae
9191 )
9292 }
9393
9494
95- def float32_vs_uint8_pixel_difference (atol = 1 , agg_method = None ):
95+ def float32_vs_uint8_pixel_difference (atol = 1 , mae = False ):
9696 return {
9797 (
9898 ("TestKernels" , "test_float32_vs_uint8" ),
9999 torch .float32 ,
100100 "cpu" ,
101- ): _pixel_difference_closeness_kwargs (atol , dtype = torch .float32 , agg_method = agg_method )
101+ ): _pixel_difference_closeness_kwargs (atol , dtype = torch .float32 , mae = mae )
102102 }
103103
104104
@@ -359,9 +359,9 @@ def reference_inputs_resize_bounding_box():
359359 reference_inputs_fn = reference_inputs_resize_image_tensor ,
360360 float32_vs_uint8 = True ,
361361 closeness_kwargs = {
362- ** pil_reference_pixel_difference (10 , agg_method = "mean" ),
362+ ** pil_reference_pixel_difference (10 , mae = True ),
363363 ** cuda_vs_cpu_pixel_difference (),
364- ** float32_vs_uint8_pixel_difference (1 , agg_method = "mean" ),
364+ ** float32_vs_uint8_pixel_difference (1 , mae = True ),
365365 },
366366 test_marks = [
367367 xfail_jit_python_scalar_arg ("size" ),
@@ -613,7 +613,7 @@ def sample_inputs_affine_video():
613613 reference_fn = pil_reference_wrapper (F .affine_image_pil ),
614614 reference_inputs_fn = reference_inputs_affine_image_tensor ,
615615 float32_vs_uint8 = True ,
616- closeness_kwargs = pil_reference_pixel_difference (10 , agg_method = "mean" ),
616+ closeness_kwargs = pil_reference_pixel_difference (10 , mae = True ),
617617 test_marks = [
618618 xfail_jit_python_scalar_arg ("shear" ),
619619 xfail_jit_tuple_instead_of_list ("fill" ),
@@ -869,7 +869,7 @@ def sample_inputs_rotate_video():
869869 reference_fn = pil_reference_wrapper (F .rotate_image_pil ),
870870 reference_inputs_fn = reference_inputs_rotate_image_tensor ,
871871 float32_vs_uint8 = True ,
872- closeness_kwargs = pil_reference_pixel_difference (1 , agg_method = "mean" ),
872+ closeness_kwargs = pil_reference_pixel_difference (1 , mae = True ),
873873 test_marks = [
874874 xfail_jit_tuple_instead_of_list ("fill" ),
875875 # TODO: check if this is a regression since it seems that should be supported if `int` is ok
@@ -1054,8 +1054,8 @@ def sample_inputs_resized_crop_video():
10541054 float32_vs_uint8 = True ,
10551055 closeness_kwargs = {
10561056 ** cuda_vs_cpu_pixel_difference (),
1057- ** pil_reference_pixel_difference (3 , agg_method = "mean" ),
1058- ** float32_vs_uint8_pixel_difference (3 , agg_method = "mean" ),
1057+ ** pil_reference_pixel_difference (3 , mae = True ),
1058+ ** float32_vs_uint8_pixel_difference (3 , mae = True ),
10591059 },
10601060 ),
10611061 KernelInfo (
@@ -1288,7 +1288,7 @@ def sample_inputs_perspective_video():
12881288 reference_inputs_fn = reference_inputs_perspective_image_tensor ,
12891289 float32_vs_uint8 = float32_vs_uint8_fill_adapter ,
12901290 closeness_kwargs = {
1291- ** pil_reference_pixel_difference (2 , agg_method = "mean" ),
1291+ ** pil_reference_pixel_difference (2 , mae = True ),
12921292 ** cuda_vs_cpu_pixel_difference (),
12931293 ** float32_vs_uint8_pixel_difference (),
12941294 },
@@ -1371,7 +1371,7 @@ def sample_inputs_elastic_video():
13711371 reference_inputs_fn = reference_inputs_elastic_image_tensor ,
13721372 float32_vs_uint8 = float32_vs_uint8_fill_adapter ,
13731373 closeness_kwargs = {
1374- ** float32_vs_uint8_pixel_difference (6 , agg_method = "mean" ),
1374+ ** float32_vs_uint8_pixel_difference (6 , mae = True ),
13751375 ** cuda_vs_cpu_pixel_difference (),
13761376 },
13771377 ),
@@ -2028,7 +2028,7 @@ def sample_inputs_adjust_hue_video():
20282028 reference_inputs_fn = reference_inputs_adjust_hue_image_tensor ,
20292029 float32_vs_uint8 = True ,
20302030 closeness_kwargs = {
2031- ** pil_reference_pixel_difference (2 , agg_method = "mean" ),
2031+ ** pil_reference_pixel_difference (2 , mae = True ),
20322032 ** float32_vs_uint8_pixel_difference (),
20332033 },
20342034 ),
0 commit comments