77from torchvision .prototype .transforms import Transform , InterpolationMode , AutoAugmentPolicy , functional as F
88from torchvision .prototype .utils ._internal import apply_recursively
99
10- from ._utils import query_image
10+ from ._utils import query_image , get_image_dimensions
1111
1212K = TypeVar ("K" )
1313V = TypeVar ("V" )
@@ -47,7 +47,7 @@ def dispatch(
4747 return input
4848
4949 image = query_image (sample )
50- num_channels = F . get_image_num_channels (image )
50+ num_channels , * _ = get_image_dimensions (image )
5151
5252 fill = self .fill
5353 if isinstance (fill , (int , float )):
@@ -160,8 +160,8 @@ class AutoAugment(_AutoAugmentBase):
160160 _AUGMENTATION_SPACE = {
161161 "ShearX" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.3 , num_bins ), True ),
162162 "ShearY" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.3 , num_bins ), True ),
163- "TranslateX" : (lambda num_bins , image_size : torch .linspace (0.0 , 150.0 / 331.0 * image_size [0 ], num_bins ), True ),
164- "TranslateY" : (lambda num_bins , image_size : torch .linspace (0.0 , 150.0 / 331.0 * image_size [1 ], num_bins ), True ),
163+ "TranslateX" : (lambda num_bins , image_size : torch .linspace (0.0 , 150.0 / 331.0 * image_size [1 ], num_bins ), True ),
164+ "TranslateY" : (lambda num_bins , image_size : torch .linspace (0.0 , 150.0 / 331.0 * image_size [0 ], num_bins ), True ),
165165 "Rotate" : (lambda num_bins , image_size : torch .linspace (0.0 , 30.0 , num_bins ), True ),
166166 "Brightness" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
167167 "Color" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
@@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any:
278278 sample = inputs if len (inputs ) > 1 else inputs [0 ]
279279
280280 image = query_image (sample )
281- image_size = F . get_image_size (image )
281+ _ , height , width = get_image_dimensions (image )
282282
283283 policy = self ._policies [int (torch .randint (len (self ._policies ), ()))]
284284
@@ -288,7 +288,7 @@ def forward(self, *inputs: Any) -> Any:
288288
289289 magnitudes_fn , signed = self ._AUGMENTATION_SPACE [transform_id ]
290290
291- magnitudes = magnitudes_fn (10 , image_size )
291+ magnitudes = magnitudes_fn (10 , ( height , width ) )
292292 if magnitudes is not None :
293293 magnitude = float (magnitudes [magnitude_idx ])
294294 if signed and torch .rand (()) <= 0.5 :
@@ -306,8 +306,8 @@ class RandAugment(_AutoAugmentBase):
306306 "Identity" : (lambda num_bins , image_size : None , False ),
307307 "ShearX" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.3 , num_bins ), True ),
308308 "ShearY" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.3 , num_bins ), True ),
309- "TranslateX" : (lambda num_bins , image_size : torch .linspace (0.0 , 150.0 / 331.0 * image_size [0 ], num_bins ), True ),
310- "TranslateY" : (lambda num_bins , image_size : torch .linspace (0.0 , 150.0 / 331.0 * image_size [1 ], num_bins ), True ),
309+ "TranslateX" : (lambda num_bins , image_size : torch .linspace (0.0 , 150.0 / 331.0 * image_size [1 ], num_bins ), True ),
310+ "TranslateY" : (lambda num_bins , image_size : torch .linspace (0.0 , 150.0 / 331.0 * image_size [0 ], num_bins ), True ),
311311 "Rotate" : (lambda num_bins , image_size : torch .linspace (0.0 , 30.0 , num_bins ), True ),
312312 "Brightness" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
313313 "Color" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
@@ -334,12 +334,12 @@ def forward(self, *inputs: Any) -> Any:
334334 sample = inputs if len (inputs ) > 1 else inputs [0 ]
335335
336336 image = query_image (sample )
337- image_size = F . get_image_size (image )
337+ _ , height , width = get_image_dimensions (image )
338338
339339 for _ in range (self .num_ops ):
340340 transform_id , (magnitudes_fn , signed ) = self ._get_random_item (self ._AUGMENTATION_SPACE )
341341
342- magnitudes = magnitudes_fn (self .num_magnitude_bins , image_size )
342+ magnitudes = magnitudes_fn (self .num_magnitude_bins , ( height , width ) )
343343 if magnitudes is not None :
344344 magnitude = float (magnitudes [int (torch .randint (self .num_magnitude_bins , ()))])
345345 if signed and torch .rand (()) <= 0.5 :
@@ -383,11 +383,11 @@ def forward(self, *inputs: Any) -> Any:
383383 sample = inputs if len (inputs ) > 1 else inputs [0 ]
384384
385385 image = query_image (sample )
386- image_size = F . get_image_size (image )
386+ _ , height , width = get_image_dimensions (image )
387387
388388 transform_id , (magnitudes_fn , signed ) = self ._get_random_item (self ._AUGMENTATION_SPACE )
389389
390- magnitudes = magnitudes_fn (self .num_magnitude_bins , image_size )
390+ magnitudes = magnitudes_fn (self .num_magnitude_bins , ( height , width ) )
391391 if magnitudes is not None :
392392 magnitude = float (magnitudes [int (torch .randint (self .num_magnitude_bins , ()))])
393393 if signed and torch .rand (()) <= 0.5 :
0 commit comments