@@ -650,6 +650,8 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
650650 # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
651651
652652 # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
653+ # Points are shifted due to affine matrix torch convention about
654+ # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
653655 pts = torch .tensor (
654656 [
655657 [- 0.5 * w , - 0.5 * h , 1.0 ],
@@ -658,11 +660,15 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
658660 [0.5 * w , - 0.5 * h , 1.0 ],
659661 ]
660662 )
661- theta = torch .tensor (matrix , dtype = torch .float ).reshape ( 1 , 2 , 3 )
662- new_pts = pts . view ( 1 , 4 , 3 ). bmm ( theta .transpose ( 1 , 2 )). view ( 4 , 2 )
663+ theta = torch .tensor (matrix , dtype = torch .float ).view ( 2 , 3 )
664+ new_pts = torch . matmul ( pts , theta .T )
663665 min_vals , _ = new_pts .min (dim = 0 )
664666 max_vals , _ = new_pts .max (dim = 0 )
665667
668+ # shift points to [0, w] and [0, h] interval to match PIL results
669+ min_vals += torch .tensor ((w * 0.5 , h * 0.5 ))
670+ max_vals += torch .tensor ((w * 0.5 , h * 0.5 ))
671+
666672 # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
667673 tol = 1e-4
668674 cmax = torch .ceil ((max_vals / tol ).trunc_ () * tol )
0 commit comments