@@ -184,13 +184,18 @@ def load(self, device="cpu"):
184184 return args , kwargs
185185
186186
187- DEFAULT_SQUARE_IMAGE_SIZE = 15
188- DEFAULT_LANDSCAPE_IMAGE_SIZE = (7 , 33 )
189- DEFAULT_PORTRAIT_IMAGE_SIZE = (31 , 9 )
190- DEFAULT_IMAGE_SIZES = (DEFAULT_LANDSCAPE_IMAGE_SIZE , DEFAULT_PORTRAIT_IMAGE_SIZE , DEFAULT_SQUARE_IMAGE_SIZE , "random" )
187+ DEFAULT_SQUARE_SPATIAL_SIZE = 15
188+ DEFAULT_LANDSCAPE_SPATIAL_SIZE = (7 , 33 )
189+ DEFAULT_PORTRAIT_SPATIAL_SIZE = (31 , 9 )
190+ DEFAULT_SPATIAL_SIZES = (
191+ DEFAULT_LANDSCAPE_SPATIAL_SIZE ,
192+ DEFAULT_PORTRAIT_SPATIAL_SIZE ,
193+ DEFAULT_SQUARE_SPATIAL_SIZE ,
194+ "random" ,
195+ )
191196
192197
193- def _parse_image_size (size , * , name = "size" ):
198+ def _parse_spatial_size (size , * , name = "size" ):
194199 if size == "random" :
195200 return tuple (torch .randint (15 , 33 , (2 ,)).tolist ())
196201 elif isinstance (size , int ) and size > 0 :
@@ -246,11 +251,11 @@ def load(self, device):
246251@dataclasses .dataclass
247252class ImageLoader (TensorLoader ):
248253 color_space : features .ColorSpace
249- image_size : Tuple [int , int ] = dataclasses .field (init = False )
254+ spatial_size : Tuple [int , int ] = dataclasses .field (init = False )
250255 num_channels : int = dataclasses .field (init = False )
251256
252257 def __post_init__ (self ):
253- self .image_size = self .shape [- 2 :]
258+ self .spatial_size = self .shape [- 2 :]
254259 self .num_channels = self .shape [- 3 ]
255260
256261
@@ -277,7 +282,7 @@ def make_image_loader(
277282 dtype = torch .float32 ,
278283 constant_alpha = True ,
279284):
280- size = _parse_image_size (size )
285+ size = _parse_spatial_size (size )
281286 num_channels = get_num_channels (color_space )
282287
283288 def fn (shape , dtype , device ):
@@ -295,7 +300,7 @@ def fn(shape, dtype, device):
295300
296301def make_image_loaders (
297302 * ,
298- sizes = DEFAULT_IMAGE_SIZES ,
303+ sizes = DEFAULT_SPATIAL_SIZES ,
299304 color_spaces = (
300305 features .ColorSpace .GRAY ,
301306 features .ColorSpace .GRAY_ALPHA ,
@@ -316,7 +321,7 @@ def make_image_loaders(
316321@dataclasses .dataclass
317322class BoundingBoxLoader (TensorLoader ):
318323 format : features .BoundingBoxFormat
319- image_size : Tuple [int , int ]
324+ spatial_size : Tuple [int , int ]
320325
321326
322327def randint_with_tensor_bounds (arg1 , arg2 = None , ** kwargs ):
@@ -331,7 +336,7 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
331336 ).reshape (low .shape )
332337
333338
334- def make_bounding_box_loader (* , extra_dims = (), format , image_size = "random" , dtype = torch .float32 ):
339+ def make_bounding_box_loader (* , extra_dims = (), format , spatial_size = "random" , dtype = torch .float32 ):
335340 if isinstance (format , str ):
336341 format = features .BoundingBoxFormat [format ]
337342 if format not in {
@@ -341,7 +346,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtyp
341346 }:
342347 raise pytest .UsageError (f"Can't make bounding box in format { format } " )
343348
344- image_size = _parse_image_size ( image_size , name = "image_size " )
349+ spatial_size = _parse_spatial_size ( spatial_size , name = "spatial_size " )
345350
346351 def fn (shape , dtype , device ):
347352 * extra_dims , num_coordinates = shape
@@ -350,10 +355,10 @@ def fn(shape, dtype, device):
350355
351356 if any (dim == 0 for dim in extra_dims ):
352357 return features .BoundingBox (
353- torch .empty (* extra_dims , 4 , dtype = dtype , device = device ), format = format , image_size = image_size
358+ torch .empty (* extra_dims , 4 , dtype = dtype , device = device ), format = format , spatial_size = spatial_size
354359 )
355360
356- height , width = image_size
361+ height , width = spatial_size
357362
358363 if format == features .BoundingBoxFormat .XYXY :
359364 x1 = torch .randint (0 , width // 2 , extra_dims )
@@ -375,10 +380,10 @@ def fn(shape, dtype, device):
375380 parts = (cx , cy , w , h )
376381
377382 return features .BoundingBox (
378- torch .stack (parts , dim = - 1 ).to (dtype = dtype , device = device ), format = format , image_size = image_size
383+ torch .stack (parts , dim = - 1 ).to (dtype = dtype , device = device ), format = format , spatial_size = spatial_size
379384 )
380385
381- return BoundingBoxLoader (fn , shape = (* extra_dims , 4 ), dtype = dtype , format = format , image_size = image_size )
386+ return BoundingBoxLoader (fn , shape = (* extra_dims , 4 ), dtype = dtype , format = format , spatial_size = spatial_size )
382387
383388
384389make_bounding_box = from_loader (make_bounding_box_loader )
@@ -388,11 +393,11 @@ def make_bounding_box_loaders(
388393 * ,
389394 extra_dims = DEFAULT_EXTRA_DIMS ,
390395 formats = tuple (features .BoundingBoxFormat ),
391- image_size = "random" ,
396+ spatial_size = "random" ,
392397 dtypes = (torch .float32 , torch .int64 ),
393398):
394399 for params in combinations_grid (extra_dims = extra_dims , format = formats , dtype = dtypes ):
395- yield make_bounding_box_loader (** params , image_size = image_size )
400+ yield make_bounding_box_loader (** params , spatial_size = spatial_size )
396401
397402
398403make_bounding_boxes = from_loaders (make_bounding_box_loaders )
@@ -475,7 +480,7 @@ class MaskLoader(TensorLoader):
475480
476481def make_detection_mask_loader (size = "random" , * , num_objects = "random" , extra_dims = (), dtype = torch .uint8 ):
477482 # This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
478- size = _parse_image_size (size )
483+ size = _parse_spatial_size (size )
479484 num_objects = int (torch .randint (1 , 11 , ())) if num_objects == "random" else num_objects
480485
481486 def fn (shape , dtype , device ):
@@ -489,7 +494,7 @@ def fn(shape, dtype, device):
489494
490495
491496def make_detection_mask_loaders (
492- sizes = DEFAULT_IMAGE_SIZES ,
497+ sizes = DEFAULT_SPATIAL_SIZES ,
493498 num_objects = (1 , 0 , "random" ),
494499 extra_dims = DEFAULT_EXTRA_DIMS ,
495500 dtypes = (torch .uint8 ,),
@@ -503,7 +508,7 @@ def make_detection_mask_loaders(
503508
504509def make_segmentation_mask_loader (size = "random" , * , num_categories = "random" , extra_dims = (), dtype = torch .uint8 ):
505510 # This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
506- size = _parse_image_size (size )
511+ size = _parse_spatial_size (size )
507512 num_categories = int (torch .randint (1 , 11 , ())) if num_categories == "random" else num_categories
508513
509514 def fn (shape , dtype , device ):
@@ -518,7 +523,7 @@ def fn(shape, dtype, device):
518523
519524def make_segmentation_mask_loaders (
520525 * ,
521- sizes = DEFAULT_IMAGE_SIZES ,
526+ sizes = DEFAULT_SPATIAL_SIZES ,
522527 num_categories = (1 , 2 , "random" ),
523528 extra_dims = DEFAULT_EXTRA_DIMS ,
524529 dtypes = (torch .uint8 ,),
@@ -532,7 +537,7 @@ def make_segmentation_mask_loaders(
532537
533538def make_mask_loaders (
534539 * ,
535- sizes = DEFAULT_IMAGE_SIZES ,
540+ sizes = DEFAULT_SPATIAL_SIZES ,
536541 num_objects = (1 , 0 , "random" ),
537542 num_categories = (1 , 2 , "random" ),
538543 extra_dims = DEFAULT_EXTRA_DIMS ,
@@ -559,7 +564,7 @@ def make_video_loader(
559564 extra_dims = (),
560565 dtype = torch .uint8 ,
561566):
562- size = _parse_image_size (size )
567+ size = _parse_spatial_size (size )
563568 num_frames = int (torch .randint (1 , 5 , ())) if num_frames == "random" else num_frames
564569
565570 def fn (shape , dtype , device ):
@@ -576,7 +581,7 @@ def fn(shape, dtype, device):
576581
577582def make_video_loaders (
578583 * ,
579- sizes = DEFAULT_IMAGE_SIZES ,
584+ sizes = DEFAULT_SPATIAL_SIZES ,
580585 color_spaces = (
581586 features .ColorSpace .GRAY ,
582587 features .ColorSpace .RGB ,
0 commit comments