@@ -1935,7 +1935,14 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
19351935@pytest .mark .parametrize (
19361936 "labels_getter" , ("default" , "labels" , lambda inputs : inputs ["labels" ], None , lambda inputs : None )
19371937)
1938- def test_sanitize_bounding_boxes (min_size , labels_getter ):
1938+ @pytest .mark .parametrize ("sample_type" , (tuple , dict ))
1939+ def test_sanitize_bounding_boxes (min_size , labels_getter , sample_type ):
1940+
1941+ if sample_type is tuple and not isinstance (labels_getter , str ):
1942+ # The "lambda inputs: inputs["labels"]" labels_getter used in this test
1943+ # doesn't work if the input is a tuple.
1944+ return
1945+
19391946 H , W = 256 , 128
19401947
19411948 boxes_and_validity = [
@@ -1970,35 +1977,56 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
19701977 )
19711978
19721979 masks = datapoints .Mask (torch .randint (0 , 2 , size = (boxes .shape [0 ], H , W )))
1973-
1980+ whatever = torch .rand (10 )
1981+ input_img = torch .randint (0 , 256 , size = (1 , 3 , H , W ), dtype = torch .uint8 )
19741982 sample = {
1975- "image" : torch . randint ( 0 , 256 , size = ( 1 , 3 , H , W ), dtype = torch . uint8 ) ,
1983+ "image" : input_img ,
19761984 "labels" : labels ,
19771985 "boxes" : boxes ,
1978- "whatever" : torch . rand ( 10 ) ,
1986+ "whatever" : whatever ,
19791987 "None" : None ,
19801988 "masks" : masks ,
19811989 }
19821990
1991+ if sample_type is tuple :
1992+ img = sample .pop ("image" )
1993+ sample = (img , sample )
1994+
19831995 out = transforms .SanitizeBoundingBoxes (min_size = min_size , labels_getter = labels_getter )(sample )
19841996
1985- assert out ["image" ] is sample ["image" ]
1986- assert out ["whatever" ] is sample ["whatever" ]
1997+ if sample_type is tuple :
1998+ out_image = out [0 ]
1999+ out_labels = out [1 ]["labels" ]
2000+ out_boxes = out [1 ]["boxes" ]
2001+ out_masks = out [1 ]["masks" ]
2002+ out_whatever = out [1 ]["whatever" ]
2003+ else :
2004+ out_image = out ["image" ]
2005+ out_labels = out ["labels" ]
2006+ out_boxes = out ["boxes" ]
2007+ out_masks = out ["masks" ]
2008+ out_whatever = out ["whatever" ]
2009+
2010+ assert out_image is input_img
2011+ assert out_whatever is whatever
19872012
19882013 if labels_getter is None or (callable (labels_getter ) and labels_getter ({"labels" : "blah" }) is None ):
1989- assert out [ "labels" ] is sample [ " labels" ]
2014+ assert out_labels is labels
19902015 else :
1991- assert isinstance (out [ "labels" ] , torch .Tensor )
1992- assert out [ "boxes" ] .shape [0 ] == out [ "labels" ] .shape [0 ] == out [ "masks" ] .shape [0 ]
2016+ assert isinstance (out_labels , torch .Tensor )
2017+ assert out_boxes .shape [0 ] == out_labels .shape [0 ] == out_masks .shape [0 ]
19932018 # This works because we conveniently set labels to arange(num_boxes)
1994- assert out [ "labels" ] .tolist () == valid_indices
2019+ assert out_labels .tolist () == valid_indices
19952020
19962021
19972022@pytest .mark .parametrize ("key" , ("labels" , "LABELS" , "LaBeL" , "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT" ))
1998- def test_sanitize_bounding_boxes_default_heuristic (key ):
2023+ @pytest .mark .parametrize ("sample_type" , (tuple , dict ))
2024+ def test_sanitize_bounding_boxes_default_heuristic (key , sample_type ):
19992025 labels = torch .arange (10 )
2000- d = {key : labels }
2001- assert transforms .SanitizeBoundingBoxes ._find_labels_default_heuristic (d ) is labels
2026+ sample = {key : labels , "another_key" : "whatever" }
2027+ if sample_type is tuple :
2028+ sample = (None , sample , "whatever_again" )
2029+ assert transforms .SanitizeBoundingBoxes ._find_labels_default_heuristic (sample ) is labels
20022030
20032031 if key .lower () != "labels" :
20042032 # If "labels" is in the dict (case-insensitive),
0 commit comments