@@ -437,7 +437,7 @@ def test__get_params(self, fill, side_range, mocker):
437437 image = mocker .MagicMock (spec = features .Image )
438438 h , w = image .spatial_size = (24 , 32 )
439439
440- params = transform ._get_params (image )
440+ params = transform ._get_params ([ image ] )
441441
442442 assert len (params ["padding" ]) == 4
443443 assert 0 <= params ["padding" ][0 ] <= (side_range [1 ] - 1 ) * w
@@ -462,7 +462,7 @@ def test__transform(self, fill, side_range, mocker):
462462 _ = transform (inpt )
463463 torch .manual_seed (12 )
464464 torch .rand (1 ) # random apply changes random state
465- params = transform ._get_params (inpt )
465+ params = transform ._get_params ([ inpt ] )
466466
467467 fill = transforms .functional ._geometry ._convert_fill_arg (fill )
468468 fn .assert_called_once_with (inpt , ** params , fill = fill )
@@ -623,7 +623,7 @@ def test__get_params(self, degrees, translate, scale, shear, mocker):
623623 h , w = image .spatial_size
624624
625625 transform = transforms .RandomAffine (degrees , translate = translate , scale = scale , shear = shear )
626- params = transform ._get_params (image )
626+ params = transform ._get_params ([ image ] )
627627
628628 if not isinstance (degrees , (list , tuple )):
629629 assert - degrees <= params ["angle" ] <= degrees
@@ -690,7 +690,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker
690690 torch .manual_seed (12 )
691691 _ = transform (inpt )
692692 torch .manual_seed (12 )
693- params = transform ._get_params (inpt )
693+ params = transform ._get_params ([ inpt ] )
694694
695695 fill = transforms .functional ._geometry ._convert_fill_arg (fill )
696696 fn .assert_called_once_with (inpt , ** params , interpolation = interpolation , fill = fill , center = center )
@@ -722,7 +722,7 @@ def test__get_params(self, padding, pad_if_needed, size, mocker):
722722 h , w = image .spatial_size
723723
724724 transform = transforms .RandomCrop (size , padding = padding , pad_if_needed = pad_if_needed )
725- params = transform ._get_params (image )
725+ params = transform ._get_params ([ image ] )
726726
727727 if padding is not None :
728728 if isinstance (padding , int ):
@@ -793,7 +793,7 @@ def test__transform(self, padding, pad_if_needed, fill, padding_mode, mocker):
793793 torch .manual_seed (12 )
794794 _ = transform (inpt )
795795 torch .manual_seed (12 )
796- params = transform ._get_params (inpt )
796+ params = transform ._get_params ([ inpt ] )
797797 if padding is None and not pad_if_needed :
798798 fn_crop .assert_called_once_with (
799799 inpt , top = params ["top" ], left = params ["left" ], height = output_size [0 ], width = output_size [1 ]
@@ -832,7 +832,7 @@ def test_assertions(self):
832832 @pytest .mark .parametrize ("sigma" , [10.0 , [10.0 , 12.0 ]])
833833 def test__get_params (self , sigma ):
834834 transform = transforms .GaussianBlur (3 , sigma = sigma )
835- params = transform ._get_params (None )
835+ params = transform ._get_params ([] )
836836
837837 if isinstance (sigma , float ):
838838 assert params ["sigma" ][0 ] == params ["sigma" ][1 ] == 10
@@ -867,7 +867,7 @@ def test__transform(self, kernel_size, sigma, mocker):
867867 torch .manual_seed (12 )
868868 _ = transform (inpt )
869869 torch .manual_seed (12 )
870- params = transform ._get_params (inpt )
870+ params = transform ._get_params ([ inpt ] )
871871
872872 fn .assert_called_once_with (inpt , kernel_size , ** params )
873873
@@ -912,7 +912,7 @@ def test__get_params(self, mocker):
912912 image .num_channels = 3
913913 image .spatial_size = (24 , 32 )
914914
915- params = transform ._get_params (image )
915+ params = transform ._get_params ([ image ] )
916916
917917 h , w = image .spatial_size
918918 assert "perspective_coeffs" in params
@@ -935,7 +935,7 @@ def test__transform(self, distortion_scale, mocker):
935935 _ = transform (inpt )
936936 torch .manual_seed (12 )
937937 torch .rand (1 ) # random apply changes random state
938- params = transform ._get_params (inpt )
938+ params = transform ._get_params ([ inpt ] )
939939
940940 fill = transforms .functional ._geometry ._convert_fill_arg (fill )
941941 fn .assert_called_once_with (inpt , ** params , fill = fill , interpolation = interpolation )
@@ -973,7 +973,7 @@ def test__get_params(self, mocker):
973973 image .num_channels = 3
974974 image .spatial_size = (24 , 32 )
975975
976- params = transform ._get_params (image )
976+ params = transform ._get_params ([ image ] )
977977
978978 h , w = image .spatial_size
979979 displacement = params ["displacement" ]
@@ -1006,7 +1006,7 @@ def test__transform(self, alpha, sigma, mocker):
10061006 # Let's mock transform._get_params to control the output:
10071007 transform ._get_params = mocker .MagicMock ()
10081008 _ = transform (inpt )
1009- params = transform ._get_params (inpt )
1009+ params = transform ._get_params ([ inpt ] )
10101010 fill = transforms .functional ._geometry ._convert_fill_arg (fill )
10111011 fn .assert_called_once_with (inpt , ** params , fill = fill , interpolation = interpolation )
10121012
@@ -1035,7 +1035,7 @@ def test_assertions(self, mocker):
10351035 transform = transforms .RandomErasing (value = [1 , 2 , 3 , 4 ])
10361036
10371037 with pytest .raises (ValueError , match = "If value is a sequence, it should have either a single value" ):
1038- transform ._get_params (image )
1038+ transform ._get_params ([ image ] )
10391039
10401040 @pytest .mark .parametrize ("value" , [5.0 , [1 , 2 , 3 ], "random" ])
10411041 def test__get_params (self , value , mocker ):
@@ -1044,7 +1044,7 @@ def test__get_params(self, value, mocker):
10441044 image .spatial_size = (24 , 32 )
10451045
10461046 transform = transforms .RandomErasing (value = value )
1047- params = transform ._get_params (image )
1047+ params = transform ._get_params ([ image ] )
10481048
10491049 v = params ["v" ]
10501050 h , w = params ["h" ], params ["w" ]
@@ -1197,6 +1197,7 @@ def test_assertions(self, transform_cls):
11971197 [
11981198 [transforms .Pad (2 ), transforms .RandomCrop (28 )],
11991199 [lambda x : 2.0 * x , transforms .Pad (2 ), transforms .RandomCrop (28 )],
1200+ [transforms .Pad (2 ), lambda x : 2.0 * x , transforms .RandomCrop (28 )],
12001201 ],
12011202 )
12021203 def test_ctor (self , transform_cls , trfms ):
@@ -1339,7 +1340,7 @@ def test__get_params(self, mocker):
13391340 n_samples = 5
13401341 for _ in range (n_samples ):
13411342
1342- params = transform ._get_params (sample )
1343+ params = transform ._get_params ([ sample ] )
13431344
13441345 assert "size" in params
13451346 size = params ["size" ]
@@ -1386,7 +1387,7 @@ def test__get_params(self, mocker):
13861387 transform = transforms .RandomShortestSize (min_size = min_size , max_size = max_size )
13871388
13881389 sample = mocker .MagicMock (spec = features .Image , num_channels = 3 , spatial_size = spatial_size )
1389- params = transform ._get_params (sample )
1390+ params = transform ._get_params ([ sample ] )
13901391
13911392 assert "size" in params
13921393 size = params ["size" ]
@@ -1554,13 +1555,13 @@ def test__get_params(self, mocker):
15541555
15551556 transform = transforms .FixedSizeCrop (size = crop_size )
15561557
1557- sample = dict (
1558- image = make_image (size = spatial_size , color_space = features .ColorSpace .RGB ),
1559- bounding_boxes = make_bounding_box (
1558+ flat_inputs = [
1559+ make_image (size = spatial_size , color_space = features .ColorSpace .RGB ),
1560+ make_bounding_box (
15601561 format = features .BoundingBoxFormat .XYXY , spatial_size = spatial_size , extra_dims = batch_shape
15611562 ),
1562- )
1563- params = transform ._get_params (sample )
1563+ ]
1564+ params = transform ._get_params (flat_inputs )
15641565
15651566 assert params ["needs_crop" ]
15661567 assert params ["height" ] <= crop_size [0 ]
@@ -1759,7 +1760,7 @@ def test__get_params(self):
17591760 transform = transforms .RandomResize (min_size = min_size , max_size = max_size )
17601761
17611762 for _ in range (10 ):
1762- params = transform ._get_params (None )
1763+ params = transform ._get_params ([] )
17631764
17641765 assert isinstance (params ["size" ], list ) and len (params ["size" ]) == 1
17651766 size = params ["size" ][0 ]
0 commit comments