Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,17 @@ def rotate_segmentation_mask():
)


@register_kernel_info_from_sample_inputs_fn
def crop_bounding_box():
for bounding_box, top, left in itertools.product(make_bounding_boxes(), [-8, 0, 9], [-8, 0, 9]):
yield SampleInput(
bounding_box,
format=bounding_box.format,
top=top,
left=left,
)


@pytest.mark.parametrize(
"kernel",
[
Expand Down Expand Up @@ -808,3 +819,44 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
expected_mask = torch.rot90(mask, k=1, dims=(-2, -1))
out_mask = F.rotate_segmentation_mask(mask, 90, expand=False)
torch.testing.assert_close(out_mask, expected_mask)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"top, left, height, width, expected_bboxes",
[
[8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]],
[-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
],
)
def test_correctness_crop_bounding_box(device, top, left, height, width, expected_bboxes):

# Expected bboxes computed using Albumentations:
# import numpy as np
# from albumentations.augmentations.crops.functional import crop_bbox_by_coords, normalize_bbox, denormalize_bbox
# expected_bboxes = []
# for in_box in in_boxes:
# n_in_box = normalize_bbox(in_box, *size)
# n_out_box = crop_bbox_by_coords(
# n_in_box, (left, top, left + width, top + height), height, width, *size
# )
# out_box = denormalize_bbox(n_out_box, height, width)
# expected_bboxes.append(out_box)

size = (64, 76)
# xyxy format
in_boxes = [
[10.0, 15.0, 25.0, 35.0],
[50.0, 5.0, 70.0, 22.0],
[45.0, 46.0, 56.0, 62.0],
]
in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=size, device=device)

output_boxes = F.crop_bounding_box(
in_boxes,
in_boxes.format,
top,
left,
)

torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
3 changes: 2 additions & 1 deletion torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@
rotate_image_tensor,
rotate_image_pil,
rotate_segmentation_mask,
pad_bounding_box,
pad_image_tensor,
pad_image_pil,
pad_bounding_box,
crop_bounding_box,
crop_image_tensor,
crop_image_pil,
perspective_image_tensor,
Expand Down
21 changes: 21 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,27 @@ def pad_bounding_box(
crop_image_pil = _FP.crop


def crop_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
top: int,
left: int,
) -> torch.Tensor:
shape = bounding_box.shape

bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)

# Crop or implicit pad if left and/or top have negative values:
bounding_box[:, 0::2] -= left
Copy link

@vadimkantorov vadimkantorov Apr 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For batch support it may be better to instead use bounding_box[..., 0::2]. Then view(-1, 4) and view(shape) would not be needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I was under the impression ... was not supported by torch.jit.script in general, but this seems fine:

@torch.jit.script
def foo(x: torch.Tensor) -> torch.Tensor:
    y = x.clone()
    y[..., 0::2] += 1
    return y

It seems, only the explicit indices are not supported:

@torch.jit.script
def bar(x: torch.Tensor) -> torch.Tensor:
    y = x.clone()
    y[..., [0, 2]] += 1
    return y
RuntimeError: 
Ellipses followed by tensor indexing is currently not supported:
[...]
def bar(x: torch.Tensor) -> torch.Tensor:
    y = x.clone()
    y[..., [0, 2]] += 1
    ~~~~~~~~~~~~~ <--- HERE
    return y

bounding_box[:, 1::2] -= top

return convert_bounding_box_format(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape)


def perspective_image_tensor(
img: torch.Tensor,
perspective_coeffs: List[float],
Expand Down