Skip to content

Commit 35b0b9e

Browse files
authored
improve prototype transforms kernel tests (#6596)
* fix PIL and tensor mask comparison * introduce kernel_name field * add dtype consistency test * port some tests from old framework * add kernel infos for conversion kernels * cleanup * use nearest and bicubic for resize image sample inputs * make parametrization id more obvious * use named sentinel instead of None for random image size
1 parent c0911e3 commit 35b0b9e

File tree

4 files changed

+286
-230
lines changed

4 files changed

+286
-230
lines changed

test/prototype_common_utils.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.testing._comparison import (
1515
assert_equal as _assert_equal,
1616
BooleanPair,
17+
ErrorMeta,
1718
NonePair,
1819
NumberPair,
1920
TensorLikePair,
@@ -70,6 +71,19 @@ def _process_inputs(self, actual, expected, *, id, allow_subclasses):
7071
actual, expected = [
7172
to_image_tensor(input) if not isinstance(input, torch.Tensor) else input for input in [actual, expected]
7273
]
74+
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
75+
# image to a tensor adds a singleton leading dimension.
76+
# Although it looks like this belongs in `self._equalize_attributes`, it has to happen here.
77+
# `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional
78+
# shape check that will fail if we don't broadcast before.
79+
try:
80+
actual, expected = torch.broadcast_tensors(actual, expected)
81+
except RuntimeError:
82+
raise ErrorMeta(
83+
AssertionError,
84+
f"The image shapes are not broadcastable: {actual.shape} != {expected.shape}.",
85+
id=id,
86+
) from None
7387
return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses)
7488

7589
def _equalize_attributes(self, actual, expected):
@@ -165,12 +179,12 @@ def load(self, device="cpu"):
165179
DEFAULT_SQUARE_IMAGE_SIZE = 15
166180
DEFAULT_LANDSCAPE_IMAGE_SIZE = (7, 33)
167181
DEFAULT_PORTRAIT_IMAGE_SIZE = (31, 9)
168-
DEFAULT_IMAGE_SIZES = (DEFAULT_LANDSCAPE_IMAGE_SIZE, DEFAULT_PORTRAIT_IMAGE_SIZE, DEFAULT_SQUARE_IMAGE_SIZE, None)
182+
DEFAULT_IMAGE_SIZES = (DEFAULT_LANDSCAPE_IMAGE_SIZE, DEFAULT_PORTRAIT_IMAGE_SIZE, DEFAULT_SQUARE_IMAGE_SIZE, "random")
169183

170184

171185
def _parse_image_size(size, *, name="size"):
172-
if size is None:
173-
return tuple(torch.randint(16, 33, (2,)).tolist())
186+
if size == "random":
187+
return tuple(torch.randint(15, 33, (2,)).tolist())
174188
elif isinstance(size, int) and size > 0:
175189
return (size, size)
176190
elif (
@@ -181,8 +195,8 @@ def _parse_image_size(size, *, name="size"):
181195
return tuple(size)
182196
else:
183197
raise pytest.UsageError(
184-
f"'{name}' can either be `None`, a positive integer, or a sequence of two positive integers,"
185-
f"but got {size} instead"
198+
f"'{name}' can either be `'random'`, a positive integer, or a sequence of two positive integers,"
199+
f"but got {size} instead."
186200
)
187201

188202

@@ -228,7 +242,7 @@ def __post_init__(self):
228242

229243

230244
def make_image_loader(
231-
size=None,
245+
size="random",
232246
*,
233247
color_space=features.ColorSpace.RGB,
234248
extra_dims=(),
@@ -298,7 +312,7 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
298312
).reshape(low.shape)
299313

300314

301-
def make_bounding_box_loader(*, extra_dims=(), format, image_size=None, dtype=torch.float32):
315+
def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtype=torch.float32):
302316
if isinstance(format, str):
303317
format = features.BoundingBoxFormat[format]
304318
if format not in {
@@ -355,7 +369,7 @@ def make_bounding_box_loaders(
355369
*,
356370
extra_dims=DEFAULT_EXTRA_DIMS,
357371
formats=tuple(features.BoundingBoxFormat),
358-
image_size=None,
372+
image_size="random",
359373
dtypes=(torch.float32, torch.int64),
360374
):
361375
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
@@ -440,10 +454,10 @@ class MaskLoader(TensorLoader):
440454
pass
441455

442456

443-
def make_detection_mask_loader(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
457+
def make_detection_mask_loader(size="random", *, num_objects="random", extra_dims=(), dtype=torch.uint8):
444458
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
445459
size = _parse_image_size(size)
446-
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
460+
num_objects = int(torch.randint(1, 11, ())) if num_objects == "random" else num_objects
447461

448462
def fn(shape, dtype, device):
449463
data = torch.testing.make_tensor(shape, low=0, high=2, dtype=dtype, device=device)
@@ -457,7 +471,7 @@ def fn(shape, dtype, device):
457471

458472
def make_detection_mask_loaders(
459473
sizes=DEFAULT_IMAGE_SIZES,
460-
num_objects=(1, 0, None),
474+
num_objects=(1, 0, "random"),
461475
extra_dims=DEFAULT_EXTRA_DIMS,
462476
dtypes=(torch.uint8,),
463477
):
@@ -468,10 +482,10 @@ def make_detection_mask_loaders(
468482
make_detection_masks = from_loaders(make_detection_mask_loaders)
469483

470484

471-
def make_segmentation_mask_loader(size=None, *, num_categories=None, extra_dims=(), dtype=torch.uint8):
485+
def make_segmentation_mask_loader(size="random", *, num_categories="random", extra_dims=(), dtype=torch.uint8):
472486
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
473487
size = _parse_image_size(size)
474-
num_categories = num_categories if num_categories is not None else int(torch.randint(1, 11, ()))
488+
num_categories = int(torch.randint(1, 11, ())) if num_categories == "random" else num_categories
475489

476490
def fn(shape, dtype, device):
477491
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=dtype, device=device)
@@ -486,7 +500,7 @@ def fn(shape, dtype, device):
486500
def make_segmentation_mask_loaders(
487501
*,
488502
sizes=DEFAULT_IMAGE_SIZES,
489-
num_categories=(1, 2, None),
503+
num_categories=(1, 2, "random"),
490504
extra_dims=DEFAULT_EXTRA_DIMS,
491505
dtypes=(torch.uint8,),
492506
):
@@ -500,8 +514,8 @@ def make_segmentation_mask_loaders(
500514
def make_mask_loaders(
501515
*,
502516
sizes=DEFAULT_IMAGE_SIZES,
503-
num_objects=(1, 0, None),
504-
num_categories=(1, 2, None),
517+
num_objects=(1, 0, "random"),
518+
num_categories=(1, 2, "random"),
505519
extra_dims=DEFAULT_EXTRA_DIMS,
506520
dtypes=(torch.uint8,),
507521
):

0 commit comments

Comments
 (0)