From 4a91171235534a711020aca010fe68074299734d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 28 Oct 2022 15:59:39 +0200 Subject: [PATCH 1/2] fix reference tests for convert_format_bounding_box --- test/prototype_transforms_kernel_infos.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 287b1acaa27..ba9bfe8ee48 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -536,12 +536,12 @@ def sample_inputs_convert_format_bounding_box(): def reference_convert_format_bounding_box(bounding_box, old_format, new_format): return torchvision.ops.box_convert( - bounding_box, in_fmt=old_format.kernel_name.lower(), out_fmt=new_format.kernel_name.lower() - ) + bounding_box, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower() + ).to(bounding_box.dtype) def reference_inputs_convert_format_bounding_box(): - for args_kwargs in sample_inputs_convert_color_space_image_tensor(): + for args_kwargs in sample_inputs_convert_format_bounding_box(): if len(args_kwargs.args[0].shape) == 2: yield args_kwargs From c41439f3c85b3f6a1392c0bde882920c60c50768 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 28 Oct 2022 16:11:50 +0200 Subject: [PATCH 2/2] add check to prevent empty args_kwargs_fns in the future --- test/test_prototype_transforms_functional.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index c739598a169..6746c09e0f5 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -38,6 +38,10 @@ def script(fn): def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None): args_kwargs = list(args_kwargs_fn(info)) + if not args_kwargs: + raise pytest.UsageError( + f"Couldn't collect a single `ArgsKwargs` for `{info.id}`{f' in {test_id}' if test_id else ''}" + ) idx_field_len = len(str(len(args_kwargs))) return [ pytest.param(