55import torch
66from torchvision .prototype import features
77from torchvision .prototype .transforms import InterpolationMode
8- from torchvision .prototype .transforms .functional import get_image_dims
98from torchvision .transforms import functional_tensor as _FT , functional_pil as _FP
109from torchvision .transforms .functional import pil_modes_mapping , _get_inverse_affine_matrix
1110
@@ -40,7 +39,7 @@ def resize_image_tensor(
4039 antialias : Optional [bool ] = None ,
4140) -> torch .Tensor :
4241 new_height , new_width = size
43- num_channels , old_height , old_width = image . shape [ - 3 :]
42+ num_channels , old_height , old_width = _FT . get_image_dims ( image )
4443 batch_shape = image .shape [:- 3 ]
4544 return _FT .resize (
4645 image .reshape ((- 1 , num_channels , old_height , old_width )),
@@ -142,7 +141,7 @@ def affine_image_tensor(
142141
143142 center_f = [0.0 , 0.0 ]
144143 if center is not None :
145- _ , height , width = get_image_dims (img )
144+ _ , height , width = _FT . get_image_dims (img )
146145 # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
147146 center_f = [1.0 * (c - s * 0.5 ) for c , s in zip (center , (width , height ))]
148147
@@ -168,7 +167,7 @@ def affine_image_pil(
168167 # it is visually better to estimate the center without 0.5 offset
169168 # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
170169 if center is None :
171- _ , height , width = get_image_dims (img )
170+ _ , height , width = _FP . get_image_dims (img )
172171 center = [width * 0.5 , height * 0.5 ]
173172 matrix = _get_inverse_affine_matrix (center , angle , translate , scale , shear )
174173
@@ -185,7 +184,7 @@ def rotate_image_tensor(
185184) -> torch .Tensor :
186185 center_f = [0.0 , 0.0 ]
187186 if center is not None :
188- _ , height , width = get_image_dims (img )
187+ _ , height , width = _FT . get_image_dims (img )
189188 # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
190189 center_f = [1.0 * (c - s * 0.5 ) for c , s in zip (center , (width , height ))]
191190
@@ -261,13 +260,13 @@ def _center_crop_compute_crop_anchor(
261260
262261def center_crop_image_tensor (img : torch .Tensor , output_size : List [int ]) -> torch .Tensor :
263262 crop_height , crop_width = _center_crop_parse_output_size (output_size )
264- _ , image_height , image_width = get_image_dims (img )
263+ _ , image_height , image_width = _FT . get_image_dims (img )
265264
266265 if crop_height > image_height or crop_width > image_width :
267266 padding_ltrb = _center_crop_compute_padding (crop_height , crop_width , image_height , image_width )
268267 img = pad_image_tensor (img , padding_ltrb , fill = 0 )
269268
270- _ , image_height , image_width = get_image_dims (img )
269+ _ , image_height , image_width = _FT . get_image_dims (img )
271270 if crop_width == image_width and crop_height == image_height :
272271 return img
273272
@@ -277,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch
277276
278277def center_crop_image_pil (img : PIL .Image .Image , output_size : List [int ]) -> PIL .Image .Image :
279278 crop_height , crop_width = _center_crop_parse_output_size (output_size )
280- _ , image_height , image_width = get_image_dims (img )
279+ _ , image_height , image_width = _FP . get_image_dims (img )
281280
282281 if crop_height > image_height or crop_width > image_width :
283282 padding_ltrb = _center_crop_compute_padding (crop_height , crop_width , image_height , image_width )
284283 img = pad_image_pil (img , padding_ltrb , fill = 0 )
285284
286- _ , image_height , image_width = get_image_dims (img )
285+ _ , image_height , image_width = _FP . get_image_dims (img )
287286 if crop_width == image_width and crop_height == image_height :
288287 return img
289288
0 commit comments