1212from torchvision .transforms .transforms import _setup_size , _interpolation_modes_from_int
1313from typing_extensions import Literal
1414
15+ from ._transform import _RandomApplyTransform
1516from ._utils import query_image , get_image_dimensions , has_any , is_simple_tensor
1617
1718
18- class RandomHorizontalFlip (Transform ):
19- def __init__ (self , p : float = 0.5 ) -> None :
20- super ().__init__ ()
21- self .p = p
22-
23- def forward (self , * inputs : Any ) -> Any :
24- sample = inputs if len (inputs ) > 1 else inputs [0 ]
25- if torch .rand (1 ) >= self .p :
26- return sample
27-
28- return super ().forward (sample )
29-
19+ class RandomHorizontalFlip (_RandomApplyTransform ):
3020 def _transform (self , input : Any , params : Dict [str , Any ]) -> Any :
3121 if isinstance (input , features .Image ):
3222 output = F .horizontal_flip_image_tensor (input )
@@ -45,18 +35,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
4535 return input
4636
4737
48- class RandomVerticalFlip (Transform ):
49- def __init__ (self , p : float = 0.5 ) -> None :
50- super ().__init__ ()
51- self .p = p
52-
53- def forward (self , * inputs : Any ) -> Any :
54- sample = inputs if len (inputs ) > 1 else inputs [0 ]
55- if torch .rand (1 ) > self .p :
56- return sample
57-
58- return super ().forward (sample )
59-
38+ class RandomVerticalFlip (_RandomApplyTransform ):
6039 def _transform (self , input : Any , params : Dict [str , Any ]) -> Any :
6140 if isinstance (input , features .Image ):
6241 output = F .vertical_flip_image_tensor (input )
@@ -371,11 +350,11 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
371350 return input
372351
373352
374- class RandomZoomOut (Transform ):
353+ class RandomZoomOut (_RandomApplyTransform ):
375354 def __init__ (
376355 self , fill : Union [float , Sequence [float ]] = 0.0 , side_range : Tuple [float , float ] = (1.0 , 4.0 ), p : float = 0.5
377356 ) -> None :
378- super ().__init__ ()
357+ super ().__init__ (p = p )
379358
380359 if fill is None :
381360 fill = 0.0
@@ -385,8 +364,6 @@ def __init__(
385364 if side_range [0 ] < 1.0 or side_range [0 ] > side_range [1 ]:
386365 raise ValueError (f"Invalid canvas side range provided { side_range } ." )
387366
388- self .p = p
389-
390367 def _get_params (self , sample : Any ) -> Dict [str , Any ]:
391368 image = query_image (sample )
392369 orig_c , orig_h , orig_w = get_image_dimensions (image )
@@ -411,10 +388,3 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
411388 def _transform (self , input : Any , params : Dict [str , Any ]) -> Any :
412389 transform = Pad (** params , padding_mode = "constant" )
413390 return transform (input )
414-
415- def forward (self , * inputs : Any ) -> Any :
416- sample = inputs if len (inputs ) > 1 else inputs [0 ]
417- if torch .rand (1 ) >= self .p :
418- return sample
419-
420- return super ().forward (sample )
0 commit comments