Skip to content
28 changes: 21 additions & 7 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,18 @@ def elastic_image_pil(
return to_pil_image(output, mode=image.mode)


def _create_identity_grid(size: Tuple[int, int], device: torch.device) -> torch.Tensor:
sy, sx = size
base_grid = torch.empty(1, sy, sx, 2, device=device)
x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device)
base_grid[..., 0].copy_(x_grid)

y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)

return base_grid


def elastic_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
Expand All @@ -1125,22 +1137,24 @@ def elastic_bounding_box(
# Or add spatial_size arg and check displacement shape
spatial_size = displacement.shape[-3], displacement.shape[-2]

id_grid = _FT._create_identity_grid(list(spatial_size)).to(bounding_box.device)
id_grid = _create_identity_grid(spatial_size, bounding_box.device)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
inv_grid = id_grid - displacement
inv_grid = id_grid.sub_(displacement)

# Get points from bboxes
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long)
index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long)
if points.is_floating_point():
points = points.ceil_()
index_xy = points.to(dtype=torch.long)
index_x, index_y = index_xy[:, 0], index_xy[:, 1]

# Transform points:
t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)

transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)

return convert_format_bounding_box(
Expand Down