22import numbers
33import warnings
44from enum import Enum
5- from typing import List , Tuple , Any , Optional
5+ from typing import List , Tuple , Any , Optional , Union
66
77import numpy as np
88import torch
@@ -948,12 +948,48 @@ def _get_inverse_affine_matrix(
948948 return matrix
949949
950950
951+ def _get_inverse_affine_matrix_tensor (
952+ center : Tensor , angle : Tensor , translate : Tensor , scale : Tensor , shear : Tensor
953+ ) -> Tensor :
954+ output = torch .zeros (3 , 3 )
955+
956+ rot = angle * torch .pi / 180.0
957+ shear_rad = shear * torch .pi / 180.0
958+
959+ m_center = torch .eye (3 , 3 )
960+ m_center [:2 , 2 ] = center
961+
962+ i_m_center = torch .eye (3 , 3 )
963+ i_m_center [:2 , 2 ] = - center
964+
965+ i_m_translate = torch .eye (3 , 3 )
966+ i_m_translate [:2 , 2 ] = - translate
967+
968+ # RSS without scaling
969+ sx , sy = shear_rad [0 ], shear_rad [1 ]
970+ a = torch .cos (rot - sy ) / torch .cos (sy )
971+ b = torch .cos (rot - sy ) * torch .tan (sx ) / torch .cos (sy ) + torch .sin (rot )
972+ c = - torch .sin (rot - sy ) / torch .cos (sy )
973+ d = - torch .sin (rot - sy ) * torch .tan (sx ) / torch .cos (sy ) + torch .cos (rot )
974+
975+ output [0 , 0 ] = d
976+ output [0 , 1 ] = b
977+ output [1 , 0 ] = c
978+ output [1 , 1 ] = a
979+ output = output / scale
980+ output [2 , 2 ] = 1.0
981+
982+ output = torch .chain_matmul (m_center , output , i_m_center , i_m_translate )
983+ output = output [:2 , :]
984+ return output
985+
986+
951987def rotate (
952988 img : Tensor ,
953- angle : float ,
989+ angle : Union [ float , int , Tensor ] ,
954990 interpolation : InterpolationMode = InterpolationMode .NEAREST ,
955991 expand : bool = False ,
956- center : Optional [List [int ]] = None ,
992+ center : Optional [Union [ List [int ], Tuple [ int , int ], Tensor ]] = None ,
957993 fill : Optional [List [float ]] = None ,
958994 resample : Optional [int ] = None ,
959995) -> Tensor :
@@ -963,7 +999,7 @@ def rotate(
963999
9641000 Args:
9651001 img (PIL Image or Tensor): image to be rotated.
966- angle (number): rotation angle value in degrees, counter-clockwise.
1002+ angle (number or Tensor ): rotation angle value in degrees, counter-clockwise.
9671003 interpolation (InterpolationMode): Desired interpolation enum defined by
9681004 :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
9691005 If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
@@ -972,7 +1008,7 @@ def rotate(
9721008 If true, expands the output image to make it large enough to hold the entire rotated image.
9731009 If false or omitted, make the output image the same size as the input image.
9741010 Note that the expand flag assumes rotation around the center and no translation.
975- center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
1011+ center (sequence or Tensor , optional): Optional center of rotation. Origin is the upper left corner.
9761012 Default is the center of the image.
9771013 fill (sequence or number, optional): Pixel fill value for the area outside the transformed
9781014 image. If given a number, the value is used for all bands respectively.
@@ -1001,28 +1037,48 @@ def rotate(
10011037 )
10021038 interpolation = _interpolation_modes_from_int (interpolation )
10031039
1004- if not isinstance (angle , (int , float )):
1005- raise TypeError ("Argument angle should be int or float" )
1040+ if not isinstance (angle , (int , float , Tensor )):
1041+ raise TypeError ("Argument angle should be int or float or Tensor " )
10061042
1007- if center is not None and not isinstance (center , (list , tuple )):
1008- raise TypeError ("Argument center should be a sequence" )
1043+ if center is not None and not isinstance (center , (list , tuple , Tensor )):
1044+ raise TypeError ("Argument center should be a sequence or a Tensor " )
10091045
10101046 if not isinstance (interpolation , InterpolationMode ):
10111047 raise TypeError ("Argument interpolation should be a InterpolationMode" )
10121048
10131049 if not isinstance (img , torch .Tensor ):
1050+ if not isinstance (angle , (int , float )):
1051+ raise TypeError ("Argument angle should be int or float" )
1052+
1053+ if center is not None and not isinstance (center , (list , tuple )):
1054+ raise TypeError ("Argument center should be a sequence" )
1055+
10141056 pil_interpolation = pil_modes_mapping [interpolation ]
10151057 return F_pil .rotate (img , angle = angle , interpolation = pil_interpolation , expand = expand , center = center , fill = fill )
10161058
1017- center_f = [0.0 , 0.0 ]
1059+ if isinstance (angle , torch .Tensor ) and angle .requires_grad :
1060+ # assert img.dtype is float
1061+ pass
1062+
1063+ center_t = torch .tensor ([0.0 , 0.0 ])
10181064 if center is not None :
1019- img_size = get_image_size (img )
1065+ # ct = torch.tensor([float(c) for c in list(center)]) if not isinstance(center, Tensor) else center
1066+ # THIS DOES NOT PASS JIT as we mix list/tuple of ints but list/tuple of floats are required
1067+ ct = torch .tensor (center ) if not isinstance (center , Tensor ) else center
1068+ img_size = torch .tensor (get_image_size (img ))
10201069 # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
1021- center_f = [ 1.0 * (c - s * 0.5 ) for c , s in zip ( center , img_size )]
1070+ center_t = 1.0 * (ct - img_size * 0.5 )
10221071
10231072 # due to current incoherence of rotation angle direction between affine and rotate implementations
10241073 # we need to set -angle.
1025- matrix = _get_inverse_affine_matrix (center_f , - angle , [0.0 , 0.0 ], 1.0 , [0.0 , 0.0 ])
1074+ angle_t = torch .tensor (float (angle )) if not isinstance (angle , Tensor ) else angle
1075+ matrix = _get_inverse_affine_matrix_tensor (
1076+ center_t ,
1077+ - angle_t ,
1078+ torch .tensor ([0.0 , 0.0 ]),
1079+ torch .tensor (1.0 ),
1080+ torch .tensor ([0.0 , 0.0 ])
1081+ )
10261082 return F_t .rotate (img , matrix = matrix , interpolation = interpolation .value , expand = expand , fill = fill )
10271083
10281084
0 commit comments