Skip to content

Commit eb826c3

Browse files
committed
JIT should be happy now
1 parent 745ee27 commit eb826c3

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

torchvision/transforms/functional.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
900900
def _get_inverse_affine_matrix(
901901
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float]
902902
) -> List[float]:
903+
# TODO: REMOVE THIS METHOD IN FAVOR OF _get_inverse_affine_matrix_tensor
903904
# Helper method to compute inverse matrix for affine transformation
904905

905906
# As it is explained in PIL.Image.rotate
@@ -1056,28 +1057,38 @@ def rotate(
10561057
pil_interpolation = pil_modes_mapping[interpolation]
10571058
return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
10581059

1059-
if isinstance(angle, torch.Tensor) and angle.requires_grad:
1060-
# assert img.dtype is float
1061-
pass
1062-
1063-
center_t = torch.tensor([0.0, 0.0])
1064-
if center is not None:
1065-
# ct = torch.tensor([float(c) for c in list(center)]) if not isinstance(center, Tensor) else center
1066-
# THIS DOES NOT PASS JIT as we mix list/tuple of ints but list/tuple of floats are required
1067-
ct = torch.tensor(center) if not isinstance(center, Tensor) else center
1068-
img_size = torch.tensor(get_image_size(img))
1060+
# TODO: This is a rather generic check for input dtype if args are learnable
1061+
# We can refactor that later
1062+
if not torch.jit.is_scripting():
1063+
# torch.jit.script crashes with Segmentation fault (core dumped) on the following
1064+
# without if not torch.jit.is_scripting()
1065+
if (isinstance(angle, torch.Tensor) and angle.requires_grad) or (
1066+
isinstance(center, torch.Tensor) and center.requires_grad
1067+
):
1068+
if not img.is_floating_point():
1069+
raise ValueError("If angle is tensor that requires grad, image should be float")
1070+
1071+
do_recenter = True
1072+
if center is None:
1073+
center = torch.tensor([0.0, 0.0])
1074+
do_recenter = False
1075+
1076+
if isinstance(center, tuple):
1077+
center = list(center)
1078+
1079+
if isinstance(center, list):
1080+
center = torch.tensor([float(center[0]), float(center[1])])
1081+
1082+
if do_recenter:
1083+
img_size = torch.tensor(get_image_size(img), dtype=torch.float)
10691084
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
1070-
center_t = 1.0 * (ct - img_size * 0.5)
1085+
center = center - img_size * 0.5
10711086

10721087
# due to current incoherence of rotation angle direction between affine and rotate implementations
10731088
# we need to set -angle.
1074-
angle_t = torch.tensor(float(angle)) if not isinstance(angle, Tensor) else angle
1089+
angle = torch.tensor(float(angle)) if not isinstance(angle, Tensor) else angle
10751090
matrix = _get_inverse_affine_matrix_tensor(
1076-
center_t,
1077-
-angle_t,
1078-
torch.tensor([0.0, 0.0]),
1079-
torch.tensor(1.0),
1080-
torch.tensor([0.0, 0.0])
1091+
center, -angle, torch.tensor([0.0, 0.0]), torch.tensor(1.0), torch.tensor([0.0, 0.0])
10811092
)
10821093
return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
10831094

0 commit comments

Comments
 (0)