@@ -35,7 +35,8 @@ def __init__(
3535 antialias : Optional [bool ] = None ,
3636 ) -> None :
3737 super ().__init__ ()
38- self .size = [size ] if isinstance (size , int ) else list (size )
38+
39+ self .size = _setup_size (size , error_msg = "Please provide only two dimensions (h, w) for size." )
3940 self .interpolation = interpolation
4041 self .max_size = max_size
4142 self .antialias = antialias
@@ -80,7 +81,6 @@ def __init__(
8081 if (scale [0 ] > scale [1 ]) or (ratio [0 ] > ratio [1 ]):
8182 warnings .warn ("Scale and ratio should be of kind (min, max)" )
8283
83- self .size = size
8484 self .scale = scale
8585 self .ratio = ratio
8686 self .interpolation = interpolation
@@ -225,6 +225,19 @@ def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) ->
225225 raise TypeError ("Got inappropriate fill arg" )
226226
227227
228+ def _check_padding_arg (padding : Union [int , Sequence [int ]]) -> None :
229+ if not isinstance (padding , (numbers .Number , tuple , list )):
230+ raise TypeError ("Got inappropriate padding arg" )
231+
232+ if isinstance (padding , (tuple , list )) and len (padding ) not in [1 , 2 , 4 ]:
233+ raise ValueError (f"Padding must be an int or a 1, 2, or 4 element tuple, not a { len (padding )} element tuple" )
234+
235+
236+ def _check_padding_mode_arg (padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ]) -> None :
237+ if padding_mode not in ["constant" , "edge" , "reflect" , "symmetric" ]:
238+ raise ValueError ("Padding mode should be either constant, edge, reflect or symmetric" )
239+
240+
228241class Pad (Transform ):
229242 def __init__ (
230243 self ,
@@ -233,18 +246,10 @@ def __init__(
233246 padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ] = "constant" ,
234247 ) -> None :
235248 super ().__init__ ()
236- if not isinstance (padding , (numbers .Number , tuple , list )):
237- raise TypeError ("Got inappropriate padding arg" )
238-
239- if isinstance (padding , (tuple , list )) and len (padding ) not in [1 , 2 , 4 ]:
240- raise ValueError (
241- f"Padding must be an int or a 1, 2, or 4 element tuple, not a { len (padding )} element tuple"
242- )
243249
250+ _check_padding_arg (padding )
244251 _check_fill_arg (fill )
245-
246- if padding_mode not in ["constant" , "edge" , "reflect" , "symmetric" ]:
247- raise ValueError ("Padding mode should be either constant, edge, reflect or symmetric" )
252+ _check_padding_mode_arg (padding_mode )
248253
249254 self .padding = padding
250255 self .fill = fill
@@ -416,3 +421,75 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
416421 fill = self .fill ,
417422 center = self .center ,
418423 )
424+
425+
426+ class RandomCrop (Transform ):
427+ def __init__ (
428+ self ,
429+ size : Union [int , Sequence [int ]],
430+ padding : Optional [Union [int , Sequence [int ]]] = None ,
431+ pad_if_needed : bool = False ,
432+ fill : Union [int , float , Sequence [int ], Sequence [float ]] = 0 ,
433+ padding_mode : Literal ["constant" , "edge" , "reflect" , "symmetric" ] = "constant" ,
434+ ) -> None :
435+ super ().__init__ ()
436+
437+ self .size = _setup_size (size , error_msg = "Please provide only two dimensions (h, w) for size." )
438+
439+ if padding is not None :
440+ _check_padding_arg (padding )
441+
442+ if (padding is not None ) or pad_if_needed :
443+ _check_padding_mode_arg (padding_mode )
444+ _check_fill_arg (fill )
445+
446+ self .padding = padding
447+ self .pad_if_needed = pad_if_needed
448+ self .fill = fill
449+ self .padding_mode = padding_mode
450+
451+ def _get_params (self , sample : Any ) -> Dict [str , Any ]:
452+ image = query_image (sample )
453+ _ , height , width = get_image_dimensions (image )
454+ output_height , output_width = self .size
455+
456+ if height + 1 < output_height or width + 1 < output_width :
457+ raise ValueError (
458+ f"Required crop size { (output_height , output_width )} is larger then input image size { (height , width )} "
459+ )
460+
461+ if width == output_width and height == output_height :
462+ return dict (top = 0 , left = 0 , height = height , width = width )
463+
464+ top = torch .randint (0 , height - output_height + 1 , size = (1 ,)).item ()
465+ left = torch .randint (0 , width - output_width + 1 , size = (1 ,)).item ()
466+ return dict (top = top , left = left , height = output_height , width = output_width )
467+
468+ def _forward (self , flat_inputs : List [Any ]) -> List [Any ]:
469+ if self .padding is not None :
470+ flat_inputs = [F .pad (flat_input , self .padding , self .fill , self .padding_mode ) for flat_input in flat_inputs ]
471+
472+ image = query_image (flat_inputs )
473+ _ , height , width = get_image_dimensions (image )
474+
475+ # pad the width if needed
476+ if self .pad_if_needed and width < self .size [1 ]:
477+ padding = [self .size [1 ] - width , 0 ]
478+ flat_inputs = [F .pad (flat_input , padding , self .fill , self .padding_mode ) for flat_input in flat_inputs ]
479+ # pad the height if needed
480+ if self .pad_if_needed and height < self .size [0 ]:
481+ padding = [0 , self .size [0 ] - height ]
482+ flat_inputs = [F .pad (flat_input , padding , self .fill , self .padding_mode ) for flat_input in flat_inputs ]
483+
484+ params = self ._get_params (flat_inputs )
485+
486+ return [F .crop (flat_input , ** params ) for flat_input in flat_inputs ]
487+
488+ def forward (self , * inputs : Any ) -> Any :
489+ from torch .utils ._pytree import tree_flatten , tree_unflatten
490+
491+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
492+
493+ flat_inputs , spec = tree_flatten (sample )
494+ out_flat_inputs = self ._forward (flat_inputs )
495+ return tree_unflatten (out_flat_inputs , spec )
0 commit comments