55from collections .abc import Sequence
66from typing import Tuple , List , Optional
77
8- import numpy as np
98import torch
109from PIL import Image
1110from torch import Tensor
@@ -721,9 +720,9 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat
721720 raise ValueError ("Please provide only two dimensions (h, w) for size." )
722721 self .size = size
723722
724- if not isinstance (scale , ( tuple , list ) ):
723+ if not isinstance (scale , Sequence ):
725724 raise TypeError ("Scale should be a sequence" )
726- if not isinstance (ratio , ( tuple , list ) ):
725+ if not isinstance (ratio , Sequence ):
727726 raise TypeError ("Ratio should be a sequence" )
728727 if (scale [0 ] > scale [1 ]) or (ratio [0 ] > ratio [1 ]):
729728 warnings .warn ("Scale and ratio should be of kind (min, max)" )
@@ -734,14 +733,14 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat
734733
735734 @staticmethod
736735 def get_params (
737- img : Tensor , scale : Tuple [float , float ], ratio : Tuple [ float , float ]
736+ img : Tensor , scale : List [float ], ratio : List [ float ]
738737 ) -> Tuple [int , int , int , int ]:
739738 """Get parameters for ``crop`` for a random sized crop.
740739
741740 Args:
742741 img (PIL Image or Tensor): Input image.
743- scale (tuple ): range of scale of the origin size cropped
744- ratio (tuple ): range of aspect ratio of the origin aspect ratio cropped
742+ scale (list ): range of scale of the origin size cropped
743+ ratio (list ): range of aspect ratio of the origin aspect ratio cropped
745744
746745 Returns:
747746 tuple: params (i, j, h, w) to be passed to ``crop`` for a random
@@ -751,7 +750,7 @@ def get_params(
751750 area = height * width
752751
753752 for _ in range (10 ):
754- target_area = area * torch .empty (1 ).uniform_ (* scale ).item ()
753+ target_area = area * torch .empty (1 ).uniform_ (scale [ 0 ], scale [ 1 ] ).item ()
755754 log_ratio = torch .log (torch .tensor (ratio ))
756755 aspect_ratio = torch .exp (
757756 torch .empty (1 ).uniform_ (log_ratio [0 ], log_ratio [1 ])
@@ -1173,8 +1172,10 @@ def __repr__(self):
11731172 return format_string
11741173
11751174
1176- class RandomAffine (object ):
1177- """Random affine transformation of the image keeping center invariant
1175+ class RandomAffine (torch .nn .Module ):
1176+ """Random affine transformation of the image keeping center invariant.
1177+ The image can be a PIL Image or a Tensor, in which case it is expected
1178+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
11781179
11791180 Args:
11801181 degrees (sequence or float or int): Range of degrees to select from.
@@ -1188,41 +1189,51 @@ class RandomAffine(object):
11881189 randomly sampled from the range a <= scale <= b. Will keep original scale by default.
11891190 shear (sequence or float or int, optional): Range of degrees to select from.
11901191 If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
1191- will be apllied . Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
1192+ will be applied . Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
11921193 range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
11931194 a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
1194- Will not apply shear by default
1195- resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
1196- An optional resampling filter. See `filters`_ for more information.
1197- If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
1198- fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area
1199- outside the transform in the output image.(Pillow>=5.0.0)
1195+ Will not apply shear by default.
1196+ resample (int, optional): An optional resampling filter. See `filters`_ for more information.
1197+ If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
1198+ If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
1199+ fillcolor (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area
1200+ outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor
1201+ input. Fill value for the area outside the transform in the output image is always 0.
12001202
12011203 .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
12021204
12031205 """
12041206
1205- def __init__ (self , degrees , translate = None , scale = None , shear = None , resample = False , fillcolor = 0 ):
1207+ def __init__ (self , degrees , translate = None , scale = None , shear = None , resample = 0 , fillcolor = 0 ):
1208+ super ().__init__ ()
12061209 if isinstance (degrees , numbers .Number ):
12071210 if degrees < 0 :
12081211 raise ValueError ("If degrees is a single number, it must be positive." )
1209- self . degrees = ( - degrees , degrees )
1212+ degrees = [ - degrees , degrees ]
12101213 else :
1211- assert isinstance (degrees , (tuple , list )) and len (degrees ) == 2 , \
1212- "degrees should be a list or tuple and it must be of length 2."
1213- self .degrees = degrees
1214+ if not isinstance (degrees , Sequence ):
1215+ raise TypeError ("degrees should be a sequence of length 2." )
1216+ if len (degrees ) != 2 :
1217+ raise ValueError ("degrees should be sequence of length 2." )
1218+
1219+ self .degrees = [float (d ) for d in degrees ]
12141220
12151221 if translate is not None :
1216- assert isinstance (translate , (tuple , list )) and len (translate ) == 2 , \
1217- "translate should be a list or tuple and it must be of length 2."
1222+ if not isinstance (translate , Sequence ):
1223+ raise TypeError ("translate should be a sequence of length 2." )
1224+ if len (translate ) != 2 :
1225+ raise ValueError ("translate should be sequence of length 2." )
12181226 for t in translate :
12191227 if not (0.0 <= t <= 1.0 ):
12201228 raise ValueError ("translation values should be between 0 and 1" )
12211229 self .translate = translate
12221230
12231231 if scale is not None :
1224- assert isinstance (scale , (tuple , list )) and len (scale ) == 2 , \
1225- "scale should be a list or tuple and it must be of length 2."
1232+ if not isinstance (scale , Sequence ):
1233+ raise TypeError ("scale should be a sequence of length 2." )
1234+ if len (scale ) != 2 :
1235+ raise ValueError ("scale should be sequence of length 2." )
1236+
12261237 for s in scale :
12271238 if s <= 0 :
12281239 raise ValueError ("scale values should be positive" )
@@ -1232,62 +1243,69 @@ def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Fal
12321243 if isinstance (shear , numbers .Number ):
12331244 if shear < 0 :
12341245 raise ValueError ("If shear is a single number, it must be positive." )
1235- self . shear = ( - shear , shear )
1246+ shear = [ - shear , shear ]
12361247 else :
1237- assert isinstance (shear , (tuple , list )) and \
1238- (len (shear ) == 2 or len (shear ) == 4 ), \
1239- "shear should be a list or tuple and it must be of length 2 or 4."
1240- # X-Axis shear with [min, max]
1241- if len (shear ) == 2 :
1242- self .shear = [shear [0 ], shear [1 ], 0. , 0. ]
1243- elif len (shear ) == 4 :
1244- self .shear = [s for s in shear ]
1248+ if not isinstance (shear , Sequence ):
1249+ raise TypeError ("shear should be a sequence of length 2 or 4." )
1250+ if len (shear ) not in (2 , 4 ):
1251+ raise ValueError ("shear should be sequence of length 2 or 4." )
1252+
1253+ self .shear = [float (s ) for s in shear ]
12451254 else :
12461255 self .shear = shear
12471256
12481257 self .resample = resample
12491258 self .fillcolor = fillcolor
12501259
12511260 @staticmethod
1252- def get_params (degrees , translate , scale_ranges , shears , img_size ):
1261+ def get_params (
1262+ degrees : List [float ],
1263+ translate : Optional [List [float ]],
1264+ scale_ranges : Optional [List [float ]],
1265+ shears : Optional [List [float ]],
1266+ img_size : List [int ]
1267+ ) -> Tuple [float , Tuple [int , int ], float , Tuple [float , float ]]:
12531268 """Get parameters for affine transformation
12541269
12551270 Returns:
1256- sequence: params to be passed to the affine transformation
1271+ params to be passed to the affine transformation
12571272 """
1258- angle = random . uniform ( degrees [0 ], degrees [1 ])
1273+ angle = float ( torch . empty ( 1 ). uniform_ ( float ( degrees [0 ]), float ( degrees [1 ])). item () )
12591274 if translate is not None :
1260- max_dx = translate [0 ] * img_size [0 ]
1261- max_dy = translate [1 ] * img_size [1 ]
1262- translations = (np .round (random .uniform (- max_dx , max_dx )),
1263- np .round (random .uniform (- max_dy , max_dy )))
1275+ max_dx = float (translate [0 ] * img_size [0 ])
1276+ max_dy = float (translate [1 ] * img_size [1 ])
1277+ tx = int (round (torch .empty (1 ).uniform_ (- max_dx , max_dx ).item ()))
1278+ ty = int (round (torch .empty (1 ).uniform_ (- max_dy , max_dy ).item ()))
1279+ translations = (tx , ty )
12641280 else :
12651281 translations = (0 , 0 )
12661282
12671283 if scale_ranges is not None :
1268- scale = random . uniform ( scale_ranges [0 ], scale_ranges [1 ])
1284+ scale = float ( torch . empty ( 1 ). uniform_ ( scale_ranges [0 ], scale_ranges [1 ]). item () )
12691285 else :
12701286 scale = 1.0
12711287
1288+ shear_x = shear_y = 0.0
12721289 if shears is not None :
1273- if len (shears ) == 2 :
1274- shear = [random .uniform (shears [0 ], shears [1 ]), 0. ]
1275- elif len (shears ) == 4 :
1276- shear = [random .uniform (shears [0 ], shears [1 ]),
1277- random .uniform (shears [2 ], shears [3 ])]
1278- else :
1279- shear = 0.0
1290+ shear_x = float (torch .empty (1 ).uniform_ (shears [0 ], shears [1 ]).item ())
1291+ if len (shears ) == 4 :
1292+ shear_y = float (torch .empty (1 ).uniform_ (shears [2 ], shears [3 ]).item ())
1293+
1294+ shear = (shear_x , shear_y )
12801295
12811296 return angle , translations , scale , shear
12821297
1283- def __call__ (self , img ):
1298+ def forward (self , img ):
12841299 """
1285- img (PIL Image): Image to be transformed.
1300+ img (PIL Image or Tensor ): Image to be transformed.
12861301
12871302 Returns:
1288- PIL Image: Affine transformed image.
1303+ PIL Image or Tensor : Affine transformed image.
12891304 """
1290- ret = self .get_params (self .degrees , self .translate , self .scale , self .shear , img .size )
1305+
1306+ img_size = F ._get_image_size (img )
1307+
1308+ ret = self .get_params (self .degrees , self .translate , self .scale , self .shear , img_size )
12911309 return F .affine (img , * ret , resample = self .resample , fillcolor = self .fillcolor )
12921310
12931311 def __repr__ (self ):
0 commit comments