1414from torch .testing ._comparison import (
1515 assert_equal as _assert_equal ,
1616 BooleanPair ,
17+ ErrorMeta ,
1718 NonePair ,
1819 NumberPair ,
1920 TensorLikePair ,
@@ -70,6 +71,19 @@ def _process_inputs(self, actual, expected, *, id, allow_subclasses):
7071 actual , expected = [
7172 to_image_tensor (input ) if not isinstance (input , torch .Tensor ) else input for input in [actual , expected ]
7273 ]
74+ # This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
75+ # image to a tensor adds a singleton leading dimension.
76+ # Although it looks like this belongs in `self._equalize_attributes`, it has to happen here.
77+ # `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional
78+ # shape check that will fail if we don't broadcast before.
79+ try :
80+ actual , expected = torch .broadcast_tensors (actual , expected )
81+ except RuntimeError :
82+ raise ErrorMeta (
83+ AssertionError ,
84+ f"The image shapes are not broadcastable: { actual .shape } != { expected .shape } ." ,
85+ id = id ,
86+ ) from None
7387 return super ()._process_inputs (actual , expected , id = id , allow_subclasses = allow_subclasses )
7488
7589 def _equalize_attributes (self , actual , expected ):
@@ -165,12 +179,12 @@ def load(self, device="cpu"):
165179DEFAULT_SQUARE_IMAGE_SIZE = 15
166180DEFAULT_LANDSCAPE_IMAGE_SIZE = (7 , 33 )
167181DEFAULT_PORTRAIT_IMAGE_SIZE = (31 , 9 )
168- DEFAULT_IMAGE_SIZES = (DEFAULT_LANDSCAPE_IMAGE_SIZE , DEFAULT_PORTRAIT_IMAGE_SIZE , DEFAULT_SQUARE_IMAGE_SIZE , None )
182+ DEFAULT_IMAGE_SIZES = (DEFAULT_LANDSCAPE_IMAGE_SIZE , DEFAULT_PORTRAIT_IMAGE_SIZE , DEFAULT_SQUARE_IMAGE_SIZE , "random" )
169183
170184
171185def _parse_image_size (size , * , name = "size" ):
172- if size is None :
173- return tuple (torch .randint (16 , 33 , (2 ,)).tolist ())
186+ if size == "random" :
187+ return tuple (torch .randint (15 , 33 , (2 ,)).tolist ())
174188 elif isinstance (size , int ) and size > 0 :
175189 return (size , size )
176190 elif (
@@ -181,8 +195,8 @@ def _parse_image_size(size, *, name="size"):
181195 return tuple (size )
182196 else :
183197 raise pytest .UsageError (
184- f"'{ name } ' can either be `None `, a positive integer, or a sequence of two positive integers,"
185- f"but got { size } instead"
198+ f"'{ name } ' can either be `'random' `, a positive integer, or a sequence of two positive integers,"
199+ f"but got { size } instead. "
186200 )
187201
188202
@@ -228,7 +242,7 @@ def __post_init__(self):
228242
229243
230244def make_image_loader (
231- size = None ,
245+ size = "random" ,
232246 * ,
233247 color_space = features .ColorSpace .RGB ,
234248 extra_dims = (),
@@ -298,7 +312,7 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
298312 ).reshape (low .shape )
299313
300314
301- def make_bounding_box_loader (* , extra_dims = (), format , image_size = None , dtype = torch .float32 ):
315+ def make_bounding_box_loader (* , extra_dims = (), format , image_size = "random" , dtype = torch .float32 ):
302316 if isinstance (format , str ):
303317 format = features .BoundingBoxFormat [format ]
304318 if format not in {
@@ -355,7 +369,7 @@ def make_bounding_box_loaders(
355369 * ,
356370 extra_dims = DEFAULT_EXTRA_DIMS ,
357371 formats = tuple (features .BoundingBoxFormat ),
358- image_size = None ,
372+ image_size = "random" ,
359373 dtypes = (torch .float32 , torch .int64 ),
360374):
361375 for params in combinations_grid (extra_dims = extra_dims , format = formats , dtype = dtypes ):
@@ -440,10 +454,10 @@ class MaskLoader(TensorLoader):
440454 pass
441455
442456
443- def make_detection_mask_loader (size = None , * , num_objects = None , extra_dims = (), dtype = torch .uint8 ):
457+ def make_detection_mask_loader (size = "random" , * , num_objects = "random" , extra_dims = (), dtype = torch .uint8 ):
444458 # This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
445459 size = _parse_image_size (size )
446- num_objects = num_objects if num_objects is not None else int (torch .randint (1 , 11 , ()))
460+ num_objects = int (torch .randint (1 , 11 , ())) if num_objects == "random" else num_objects
447461
448462 def fn (shape , dtype , device ):
449463 data = torch .testing .make_tensor (shape , low = 0 , high = 2 , dtype = dtype , device = device )
@@ -457,7 +471,7 @@ def fn(shape, dtype, device):
457471
458472def make_detection_mask_loaders (
459473 sizes = DEFAULT_IMAGE_SIZES ,
460- num_objects = (1 , 0 , None ),
474+ num_objects = (1 , 0 , "random" ),
461475 extra_dims = DEFAULT_EXTRA_DIMS ,
462476 dtypes = (torch .uint8 ,),
463477):
@@ -468,10 +482,10 @@ def make_detection_mask_loaders(
468482make_detection_masks = from_loaders (make_detection_mask_loaders )
469483
470484
471- def make_segmentation_mask_loader (size = None , * , num_categories = None , extra_dims = (), dtype = torch .uint8 ):
485+ def make_segmentation_mask_loader (size = "random" , * , num_categories = "random" , extra_dims = (), dtype = torch .uint8 ):
472486 # This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
473487 size = _parse_image_size (size )
474- num_categories = num_categories if num_categories is not None else int (torch .randint (1 , 11 , ()))
488+ num_categories = int (torch .randint (1 , 11 , ())) if num_categories == "random" else num_categories
475489
476490 def fn (shape , dtype , device ):
477491 data = torch .testing .make_tensor (shape , low = 0 , high = num_categories , dtype = dtype , device = device )
@@ -486,7 +500,7 @@ def fn(shape, dtype, device):
486500def make_segmentation_mask_loaders (
487501 * ,
488502 sizes = DEFAULT_IMAGE_SIZES ,
489- num_categories = (1 , 2 , None ),
503+ num_categories = (1 , 2 , "random" ),
490504 extra_dims = DEFAULT_EXTRA_DIMS ,
491505 dtypes = (torch .uint8 ,),
492506):
@@ -500,8 +514,8 @@ def make_segmentation_mask_loaders(
500514def make_mask_loaders (
501515 * ,
502516 sizes = DEFAULT_IMAGE_SIZES ,
503- num_objects = (1 , 0 , None ),
504- num_categories = (1 , 2 , None ),
517+ num_objects = (1 , 0 , "random" ),
518+ num_categories = (1 , 2 , "random" ),
505519 extra_dims = DEFAULT_EXTRA_DIMS ,
506520 dtypes = (torch .uint8 ,),
507521):
0 commit comments