File tree Expand file tree Collapse file tree 2 files changed +18
-11
lines changed
torchvision/prototype/transforms Expand file tree Collapse file tree 2 files changed +18
-11
lines changed Original file line number Diff line number Diff line change @@ -1263,18 +1263,24 @@ def test__get_params(self, mocker):
12631263 scale_range = (0.5 , 1.5 )
12641264
12651265 transform = transforms .ScaleJitter (target_size = target_size , scale_range = scale_range )
1266-
12671266 sample = mocker .MagicMock (spec = features .Image , num_channels = 3 , image_size = image_size )
1268- params = transform ._get_params (sample )
12691267
1270- assert "size" in params
1271- size = params [ "size" ]
1268+ n_samples = 5
1269+ for _ in range ( n_samples ):
12721270
1273- assert isinstance (size , tuple ) and len (size ) == 2
1274- height , width = size
1271+ params = transform ._get_params (sample )
1272+
1273+ assert "size" in params
1274+ size = params ["size" ]
1275+
1276+ assert isinstance (size , tuple ) and len (size ) == 2
1277+ height , width = size
1278+
1279+ r_min = min (target_size [1 ] / image_size [0 ], target_size [0 ] / image_size [1 ]) * scale_range [0 ]
1280+ r_max = min (target_size [1 ] / image_size [0 ], target_size [0 ] / image_size [1 ]) * scale_range [1 ]
12751281
1276- assert int (target_size [0 ] * scale_range [ 0 ] ) <= height <= int (target_size [0 ] * scale_range [ 1 ] )
1277- assert int (target_size [1 ] * scale_range [ 0 ] ) <= width <= int (target_size [1 ] * scale_range [ 1 ] )
1282+ assert int (image_size [0 ] * r_min ) <= height <= int (image_size [0 ] * r_max )
1283+ assert int (image_size [1 ] * r_min ) <= width <= int (image_size [1 ] * r_max )
12781284
12791285 def test__transform (self , mocker ):
12801286 interpolation_sentinel = mocker .MagicMock ()
Original file line number Diff line number Diff line change @@ -727,9 +727,10 @@ def __init__(
727727 def _get_params (self , sample : Any ) -> Dict [str , Any ]:
728728 _ , orig_height , orig_width = query_chw (sample )
729729
730- r = self .scale_range [0 ] + torch .rand (1 ) * (self .scale_range [1 ] - self .scale_range [0 ])
731- new_width = int (self .target_size [1 ] * r )
732- new_height = int (self .target_size [0 ] * r )
730+ scale = self .scale_range [0 ] + torch .rand (1 ) * (self .scale_range [1 ] - self .scale_range [0 ])
731+ r = min (self .target_size [1 ] / orig_height , self .target_size [0 ] / orig_width ) * scale
732+ new_width = int (orig_width * r )
733+ new_height = int (orig_height * r )
733734
734735 return dict (size = (new_height , new_width ))
735736
You can’t perform that action at this time.
0 commit comments