@@ -390,7 +390,7 @@ def _affine_bounding_box_xyxy(
390390 device = device ,
391391 )
392392 new_points = torch .matmul (points , transposed_affine_matrix )
393- tr , _ = torch .min (new_points , dim = 0 , keepdim = True )
393+ tr = torch .amin (new_points , dim = 0 , keepdim = True )
394394 # Translate bounding boxes
395395 out_bboxes .sub_ (tr .repeat ((1 , 2 )))
396396 # Estimate meta-data for image with inverted=True and with center=[0,0]
@@ -701,7 +701,7 @@ def pad_image_tensor(
701701 # internally.
702702 torch_padding = _parse_pad_padding (padding )
703703
704- if padding_mode not in [ "constant" , "edge" , "reflect" , "symmetric" ] :
704+ if padding_mode not in ( "constant" , "edge" , "reflect" , "symmetric" ) :
705705 raise ValueError (
706706 f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
707707 f"but got `'{ padding_mode } '`."
@@ -917,17 +917,17 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
917917 # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
918918 # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
919919 #
920-
920+ # TODO: should we define them transposed?
921921 theta1 = torch .tensor (
922922 [[[coeffs [0 ], coeffs [1 ], coeffs [2 ]], [coeffs [3 ], coeffs [4 ], coeffs [5 ]]]], dtype = dtype , device = device
923923 )
924924 theta2 = torch .tensor ([[[coeffs [6 ], coeffs [7 ], 1.0 ], [coeffs [6 ], coeffs [7 ], 1.0 ]]], dtype = dtype , device = device )
925925
926926 d = 0.5
927927 base_grid = torch .empty (1 , oh , ow , 3 , dtype = dtype , device = device )
928- x_grid = torch .linspace (d , ow * 1.0 + d - 1.0 , steps = ow , device = device )
928+ x_grid = torch .linspace (d , ow + d - 1.0 , steps = ow , device = device )
929929 base_grid [..., 0 ].copy_ (x_grid )
930- y_grid = torch .linspace (d , oh * 1.0 + d - 1.0 , steps = oh , device = device ).unsqueeze_ (- 1 )
930+ y_grid = torch .linspace (d , oh + d - 1.0 , steps = oh , device = device ).unsqueeze_ (- 1 )
931931 base_grid [..., 1 ].copy_ (y_grid )
932932 base_grid [..., 2 ].fill_ (1 )
933933
@@ -1059,6 +1059,7 @@ def perspective_bounding_box(
10591059 (- perspective_coeffs [0 ] * perspective_coeffs [7 ] + perspective_coeffs [1 ] * perspective_coeffs [6 ]) / denom ,
10601060 ]
10611061
1062+ # TODO: should we define them transposed?
10621063 theta1 = torch .tensor (
10631064 [[inv_coeffs [0 ], inv_coeffs [1 ], inv_coeffs [2 ]], [inv_coeffs [3 ], inv_coeffs [4 ], inv_coeffs [5 ]]],
10641065 dtype = dtype ,
@@ -1165,14 +1166,17 @@ def elastic_image_tensor(
11651166 return image
11661167
11671168 shape = image .shape
1169+ device = image .device
11681170
11691171 if image .ndim > 4 :
11701172 image = image .reshape ((- 1 ,) + shape [- 3 :])
11711173 needs_unsquash = True
11721174 else :
11731175 needs_unsquash = False
11741176
1175- output = _FT .elastic_transform (image , displacement , interpolation = interpolation .value , fill = fill )
1177+ image_height , image_width = shape [- 2 :]
1178+ grid = _create_identity_grid ((image_height , image_width ), device = device ).add_ (displacement .to (device ))
1179+ output = _FT ._apply_grid_transform (image , grid , interpolation .value , fill )
11761180
11771181 if needs_unsquash :
11781182 output = output .reshape (shape )
@@ -1505,8 +1509,7 @@ def five_crop_image_tensor(
15051509 image_height , image_width = image .shape [- 2 :]
15061510
15071511 if crop_width > image_width or crop_height > image_height :
1508- msg = "Requested crop size {} is bigger than input size {}"
1509- raise ValueError (msg .format (size , (image_height , image_width )))
1512+ raise ValueError (f"Requested crop size { size } is bigger than input size { (image_height , image_width )} " )
15101513
15111514 tl = crop_image_tensor (image , 0 , 0 , crop_height , crop_width )
15121515 tr = crop_image_tensor (image , 0 , image_width - crop_width , crop_height , crop_width )
@@ -1525,8 +1528,7 @@ def five_crop_image_pil(
15251528 image_height , image_width = get_spatial_size_image_pil (image )
15261529
15271530 if crop_width > image_width or crop_height > image_height :
1528- msg = "Requested crop size {} is bigger than input size {}"
1529- raise ValueError (msg .format (size , (image_height , image_width )))
1531+ raise ValueError (f"Requested crop size { size } is bigger than input size { (image_height , image_width )} " )
15301532
15311533 tl = crop_image_pil (image , 0 , 0 , crop_height , crop_width )
15321534 tr = crop_image_pil (image , 0 , image_width - crop_width , crop_height , crop_width )
0 commit comments