diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index b253b7cbf2c..cf861c46d24 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -262,7 +262,7 @@ def _copy_paste( # https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422 xyxy_boxes[:, 2:] += 1 boxes = F.convert_format_bounding_box( - xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format + xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True ) out_target["boxes"] = torch.cat([boxes, paste_boxes]) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index c5ab38d8418..214296d0350 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -646,7 +646,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: continue # check for any valid boxes with centers within the crop area - xyxy_bboxes = F.convert_format_bounding_box(bboxes, bboxes.format, features.BoundingBoxFormat.XYXY) + xyxy_bboxes = F.convert_format_bounding_box( + bboxes.as_subclass(torch.Tensor), bboxes.format, features.BoundingBoxFormat.XYXY + ) cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) @@ -799,7 +801,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: if needs_crop and bounding_boxes is not None: format = bounding_boxes.format bounding_boxes, spatial_size = F.crop_bounding_box( - bounding_boxes, format=format, top=top, left=left, height=new_height, width=new_width + bounding_boxes.as_subclass(torch.Tensor), + format=format, + top=top, + left=left, + height=new_height, + width=new_width, ) bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size) height_and_width = F.convert_format_bounding_box( diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index aad684bf1a8..1d4b0f6fa1d 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -207,7 +207,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # format,we need to convert first just to afterwards compute the width and height again, although they were # there in the first place for these formats. bounding_box = F.convert_format_bounding_box( - bounding_box, old_format=bounding_box.format, new_format=features.BoundingBoxFormat.XYXY + bounding_box.as_subclass(torch.Tensor), + old_format=bounding_box.format, + new_format=features.BoundingBoxFormat.XYXY, ) valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 7f709b73b4b..5b71c79d34a 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -38,16 +38,14 @@ def horizontal_flip_bounding_box( # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # BoundingBoxFormat instead of converting back and forth - bounding_box = ( - bounding_box.clone() - if format == features.BoundingBoxFormat.XYXY - else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) + bounding_box = convert_format_bounding_box( + bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True ).reshape(-1, 4) bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]] return convert_format_bounding_box( - bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format + bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True ).reshape(shape) @@ -79,16 +77,14 @@ def vertical_flip_bounding_box( # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # BoundingBoxFormat instead of converting back and forth - bounding_box = ( - bounding_box.clone() - if format == features.BoundingBoxFormat.XYXY - else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) + bounding_box = convert_format_bounding_box( + bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True ).reshape(-1, 4) bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]] return convert_format_bounding_box( - bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format + bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True ).reshape(shape) @@ -412,7 +408,7 @@ def affine_bounding_box( # out_bboxes should be of shape [N boxes, 4] return convert_format_bounding_box( - out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True ).reshape(original_shape) @@ -594,9 +590,9 @@ def rotate_bounding_box( ) return ( - convert_format_bounding_box(out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format).reshape( - original_shape - ), + convert_format_bounding_box( + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True + ).reshape(original_shape), spatial_size, ) @@ -815,10 +811,8 @@ def crop_bounding_box( ) -> Tuple[torch.Tensor, Tuple[int, int]]: # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # BoundingBoxFormat instead of converting back and forth - bounding_box = ( - bounding_box.clone() - if format == features.BoundingBoxFormat.XYXY - else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY) + bounding_box = convert_format_bounding_box( + bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True ) # Crop or implicit pad if left and/or top have negative values: @@ -826,7 +820,9 @@ def crop_bounding_box( bounding_box[..., 1::2] -= top return ( - convert_format_bounding_box(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format), + convert_format_bounding_box( + bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True + ), (height, width), ) @@ -964,7 +960,7 @@ def perspective_bounding_box( # out_bboxes should be of shape [N boxes, 4] return convert_format_bounding_box( - out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True ).reshape(original_shape) @@ -1085,7 +1081,7 @@ def elastic_bounding_box( out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) return convert_format_bounding_box( - out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True ).reshape(original_shape) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 5e017848415..81ccd08de5d 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -119,51 +119,60 @@ def get_num_frames(inpt: features.VideoTypeJIT) -> int: raise TypeError(f"The video should be a Tensor. Got {type(inpt)}") -def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: - xyxy = xywh.clone() +def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor: + xyxy = xywh if inplace else xywh.clone() xyxy[..., 2:] += xyxy[..., :2] return xyxy -def _xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: - xywh = xyxy.clone() +def _xyxy_to_xywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: + xywh = xyxy if inplace else xyxy.clone() xywh[..., 2:] -= xywh[..., :2] return xywh -def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor: - cx, cy, w, h = torch.unbind(cxcywh, dim=-1) - x1 = cx - 0.5 * w - y1 = cy - 0.5 * h - x2 = cx + 0.5 * w - y2 = cy + 0.5 * h - return torch.stack((x1, y1, x2, y2), dim=-1).to(cxcywh.dtype) +def _cxcywh_to_xyxy(cxcywh: torch.Tensor, inplace: bool) -> torch.Tensor: + if not inplace: + cxcywh = cxcywh.clone() + # Trick to do fast division by 2 and ceil, without casting. It produces the same result as + # `torchvision.ops._box_convert._box_cxcywh_to_xyxy`. + half_wh = cxcywh[..., 2:].div(-2, rounding_mode=None if cxcywh.is_floating_point() else "floor").abs_() + # (cx - width / 2) = x1, same for y1 + cxcywh[..., :2].sub_(half_wh) + # (x1 + width) = x2, same for y2 + cxcywh[..., 2:].add_(cxcywh[..., :2]) -def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor: - x1, y1, x2, y2 = torch.unbind(xyxy, dim=-1) - cx = (x1 + x2) / 2 - cy = (y1 + y2) / 2 - w = x2 - x1 - h = y2 - y1 - return torch.stack((cx, cy, w, h), dim=-1).to(xyxy.dtype) + return cxcywh + + +def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: + if not inplace: + xyxy = xyxy.clone() + + # (x2 - x1) = width, same for height + xyxy[..., 2:].sub_(xyxy[..., :2]) + # (x1 * 2 + width) / 2 = x1 + width / 2 = x1 + (x2-x1)/2 = (x1 + x2)/2 = cx, same for cy + xyxy[..., :2].mul_(2).add_(xyxy[..., 2:]).div_(2, rounding_mode=None if xyxy.is_floating_point() else "floor") + + return xyxy def convert_format_bounding_box( - bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat + bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False ) -> torch.Tensor: if new_format == old_format: return bounding_box if old_format == BoundingBoxFormat.XYWH: - bounding_box = _xywh_to_xyxy(bounding_box) + bounding_box = _xywh_to_xyxy(bounding_box, inplace) elif old_format == BoundingBoxFormat.CXCYWH: - bounding_box = _cxcywh_to_xyxy(bounding_box) + bounding_box = _cxcywh_to_xyxy(bounding_box, inplace) if new_format == BoundingBoxFormat.XYWH: - bounding_box = _xyxy_to_xywh(bounding_box) + bounding_box = _xyxy_to_xywh(bounding_box, inplace) elif new_format == BoundingBoxFormat.CXCYWH: - bounding_box = _xyxy_to_cxcywh(bounding_box) + bounding_box = _xyxy_to_cxcywh(bounding_box, inplace) return bounding_box @@ -173,14 +182,12 @@ def clamp_bounding_box( ) -> torch.Tensor: # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # BoundingBoxFormat instead of converting back and forth - xyxy_boxes = ( - bounding_box.clone() - if format == BoundingBoxFormat.XYXY - else convert_format_bounding_box(bounding_box, format, BoundingBoxFormat.XYXY) + xyxy_boxes = convert_format_bounding_box( + bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True ) xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1]) xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0]) - return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format) + return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True) def _strip_alpha(image: torch.Tensor) -> torch.Tensor: