Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from torch import jit
from torch.nn.functional import one_hot
from torchvision.prototype import features
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.functional_tensor import _max_value as get_max_value


make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")


Expand Down Expand Up @@ -421,6 +421,14 @@ def center_crop_bounding_box():
)


def center_crop_segmentation_mask():
for mask, output_size in itertools.product(
make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
):
yield SampleInput(mask, output_size)


@pytest.mark.parametrize(
"kernel",
[
Expand Down Expand Up @@ -1337,3 +1345,26 @@ def _compute_expected_bbox(bbox, output_size_):
else:
expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_boxes, expected_bboxes)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
def test_correctness_center_crop_segmentation_mask(device, output_size):
def _compute_expected_segmentation_mask(mask, output_size):
crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]

_, image_height, image_width = mask.shape
if crop_width > image_height or crop_height > image_width:
padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
mask = F.pad_image_tensor(mask, padding, fill=0)

left = round((image_width - crop_width) * 0.5)
top = round((image_height - crop_height) * 0.5)
Comment on lines +1361 to +1362
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel has an additional int call:

crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))

@vfdev-5 I recall we had issues with round before. Should we just switch to int in general?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier I do not quite remember what was the issue with round (maybe jit behaves differently to eager mode). For me the code you mention is more like a definition of crop_top and crop_left. For example, for bboxes we could also keep these values as float but let's define that crop_top/left are rounded integers.


return mask[:, top : top + crop_height, left : left + crop_width]

mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
actual = F.center_crop_segmentation_mask(mask, output_size)

expected = _compute_expected_segmentation_mask(mask, output_size)
torch.testing.assert_close(expected, actual)
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
resize_image_pil,
resize_segmentation_mask,
center_crop_bounding_box,
center_crop_segmentation_mask,
center_crop_image_tensor,
center_crop_image_pil,
resized_crop_bounding_box,
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,10 @@ def center_crop_bounding_box(
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left)


def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
return center_crop_image_tensor(img=segmentation_mask, output_size=output_size)


def resized_crop_image_tensor(
img: torch.Tensor,
top: int,
Expand Down