diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index d4f1fadb0bf..65673203941 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -401,18 +401,16 @@ def pad_bounding_box( ) -> torch.Tensor: left, _, top, _ = _FT._parse_pad_padding(padding) - shape = bounding_box.shape - bounding_box = convert_bounding_box_format( bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY - ).view(-1, 4) + ) - bounding_box[:, 0::2] += left - bounding_box[:, 1::2] += top + bounding_box[..., 0::2] += left + bounding_box[..., 1::2] += top return convert_bounding_box_format( bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False - ).view(shape) + ) crop_image_tensor = _FT.crop @@ -425,19 +423,17 @@ def crop_bounding_box( top: int, left: int, ) -> torch.Tensor: - shape = bounding_box.shape - bounding_box = convert_bounding_box_format( bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY - ).view(-1, 4) + ) # Crop or implicit pad if left and/or top have negative values: - bounding_box[:, 0::2] -= left - bounding_box[:, 1::2] -= top + bounding_box[..., 0::2] -= left + bounding_box[..., 1::2] -= top return convert_bounding_box_format( bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False - ).view(shape) + ) def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: