@@ -406,26 +406,21 @@ def make_bounding_boxes(
406406 canvas_size = DEFAULT_SIZE ,
407407 * ,
408408 format = datapoints .BoundingBoxFormat .XYXY ,
409- batch_dims = (),
410409 dtype = None ,
411410 device = "cpu" ,
412411):
413412 def sample_position (values , max_value ):
414413 # We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
415414 # However, if we have batch_dims, we need tensors as limits.
416- return torch .stack ([torch .randint (max_value - v , ()) for v in values .flatten (). tolist ()]). reshape ( values . shape )
415+ return torch .stack ([torch .randint (max_value - v , ()) for v in values .tolist ()])
417416
418417 if isinstance (format , str ):
419418 format = datapoints .BoundingBoxFormat [format ]
420419
421420 dtype = dtype or torch .float32
422421
423- if any (dim == 0 for dim in batch_dims ):
424- return datapoints .BoundingBoxes (
425- torch .empty (* batch_dims , 4 , dtype = dtype , device = device ), format = format , canvas_size = canvas_size
426- )
427-
428- h , w = [torch .randint (1 , c , batch_dims ) for c in canvas_size ]
422+ num_objects = 1
423+ h , w = [torch .randint (1 , c , (num_objects ,)) for c in canvas_size ]
429424 y = sample_position (h , canvas_size [0 ])
430425 x = sample_position (w , canvas_size [1 ])
431426
@@ -448,11 +443,12 @@ def sample_position(values, max_value):
448443 )
449444
450445
451- def make_detection_mask (size = DEFAULT_SIZE , * , num_objects = 5 , batch_dims = (), dtype = None , device = "cpu" ):
446+ def make_detection_mask (size = DEFAULT_SIZE , * , dtype = None , device = "cpu" ):
452447 """Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
448+ num_objects = 1
453449 return datapoints .Mask (
454450 torch .testing .make_tensor (
455- (* batch_dims , num_objects , * size ),
451+ (num_objects , * size ),
456452 low = 0 ,
457453 high = 2 ,
458454 dtype = dtype or torch .bool ,
0 commit comments