@@ -199,25 +199,31 @@ def _apply_image_transform(
199199
200200class AutoAugment (_AutoAugmentBase ):
201201 _AUGMENTATION_SPACE = {
202- "ShearX" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.3 , num_bins ), True ),
203- "ShearY" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.3 , num_bins ), True ),
204- "TranslateX" : (lambda num_bins , image_size : torch .linspace (0.0 , 150.0 / 331.0 * image_size [1 ], num_bins ), True ),
205- "TranslateY" : (lambda num_bins , image_size : torch .linspace (0.0 , 150.0 / 331.0 * image_size [0 ], num_bins ), True ),
206- "Rotate" : (lambda num_bins , image_size : torch .linspace (0.0 , 30.0 , num_bins ), True ),
207- "Brightness" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
208- "Color" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
209- "Contrast" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
210- "Sharpness" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
202+ "ShearX" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.3 , num_bins ), True ),
203+ "ShearY" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.3 , num_bins ), True ),
204+ "TranslateX" : (
205+ lambda num_bins , height , width : torch .linspace (0.0 , 150.0 / 331.0 * width , num_bins ),
206+ True ,
207+ ),
208+ "TranslateY" : (
209+ lambda num_bins , height , width : torch .linspace (0.0 , 150.0 / 331.0 * height , num_bins ),
210+ True ,
211+ ),
212+ "Rotate" : (lambda num_bins , height , width : torch .linspace (0.0 , 30.0 , num_bins ), True ),
213+ "Brightness" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.9 , num_bins ), True ),
214+ "Color" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.9 , num_bins ), True ),
215+ "Contrast" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.9 , num_bins ), True ),
216+ "Sharpness" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.9 , num_bins ), True ),
211217 "Posterize" : (
212- lambda num_bins , image_size : cast (torch .Tensor , 8 - (torch .arange (num_bins ) / ((num_bins - 1 ) / 4 )))
218+ lambda num_bins , height , width : cast (torch .Tensor , 8 - (torch .arange (num_bins ) / ((num_bins - 1 ) / 4 )))
213219 .round ()
214220 .int (),
215221 False ,
216222 ),
217- "Solarize" : (lambda num_bins , image_size : torch .linspace (255.0 , 0.0 , num_bins ), False ),
218- "AutoContrast" : (lambda num_bins , image_size : None , False ),
219- "Equalize" : (lambda num_bins , image_size : None , False ),
220- "Invert" : (lambda num_bins , image_size : None , False ),
223+ "Solarize" : (lambda num_bins , height , width : torch .linspace (255.0 , 0.0 , num_bins ), False ),
224+ "AutoContrast" : (lambda num_bins , height , width : None , False ),
225+ "Equalize" : (lambda num_bins , height , width : None , False ),
226+ "Invert" : (lambda num_bins , height , width : None , False ),
221227 }
222228
223229 def __init__ (
@@ -335,7 +341,7 @@ def forward(self, *inputs: Any) -> Any:
335341
336342 magnitudes_fn , signed = self ._AUGMENTATION_SPACE [transform_id ]
337343
338- magnitudes = magnitudes_fn (10 , ( height , width ) )
344+ magnitudes = magnitudes_fn (10 , height , width )
339345 if magnitudes is not None :
340346 magnitude = float (magnitudes [magnitude_idx ])
341347 if signed and torch .rand (()) <= 0.5 :
@@ -352,25 +358,31 @@ def forward(self, *inputs: Any) -> Any:
352358
353359class RandAugment (_AutoAugmentBase ):
354360 _AUGMENTATION_SPACE = {
355- "Identity" : (lambda num_bins , image_size : None , False ),
356- "ShearX" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.3 , num_bins ), True ),
357- "ShearY" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.3 , num_bins ), True ),
358- "TranslateX" : (lambda num_bins , image_size : torch .linspace (0.0 , 150.0 / 331.0 * image_size [1 ], num_bins ), True ),
359- "TranslateY" : (lambda num_bins , image_size : torch .linspace (0.0 , 150.0 / 331.0 * image_size [0 ], num_bins ), True ),
360- "Rotate" : (lambda num_bins , image_size : torch .linspace (0.0 , 30.0 , num_bins ), True ),
361- "Brightness" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
362- "Color" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
363- "Contrast" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
364- "Sharpness" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
361+ "Identity" : (lambda num_bins , height , width : None , False ),
362+ "ShearX" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.3 , num_bins ), True ),
363+ "ShearY" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.3 , num_bins ), True ),
364+ "TranslateX" : (
365+ lambda num_bins , height , width : torch .linspace (0.0 , 150.0 / 331.0 * width , num_bins ),
366+ True ,
367+ ),
368+ "TranslateY" : (
369+ lambda num_bins , height , width : torch .linspace (0.0 , 150.0 / 331.0 * height , num_bins ),
370+ True ,
371+ ),
372+ "Rotate" : (lambda num_bins , height , width : torch .linspace (0.0 , 30.0 , num_bins ), True ),
373+ "Brightness" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.9 , num_bins ), True ),
374+ "Color" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.9 , num_bins ), True ),
375+ "Contrast" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.9 , num_bins ), True ),
376+ "Sharpness" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.9 , num_bins ), True ),
365377 "Posterize" : (
366- lambda num_bins , image_size : cast (torch .Tensor , 8 - (torch .arange (num_bins ) / ((num_bins - 1 ) / 4 )))
378+ lambda num_bins , height , width : cast (torch .Tensor , 8 - (torch .arange (num_bins ) / ((num_bins - 1 ) / 4 )))
367379 .round ()
368380 .int (),
369381 False ,
370382 ),
371- "Solarize" : (lambda num_bins , image_size : torch .linspace (255.0 , 0.0 , num_bins ), False ),
372- "AutoContrast" : (lambda num_bins , image_size : None , False ),
373- "Equalize" : (lambda num_bins , image_size : None , False ),
383+ "Solarize" : (lambda num_bins , height , width : torch .linspace (255.0 , 0.0 , num_bins ), False ),
384+ "AutoContrast" : (lambda num_bins , height , width : None , False ),
385+ "Equalize" : (lambda num_bins , height , width : None , False ),
374386 }
375387
376388 def __init__ (
@@ -397,7 +409,7 @@ def forward(self, *inputs: Any) -> Any:
397409 for _ in range (self .num_ops ):
398410 transform_id , (magnitudes_fn , signed ) = self ._get_random_item (self ._AUGMENTATION_SPACE )
399411
400- magnitudes = magnitudes_fn (self .num_magnitude_bins , ( height , width ) )
412+ magnitudes = magnitudes_fn (self .num_magnitude_bins , height , width )
401413 if magnitudes is not None :
402414 magnitude = float (magnitudes [int (torch .randint (self .num_magnitude_bins , ()))])
403415 if signed and torch .rand (()) <= 0.5 :
@@ -414,25 +426,25 @@ def forward(self, *inputs: Any) -> Any:
414426
415427class TrivialAugmentWide (_AutoAugmentBase ):
416428 _AUGMENTATION_SPACE = {
417- "Identity" : (lambda num_bins , image_size : None , False ),
418- "ShearX" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.99 , num_bins ), True ),
419- "ShearY" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.99 , num_bins ), True ),
420- "TranslateX" : (lambda num_bins , image_size : torch .linspace (0.0 , 32.0 , num_bins ), True ),
421- "TranslateY" : (lambda num_bins , image_size : torch .linspace (0.0 , 32.0 , num_bins ), True ),
422- "Rotate" : (lambda num_bins , image_size : torch .linspace (0.0 , 135.0 , num_bins ), True ),
423- "Brightness" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.99 , num_bins ), True ),
424- "Color" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.99 , num_bins ), True ),
425- "Contrast" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.99 , num_bins ), True ),
426- "Sharpness" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.99 , num_bins ), True ),
429+ "Identity" : (lambda num_bins , height , width : None , False ),
430+ "ShearX" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.99 , num_bins ), True ),
431+ "ShearY" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.99 , num_bins ), True ),
432+ "TranslateX" : (lambda num_bins , height , width : torch .linspace (0.0 , 32.0 , num_bins ), True ),
433+ "TranslateY" : (lambda num_bins , height , width : torch .linspace (0.0 , 32.0 , num_bins ), True ),
434+ "Rotate" : (lambda num_bins , height , width : torch .linspace (0.0 , 135.0 , num_bins ), True ),
435+ "Brightness" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.99 , num_bins ), True ),
436+ "Color" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.99 , num_bins ), True ),
437+ "Contrast" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.99 , num_bins ), True ),
438+ "Sharpness" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.99 , num_bins ), True ),
427439 "Posterize" : (
428- lambda num_bins , image_size : cast (torch .Tensor , 8 - (torch .arange (num_bins ) / ((num_bins - 1 ) / 6 )))
440+ lambda num_bins , height , width : cast (torch .Tensor , 8 - (torch .arange (num_bins ) / ((num_bins - 1 ) / 6 )))
429441 .round ()
430442 .int (),
431443 False ,
432444 ),
433- "Solarize" : (lambda num_bins , image_size : torch .linspace (255.0 , 0.0 , num_bins ), False ),
434- "AutoContrast" : (lambda num_bins , image_size : None , False ),
435- "Equalize" : (lambda num_bins , image_size : None , False ),
445+ "Solarize" : (lambda num_bins , height , width : torch .linspace (255.0 , 0.0 , num_bins ), False ),
446+ "AutoContrast" : (lambda num_bins , height , width : None , False ),
447+ "Equalize" : (lambda num_bins , height , width : None , False ),
436448 }
437449
438450 def __init__ (
@@ -454,7 +466,7 @@ def forward(self, *inputs: Any) -> Any:
454466
455467 transform_id , (magnitudes_fn , signed ) = self ._get_random_item (self ._AUGMENTATION_SPACE )
456468
457- magnitudes = magnitudes_fn (self .num_magnitude_bins , ( height , width ) )
469+ magnitudes = magnitudes_fn (self .num_magnitude_bins , height , width )
458470 if magnitudes is not None :
459471 magnitude = float (magnitudes [int (torch .randint (self .num_magnitude_bins , ()))])
460472 if signed and torch .rand (()) <= 0.5 :
@@ -468,27 +480,27 @@ def forward(self, *inputs: Any) -> Any:
468480
469481class AugMix (_AutoAugmentBase ):
470482 _PARTIAL_AUGMENTATION_SPACE = {
471- "ShearX" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.3 , num_bins ), True ),
472- "ShearY" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.3 , num_bins ), True ),
473- "TranslateX" : (lambda num_bins , image_size : torch .linspace (0.0 , image_size [ 1 ] / 3.0 , num_bins ), True ),
474- "TranslateY" : (lambda num_bins , image_size : torch .linspace (0.0 , image_size [ 0 ] / 3.0 , num_bins ), True ),
475- "Rotate" : (lambda num_bins , image_size : torch .linspace (0.0 , 30.0 , num_bins ), True ),
483+ "ShearX" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.3 , num_bins ), True ),
484+ "ShearY" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.3 , num_bins ), True ),
485+ "TranslateX" : (lambda num_bins , height , width : torch .linspace (0.0 , width / 3.0 , num_bins ), True ),
486+ "TranslateY" : (lambda num_bins , height , width : torch .linspace (0.0 , height / 3.0 , num_bins ), True ),
487+ "Rotate" : (lambda num_bins , height , width : torch .linspace (0.0 , 30.0 , num_bins ), True ),
476488 "Posterize" : (
477- lambda num_bins , image_size : cast (torch .Tensor , 4 - (torch .arange (num_bins ) / ((num_bins - 1 ) / 4 )))
489+ lambda num_bins , height , width : cast (torch .Tensor , 4 - (torch .arange (num_bins ) / ((num_bins - 1 ) / 4 )))
478490 .round ()
479491 .int (),
480492 False ,
481493 ),
482- "Solarize" : (lambda num_bins , image_size : torch .linspace (255.0 , 0.0 , num_bins ), False ),
483- "AutoContrast" : (lambda num_bins , image_size : None , False ),
484- "Equalize" : (lambda num_bins , image_size : None , False ),
494+ "Solarize" : (lambda num_bins , height , width : torch .linspace (255.0 , 0.0 , num_bins ), False ),
495+ "AutoContrast" : (lambda num_bins , height , width : None , False ),
496+ "Equalize" : (lambda num_bins , height , width : None , False ),
485497 }
486- _AUGMENTATION_SPACE : Dict [str , Tuple [Callable [[int , Tuple [ int , int ] ], Optional [torch .Tensor ]], bool ]] = {
498+ _AUGMENTATION_SPACE : Dict [str , Tuple [Callable [[int , int , int ], Optional [torch .Tensor ]], bool ]] = {
487499 ** _PARTIAL_AUGMENTATION_SPACE ,
488- "Brightness" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
489- "Color" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
490- "Contrast" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
491- "Sharpness" : (lambda num_bins , image_size : torch .linspace (0.0 , 0.9 , num_bins ), True ),
500+ "Brightness" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.9 , num_bins ), True ),
501+ "Color" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.9 , num_bins ), True ),
502+ "Contrast" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.9 , num_bins ), True ),
503+ "Sharpness" : (lambda num_bins , height , width : torch .linspace (0.0 , 0.9 , num_bins ), True ),
492504 }
493505
494506 def __init__ (
@@ -550,7 +562,7 @@ def forward(self, *inputs: Any) -> Any:
550562 for _ in range (depth ):
551563 transform_id , (magnitudes_fn , signed ) = self ._get_random_item (augmentation_space )
552564
553- magnitudes = magnitudes_fn (self ._PARAMETER_MAX , ( height , width ) )
565+ magnitudes = magnitudes_fn (self ._PARAMETER_MAX , height , width )
554566 if magnitudes is not None :
555567 magnitude = float (magnitudes [int (torch .randint (self .severity , ()))])
556568 if signed and torch .rand (()) <= 0.5 :
0 commit comments