Skip to content
159 changes: 159 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from test_prototype_transforms_functional import (
make_bounding_box,
make_bounding_boxes,
make_image,
make_images,
make_label,
make_one_hot_labels,
Expand Down Expand Up @@ -1328,3 +1329,161 @@ def test__transform(self, mocker):
transform(inpt_sentinel)

mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)


class TestFixedSizeCrop:
def test__get_params(self, mocker):
crop_size = (7, 7)
batch_shape = (10,)
image_size = (11, 5)

transform = transforms.FixedSizeCrop(size=crop_size)

sample = dict(
image=make_image(size=image_size, color_space=features.ColorSpace.RGB),
bounding_boxes=make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=batch_shape
),
)
params = transform._get_params(sample)

assert params["needs_crop"]
assert params["height"] <= crop_size[0]
assert params["width"] <= crop_size[1]

assert (
isinstance(params["is_valid"], torch.Tensor)
and params["is_valid"].dtype is torch.bool
and params["is_valid"].shape == batch_shape
)

assert params["needs_pad"]
assert any(pad > 0 for pad in params["padding"])

@pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2)))
def test__transform(self, mocker, needs):
fill_sentinel = mocker.MagicMock()
padding_mode_sentinel = mocker.MagicMock()

transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)
transform._transformed_types = (mocker.MagicMock,)
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)

needs_crop, needs_pad = needs
top_sentinel = mocker.MagicMock()
left_sentinel = mocker.MagicMock()
height_sentinel = mocker.MagicMock()
width_sentinel = mocker.MagicMock()
padding_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=needs_crop,
top=top_sentinel,
left=left_sentinel,
height=height_sentinel,
width=width_sentinel,
padding=padding_sentinel,
needs_pad=needs_pad,
),
)

inpt_sentinel = mocker.MagicMock()

mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop")
mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad")
transform(inpt_sentinel)

if needs_crop:
mock_crop.assert_called_once_with(
inpt_sentinel,
top=top_sentinel,
left=left_sentinel,
height=height_sentinel,
width=width_sentinel,
)
else:
mock_crop.assert_not_called()

if needs_pad:
# If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use
# `MagicMock.assert_called_once_with` and have to perform the checks manually
mock_pad.assert_called_once()
args, kwargs = mock_pad.call_args
if not needs_crop:
assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else:
mock_pad.assert_not_called()

def test__transform_culling(self, mocker):
batch_size = 10
image_size = (10, 10)

is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=True,
top=0,
left=0,
height=image_size[0],
width=image_size[1],
is_valid=is_valid,
needs_pad=False,
),
)

bounding_boxes = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
)
segmentation_masks = make_segmentation_mask(size=image_size, extra_dims=(batch_size,))
labels = make_label(size=(batch_size,))

transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)

output = transform(
dict(
bounding_boxes=bounding_boxes,
segmentation_masks=segmentation_masks,
labels=labels,
)
)

assert_equal(output["bounding_boxes"], bounding_boxes[is_valid])
assert_equal(output["segmentation_masks"], segmentation_masks[is_valid])
assert_equal(output["labels"], labels[is_valid])

def test__transform_bounding_box_clamping(self, mocker):
batch_size = 3
image_size = (10, 10)

mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=True,
top=0,
left=0,
height=image_size[0],
width=image_size[1],
is_valid=torch.full((batch_size,), fill_value=True),
needs_pad=False,
),
)

bounding_box = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
)
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box")

transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)

transform(bounding_box)

mock.assert_called_once()
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CenterCrop,
ElasticTransform,
FiveCrop,
FixedSizeCrop,
Pad,
RandomAffine,
RandomCrop,
Expand Down
97 changes: 97 additions & 0 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,3 +783,100 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation)


class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
self.padding_mode = padding_mode

def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
_, height, width = get_image_dimensions(image)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)

needs_crop = new_height != height or new_width != width

offset_height = max(height - self.crop_height, 0)
offset_width = max(width - self.crop_width, 0)

r = torch.rand(1)
top = int(offset_height * r)
left = int(offset_width * r)

if needs_crop:
bounding_boxes = query_bounding_box(sample)
bounding_boxes = cast(
features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=height, width=width)
)
bounding_boxes = features.BoundingBox.new_like(
bounding_boxes,
F.clamp_bounding_box(
bounding_boxes, format=bounding_boxes.format, image_size=bounding_boxes.image_size
),
)
height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1)
else:
is_valid = None

pad_bottom = max(self.crop_height - new_height, 0)
pad_right = max(self.crop_width - new_width, 0)

needs_pad = pad_bottom != 0 or pad_right != 0

return dict(
needs_crop=needs_crop,
top=top,
left=left,
height=new_height,
width=new_width,
is_valid=is_valid,
padding=[0, 0, pad_right, pad_bottom],
needs_pad=needs_pad,
)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_crop"]:
inpt = F.crop(
inpt,
top=params["top"],
left=params["left"],
height=params["height"],
width=params["width"],
)
if isinstance(inpt, (features.Label, features.OneHotLabel, features.SegmentationMask)):
inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
elif isinstance(inpt, features.BoundingBox):
inpt = features.BoundingBox.new_like(
inpt,
F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, image_size=inpt.image_size),
)

if params["needs_pad"]:
inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode)

return inpt

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if not (
has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks."
)
return super().forward(sample)