Skip to content

Commit 4383337

Browse files
committed
Add support of inplace on convert_format_bounding_box
1 parent 6af796a commit 4383337

File tree

5 files changed

+32
-36
lines changed

5 files changed

+32
-36
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _copy_paste(
262262
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
263263
xyxy_boxes[:, 2:] += 1
264264
boxes = F.convert_format_bounding_box(
265-
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format
265+
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
266266
)
267267
out_target["boxes"] = torch.cat([boxes, paste_boxes])
268268

torchvision/prototype/transforms/_geometry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:
618618

619619
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
620620
orig_h, orig_w = query_spatial_size(flat_inputs)
621-
bboxes = query_bounding_box(flat_inputs)
621+
bboxes = query_bounding_box(flat_inputs).as_subclass(torch.Tensor)
622622

623623
while True:
624624
# sample an option
@@ -798,6 +798,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
798798

799799
if needs_crop and bounding_boxes is not None:
800800
format = bounding_boxes.format
801+
bounding_boxes = bounding_boxes.as_subclass(torch.Tensor)
801802
bounding_boxes, spatial_size = F.crop_bounding_box(
802803
bounding_boxes, format=format, top=top, left=left, height=new_height, width=new_width
803804
)

torchvision/prototype/transforms/_misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __init__(self, min_size: float = 1.0) -> None:
200200
self.min_size = min_size
201201

202202
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
203-
bounding_box = query_bounding_box(flat_inputs)
203+
bounding_box = query_bounding_box(flat_inputs).as_subclass(torch.Tensor)
204204

205205
# TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to
206206
# be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,14 @@ def horizontal_flip_bounding_box(
3838

3939
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
4040
# BoundingBoxFormat instead of converting back and forth
41-
bounding_box = (
42-
bounding_box.clone()
43-
if format == features.BoundingBoxFormat.XYXY
44-
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
41+
bounding_box = convert_format_bounding_box(
42+
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True
4543
).reshape(-1, 4)
4644

4745
bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]]
4846

4947
return convert_format_bounding_box(
50-
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
48+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
5149
).reshape(shape)
5250

5351

@@ -79,16 +77,14 @@ def vertical_flip_bounding_box(
7977

8078
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
8179
# BoundingBoxFormat instead of converting back and forth
82-
bounding_box = (
83-
bounding_box.clone()
84-
if format == features.BoundingBoxFormat.XYXY
85-
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
80+
bounding_box = convert_format_bounding_box(
81+
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True
8682
).reshape(-1, 4)
8783

8884
bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]]
8985

9086
return convert_format_bounding_box(
91-
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
87+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
9288
).reshape(shape)
9389

9490

@@ -412,7 +408,7 @@ def affine_bounding_box(
412408
# out_bboxes should be of shape [N boxes, 4]
413409

414410
return convert_format_bounding_box(
415-
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
411+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
416412
).reshape(original_shape)
417413

418414

@@ -594,9 +590,9 @@ def rotate_bounding_box(
594590
)
595591

596592
return (
597-
convert_format_bounding_box(out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format).reshape(
598-
original_shape
599-
),
593+
convert_format_bounding_box(
594+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
595+
).reshape(original_shape),
600596
spatial_size,
601597
)
602598

@@ -815,18 +811,18 @@ def crop_bounding_box(
815811
) -> Tuple[torch.Tensor, Tuple[int, int]]:
816812
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
817813
# BoundingBoxFormat instead of converting back and forth
818-
bounding_box = (
819-
bounding_box.clone()
820-
if format == features.BoundingBoxFormat.XYXY
821-
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
814+
bounding_box = convert_format_bounding_box(
815+
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True
822816
)
823817

824818
# Crop or implicit pad if left and/or top have negative values:
825819
bounding_box[..., 0::2] -= left
826820
bounding_box[..., 1::2] -= top
827821

828822
return (
829-
convert_format_bounding_box(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format),
823+
convert_format_bounding_box(
824+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
825+
),
830826
(height, width),
831827
)
832828

@@ -964,7 +960,7 @@ def perspective_bounding_box(
964960
# out_bboxes should be of shape [N boxes, 4]
965961

966962
return convert_format_bounding_box(
967-
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
963+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
968964
).reshape(original_shape)
969965

970966

@@ -1085,7 +1081,7 @@ def elastic_bounding_box(
10851081
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
10861082

10871083
return convert_format_bounding_box(
1088-
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
1084+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
10891085
).reshape(original_shape)
10901086

10911087

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ def get_num_frames(inpt: features.VideoTypeJIT) -> int:
119119
raise TypeError(f"The video should be a Tensor. Got {type(inpt)}")
120120

121121

122-
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
123-
xyxy = xywh.clone()
122+
def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
123+
if not inplace:
124+
xyxy = xywh.clone()
124125
xyxy[..., 2:] += xyxy[..., :2]
125126
return xyxy
126127

@@ -150,20 +151,20 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor:
150151

151152

152153
def convert_format_bounding_box(
153-
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat
154+
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
154155
) -> torch.Tensor:
155156
if new_format == old_format:
156157
return bounding_box
157158

158159
if old_format == BoundingBoxFormat.XYWH:
159-
bounding_box = _xywh_to_xyxy(bounding_box)
160+
bounding_box = _xywh_to_xyxy(bounding_box, inplace)
160161
elif old_format == BoundingBoxFormat.CXCYWH:
161-
bounding_box = _cxcywh_to_xyxy(bounding_box)
162+
bounding_box = _cxcywh_to_xyxy(bounding_box, inplace)
162163

163164
if new_format == BoundingBoxFormat.XYWH:
164-
bounding_box = _xyxy_to_xywh(bounding_box)
165+
bounding_box = _xyxy_to_xywh(bounding_box, inplace)
165166
elif new_format == BoundingBoxFormat.CXCYWH:
166-
bounding_box = _xyxy_to_cxcywh(bounding_box)
167+
bounding_box = _xyxy_to_cxcywh(bounding_box, inplace)
167168

168169
return bounding_box
169170

@@ -173,14 +174,12 @@ def clamp_bounding_box(
173174
) -> torch.Tensor:
174175
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
175176
# BoundingBoxFormat instead of converting back and forth
176-
xyxy_boxes = (
177-
bounding_box.clone()
178-
if format == BoundingBoxFormat.XYXY
179-
else convert_format_bounding_box(bounding_box, format, BoundingBoxFormat.XYXY)
177+
xyxy_boxes = convert_format_bounding_box(
178+
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True
180179
)
181180
xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1])
182181
xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0])
183-
return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format)
182+
return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True)
184183

185184

186185
def _strip_alpha(image: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)