@@ -900,6 +900,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
900900def _get_inverse_affine_matrix (
901901 center : List [float ], angle : float , translate : List [float ], scale : float , shear : List [float ]
902902) -> List [float ]:
903+ # TODO: REMOVE THIS METHOD IN FAVOR OF _get_inverse_affine_matrix_tensor
903904 # Helper method to compute inverse matrix for affine transformation
904905
905906 # As it is explained in PIL.Image.rotate
@@ -1056,28 +1057,38 @@ def rotate(
10561057 pil_interpolation = pil_modes_mapping [interpolation ]
10571058 return F_pil .rotate (img , angle = angle , interpolation = pil_interpolation , expand = expand , center = center , fill = fill )
10581059
1059- if isinstance (angle , torch .Tensor ) and angle .requires_grad :
1060- # assert img.dtype is float
1061- pass
1062-
1063- center_t = torch .tensor ([0.0 , 0.0 ])
1064- if center is not None :
1065- # ct = torch.tensor([float(c) for c in list(center)]) if not isinstance(center, Tensor) else center
1066- # THIS DOES NOT PASS JIT as we mix list/tuple of ints but list/tuple of floats are required
1067- ct = torch .tensor (center ) if not isinstance (center , Tensor ) else center
1068- img_size = torch .tensor (get_image_size (img ))
1060+ # TODO: This is a rather generic check for input dtype if args are learnable
1061+ # We can refactor that later
1062+ if not torch .jit .is_scripting ():
1063+ # torch.jit.script crashes with Segmentation fault (core dumped) on the following
1064+ # without if not torch.jit.is_scripting()
1065+ if (isinstance (angle , torch .Tensor ) and angle .requires_grad ) or (
1066+ isinstance (center , torch .Tensor ) and center .requires_grad
1067+ ):
1068+ if not img .is_floating_point ():
1069+ raise ValueError ("If angle is tensor that requires grad, image should be float" )
1070+
1071+ do_recenter = True
1072+ if center is None :
1073+ center = torch .tensor ([0.0 , 0.0 ])
1074+ do_recenter = False
1075+
1076+ if isinstance (center , tuple ):
1077+ center = list (center )
1078+
1079+ if isinstance (center , list ):
1080+ center = torch .tensor ([float (center [0 ]), float (center [1 ])])
1081+
1082+ if do_recenter :
1083+ img_size = torch .tensor (get_image_size (img ), dtype = torch .float )
10691084 # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
1070- center_t = 1.0 * ( ct - img_size * 0.5 )
1085+ center = center - img_size * 0.5
10711086
10721087 # due to current incoherence of rotation angle direction between affine and rotate implementations
10731088 # we need to set -angle.
1074- angle_t = torch .tensor (float (angle )) if not isinstance (angle , Tensor ) else angle
1089+ angle = torch .tensor (float (angle )) if not isinstance (angle , Tensor ) else angle
10751090 matrix = _get_inverse_affine_matrix_tensor (
1076- center_t ,
1077- - angle_t ,
1078- torch .tensor ([0.0 , 0.0 ]),
1079- torch .tensor (1.0 ),
1080- torch .tensor ([0.0 , 0.0 ])
1091+ center , - angle , torch .tensor ([0.0 , 0.0 ]), torch .tensor (1.0 ), torch .tensor ([0.0 , 0.0 ])
10811092 )
10821093 return F_t .rotate (img , matrix = matrix , interpolation = interpolation .value , expand = expand , fill = fill )
10831094
0 commit comments