Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_to_datapoint_reference():

tensor_to = tensor.to(image)

assert type(tensor_to) is torch.Tensor
assert type(tensor_to) is datapoints.Image
assert tensor_to.dtype is torch.float64


Expand Down Expand Up @@ -145,7 +145,7 @@ def test_other_op_no_wrapping():
# any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here
output = image * 2

assert type(output) is torch.Tensor
assert type(output) is datapoints.Image


@pytest.mark.parametrize(
Expand All @@ -169,7 +169,7 @@ def test_inplace_op_no_wrapping():

output = image.add_(0)

assert type(output) is torch.Tensor
assert type(output) is datapoints.Image
assert type(image) is datapoints.Image


Expand Down
4 changes: 2 additions & 2 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_call(self, dims, inverse_dims):
if check_type(value, (Image, is_simple_tensor, Video)):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
assert type(transformed_value) == value_type
else:
assert transformed_value is value

Expand Down Expand Up @@ -407,7 +407,7 @@ def test_call(self, dims):
if check_type(value, (Image, is_simple_tensor, Video)):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
assert type(transformed_value) == value_type
else:
assert transformed_value is value

Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def make_vanilla_tensor_images(*args, **kwargs):
for image in make_images(*args, **kwargs):
if image.ndim > 3:
continue
yield image.data
yield image.as_subclass(torch.Tensor)


def make_pil_images(*args, **kwargs):
Expand Down
17 changes: 11 additions & 6 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,9 @@ def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)

output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(
in_boxes.as_subclass(torch.Tensor), format, top, left, height, width, size
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and similar changes below) was needed because in_boxes is now still a BBox instance, and resized_crop_bounding_boxes expects a tensor (there is an error saying something like "if you pass a bbox, don't pass the format").

)

if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
Expand Down Expand Up @@ -742,13 +744,16 @@ def _compute_expected_canvas_size(bbox, padding_):
bboxes_canvas_size = bboxes.canvas_size

output_boxes, output_canvas_size = F.pad_bounding_boxes(
bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding
bboxes.as_subclass(torch.Tensor), format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding
)

torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))

expected_bboxes = torch.stack(
[_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()]
[
_compute_expected_bbox(b.as_subclass(torch.Tensor), bboxes_format, padding)
for b in bboxes.reshape(-1, 4).unbind()
]
).reshape(bboxes.shape)

torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
Expand Down Expand Up @@ -836,7 +841,7 @@ def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_):

expected_bboxes = torch.stack(
[
_compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs)
_compute_expected_bbox(b.as_subclass(torch.Tensor), bboxes.format, bboxes.canvas_size, inv_pcoeffs)
for b in bboxes.reshape(-1, 4).unbind()
]
).reshape(bboxes.shape)
Expand Down Expand Up @@ -876,12 +881,12 @@ def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
bboxes_canvas_size = bboxes.canvas_size

output_boxes, output_canvas_size = F.center_crop_bounding_boxes(
bboxes, bboxes_format, bboxes_canvas_size, output_size
bboxes.as_subclass(torch.Tensor), bboxes_format, bboxes_canvas_size, output_size
)

expected_bboxes = torch.stack(
[
_compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size)
_compute_expected_bbox(b.as_subclass(torch.Tensor), bboxes_format, bboxes_canvas_size, output_size)
for b in bboxes.reshape(-1, 4).unbind()
]
).reshape(bboxes.shape)
Expand Down
83 changes: 83 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2270,3 +2270,86 @@ def test_image_correctness(self, permutation, batch_dims):
expected = self.reference_image_correctness(image, permutation=permutation)

torch.testing.assert_close(actual, expected)


def test_operations():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is mostly for illustrating the new behaviour. If we're OK with it, I'll refactor this test into something a little more polished


img = datapoints.Image(torch.rand(3, 10, 10))
t = torch.rand(3, 10, 10)
mask = datapoints.Mask(torch.rand(1, 10, 10))

for out in (
[
img + t,
t + img,
img * t,
t * img,
img + 3,
3 + img,
img * 3,
3 * img,
img + img,
img.sum(),
img.reshape(-1),
img.float(),
torch.stack([img, img]),
]
+ list(torch.chunk(img, 2))
+ list(torch.unbind(img))
):
assert isinstance(out, datapoints.Image)

for out in (
[
mask + t,
t + mask,
mask * t,
t * mask,
mask + 3,
3 + mask,
mask * 3,
3 * mask,
mask + mask,
mask.sum(),
mask.reshape(-1),
mask.float(),
torch.stack([mask, mask]),
]
+ list(torch.chunk(mask, 2))
+ list(torch.unbind(mask))
):
assert isinstance(out, datapoints.Mask)

with pytest.raises(TypeError, match="unsupported operand type"):
img + mask
Comment on lines +2323 to +2324
Copy link
Member Author

@NicolasHug NicolasHug Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users want to do that? Perfect, they'll need explicitly say what type they want as output by converting one of those operands to a tensor. We don't have to assume anything on their behalf and (surprisingly) return a pure tensor.

EDIT: as @pmeier pointed out offline, this is in fact the same behaviour as on main - nothing new


with pytest.raises(TypeError, match="unsupported operand type"):
img * mask

bboxes = datapoints.BoundingBoxes(
[[17, 16, 344, 495], [0, 10, 0, 10]], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1000, 1000)
)
t = torch.rand(2, 4)

for out in (
[
bboxes + t,
t + bboxes,
bboxes * t,
t * bboxes,
bboxes + 3,
3 + bboxes,
bboxes * 3,
3 * bboxes,
bboxes + bboxes,
bboxes.sum(),
bboxes.reshape(-1),
bboxes.float(),
torch.stack([bboxes, bboxes]),
]
+ list(torch.chunk(bboxes, 2))
+ list(torch.unbind(bboxes))
):
assert isinstance(out, datapoints.BoundingBoxes)
assert hasattr(out, "format")
assert hasattr(out, "canvas_size")
5 changes: 4 additions & 1 deletion test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,10 @@ def reference_pad_bounding_boxes(bounding_boxes, *, format, canvas_size, padding
width = canvas_size[1] + left + right

expected_bboxes = reference_affine_bounding_boxes_helper(
bounding_boxes, format=format, canvas_size=(height, width), affine_matrix=affine_matrix
bounding_boxes,
format=format,
canvas_size=(height, width),
affine_matrix=affine_matrix,
)
return expected_bboxes, (height, width)

Expand Down
29 changes: 26 additions & 3 deletions torchvision/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Optional, Tuple, Union
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from torch.utils._pytree import tree_flatten

from ._datapoint import Datapoint

Expand Down Expand Up @@ -99,5 +100,27 @@ def wrap_like(
canvas_size=canvas_size if canvas_size is not None else other.canvas_size,
)

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, canvas_size=self.canvas_size)
@classmethod
def __torch_function__(
cls,
func: Callable[..., torch.Tensor],
types: Tuple[Type[torch.Tensor], ...],
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
out = super().__torch_function__(func, types, args, kwargs)

# If there are BoundingBoxes instances in the output, their metadata got lost when we called
# super().__torch_function__. We need to restore the metadata somehow, so we choose to take
# the metadata from the first bbox in the parameters.
# This should be what we want in most cases. When it's not, it's probably a mis-use anyway, e.g.
# something like some_xyxy_bbox + some_xywh_bbox; we don't guard against those cases.
first_bbox_from_args = None
for obj in tree_flatten(out)[0]:
if isinstance(obj, BoundingBoxes):
if first_bbox_from_args is None:
flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator]
first_bbox_from_args = next(x for x in flat_params if isinstance(x, BoundingBoxes))
obj.format = first_bbox_from_args.format
obj.canvas_size = first_bbox_from_args.canvas_size
return out
80 changes: 1 addition & 79 deletions torchvision/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from __future__ import annotations

from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Optional, Type, TypeVar, Union

import torch
from torch._C import DisableTorchFunctionSubclass
from torch.types import _device, _dtype, _size


D = TypeVar("D", bound="Datapoint")
Expand Down Expand Up @@ -33,88 +31,12 @@ def _to_tensor(
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)

# The ops in this set are those that should *preserve* the Datapoint type,
# i.e. they are exceptions to the "no wrapping" rule.
_NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}

@classmethod
def __torch_function__(
cls,
func: Callable[..., torch.Tensor],
types: Tuple[Type[torch.Tensor], ...],
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
"""For general information about how the __torch_function__ protocol works,
see https://pytorch.org/docs/stable/notes/extending.html#extending-torch

TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call.

The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint`
use case, this has two downsides:

1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e.
``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output.

For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
"""
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.

if not all(issubclass(cls, t) for t in types):
return NotImplemented

with DisableTorchFunctionSubclass():
output = func(*args, **kwargs or dict())

if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls):
# We also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`.
return cls.wrap_like(args[0], output)

if isinstance(output, cls):
# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
# so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
return output.as_subclass(torch.Tensor)

return output

def _make_repr(self, **kwargs: Any) -> str:
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
# If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.
extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
return f"{super().__repr__()[:-1]}, {extra_repr})"

# Add properties for common attributes like shape, dtype, device, ndim etc
# this way we return the result without passing into __torch_function__
@property
def shape(self) -> _size: # type: ignore[override]
with DisableTorchFunctionSubclass():
return super().shape

@property
def ndim(self) -> int: # type: ignore[override]
with DisableTorchFunctionSubclass():
return super().ndim

@property
def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override]
with DisableTorchFunctionSubclass():
return super().device

@property
def dtype(self) -> _dtype: # type: ignore[override]
with DisableTorchFunctionSubclass():
return super().dtype

def __deepcopy__(self: D, memo: Dict[int, Any]) -> D:
# We need to detach first, since a plain `Tensor.clone` will be part of the computation graph, which does
# *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad`
Expand Down
6 changes: 4 additions & 2 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _copy_paste(

# Copy-paste masks:
masks = masks * inverse_paste_alpha_mask
non_all_zero_masks = masks.sum((-1, -2)) > 0
non_all_zero_masks = (masks.sum((-1, -2)) > 0).as_subclass(torch.Tensor)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was 2 other similar failures (below). The reason for the error is that (masks.sum((-1, -2)) > 0) is still a Mask object, and we can't use Masks as indices (line below).

This is the only kind of instance that I identified as potentially weird / confusing. But the error message is good enough to figure out the fix.
(In contrast, unwrapping all the time is likely to cause a lot more surprises and forces users to re-wrap all the time).

masks = masks[non_all_zero_masks]

# Do a shallow copy of the target dict
Expand All @@ -92,7 +92,9 @@ def _copy_paste(

# Check for degenerated boxes and remove them
boxes = F.convert_format_bounding_boxes(
out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY
out_target["boxes"].as_subclass(torch.Tensor),
old_format=bbox_format,
new_format=datapoints.BoundingBoxFormat.XYXY,
)
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
Expand Down
6 changes: 5 additions & 1 deletion torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
elif isinstance(inpt, datapoints.BoundingBoxes):
inpt = datapoints.BoundingBoxes.wrap_like(
inpt,
F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size),
F.clamp_bounding_boxes(
inpt[params["is_valid"]].as_subclass(torch.Tensor),
format=inpt.format,
canvas_size=inpt.canvas_size,
),
)

if params["needs_pad"]:
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt
return inpt.permute(*dims)


Expand All @@ -64,5 +64,5 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt
return inpt.transpose(*dims)
Loading