Skip to content
6 changes: 5 additions & 1 deletion test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"assert_close",
"assert_equal",
"ArgsKwargs",
"VALID_EXTRA_DIMS",
"make_image_loaders",
"make_image",
"make_images",
Expand Down Expand Up @@ -201,7 +202,10 @@ def _parse_image_size(size, *, name="size"):
)


DEFAULT_EXTRA_DIMS = ((), (0,), (4,), (2, 3), (5, 0), (0, 5))
VALID_EXTRA_DIMS = ((), (4,), (2, 3))
DEGENERATE_BATCH_DIMS = ((0,), (5, 0), (0, 5))

DEFAULT_EXTRA_DIMS = (*VALID_EXTRA_DIMS, *DEGENERATE_BATCH_DIMS)


def from_loader(loader_fn):
Expand Down
60 changes: 48 additions & 12 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,40 @@ def sample_inputs(self, *feature_types, filter_metadata=True):
yield args_kwargs


def xfail_python_scalar_arg_jit(name, *, reason=None):
def xfail_jit_python_scalar_arg(name, *, reason=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a rename that moves the jit term to the front of the name to make it clear this is only a JIT issue.

reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
return TestMark(
("TestDispatchers", "test_scripted_smoke"),
pytest.mark.xfail(reason=reason),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs[name], (int, float)),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
)


def xfail_integer_size_jit(name="size"):
return xfail_python_scalar_arg_jit(name, reason=f"Integer `{name}` is not supported when scripting.")
def xfail_jit_integer_size(name="size"):
return xfail_jit_python_scalar_arg(name, reason=f"Integer `{name}` is not supported when scripting.")


def xfail_jit_tuple_instead_of_list(name, *, reason=None):
reason = reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting"
return TestMark(
("TestDispatchers", "test_scripted_smoke"),
pytest.mark.xfail(reason=reason),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple),
)


def is_list_of_ints(args_kwargs):
fill = args_kwargs.kwargs.get("fill")
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill)


def xfail_jit_list_of_ints(name, *, reason=None):
reason = reason or f"Passing a list of integers for `{name}` is not supported when scripting"
return TestMark(
("TestDispatchers", "test_scripted_smoke"),
pytest.mark.xfail(reason=reason),
condition=is_list_of_ints,
)


skip_dispatch_feature = TestMark(
Expand Down Expand Up @@ -123,7 +146,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F.resize_image_pil),
test_marks=[
xfail_integer_size_jit(),
xfail_jit_integer_size(),
],
),
DispatcherInfo(
Expand All @@ -136,7 +159,10 @@ def fill_sequence_needs_broadcast(args_kwargs):
pil_kernel_info=PILKernelInfo(F.affine_image_pil),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
xfail_python_scalar_arg_jit("shear"),
xfail_jit_python_scalar_arg("shear"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
Comment on lines +164 to +165
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a few of these. I'm not sure if this is a regression of #6636. Will check and send a follow-up PR since this is one is only test changes.

],
),
DispatcherInfo(
Expand All @@ -156,6 +182,11 @@ def fill_sequence_needs_broadcast(args_kwargs):
features.Mask: F.rotate_mask,
},
pil_kernel_info=PILKernelInfo(F.rotate_image_pil),
test_marks=[
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
],
),
DispatcherInfo(
F.crop,
Expand Down Expand Up @@ -194,7 +225,12 @@ def fill_sequence_needs_broadcast(args_kwargs):
),
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
)
),
xfail_jit_python_scalar_arg("padding"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why all this xfail appeared in this PR and not before ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we didn't test scalar padding before

padding=[[1], [1, 1], [1, 1, 2, 2]],

Thus, while reducing the number of sample inputs now, the tests are actually more comprehensive.

xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
xfail_jit_list_of_ints("fill"),
],
),
DispatcherInfo(
Expand Down Expand Up @@ -227,7 +263,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
test_marks=[
xfail_integer_size_jit("output_size"),
xfail_jit_integer_size("output_size"),
],
),
DispatcherInfo(
Expand All @@ -237,8 +273,8 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil),
test_marks=[
xfail_python_scalar_arg_jit("kernel_size"),
xfail_python_scalar_arg_jit("sigma"),
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
],
),
DispatcherInfo(
Expand Down Expand Up @@ -335,7 +371,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
test_marks=[
xfail_integer_size_jit(),
xfail_jit_integer_size(),
skip_dispatch_feature,
],
),
Expand All @@ -345,7 +381,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
features.Image: F.ten_crop_image_tensor,
},
test_marks=[
xfail_integer_size_jit(),
xfail_jit_integer_size(),
skip_dispatch_feature,
],
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
Expand Down
Loading