diff --git a/test/test_datapoints.py b/test/test_datapoints.py index 4987b1991fc..3015c3b7e9b 100644 --- a/test/test_datapoints.py +++ b/test/test_datapoints.py @@ -86,7 +86,7 @@ def test_to_datapoint_reference(): tensor_to = tensor.to(image) - assert type(tensor_to) is torch.Tensor + assert type(tensor_to) is datapoints.Image assert tensor_to.dtype is torch.float64 @@ -145,7 +145,7 @@ def test_other_op_no_wrapping(): # any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here output = image * 2 - assert type(output) is torch.Tensor + assert type(output) is datapoints.Image @pytest.mark.parametrize( @@ -169,7 +169,7 @@ def test_inplace_op_no_wrapping(): output = image.add_(0) - assert type(output) is torch.Tensor + assert type(output) is datapoints.Image assert type(image) is datapoints.Image diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index d395c224785..d665ad44fad 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -362,7 +362,7 @@ def test_call(self, dims, inverse_dims): if check_type(value, (Image, is_simple_tensor, Video)): if transform.dims.get(value_type) is not None: assert transformed_value.permute(inverse_dims[value_type]).equal(value) - assert type(transformed_value) == torch.Tensor + assert type(transformed_value) == value_type else: assert transformed_value is value @@ -407,7 +407,7 @@ def test_call(self, dims): if check_type(value, (Image, is_simple_tensor, Video)): if transposed_dims is not None: assert transformed_value.transpose(*transposed_dims).equal(value) - assert type(transformed_value) == torch.Tensor + assert type(transformed_value) == value_type else: assert transformed_value is value diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 5f4a9b62898..fd2b6d09e97 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -36,7 +36,7 @@ def make_vanilla_tensor_images(*args, **kwargs): for image in make_images(*args, **kwargs): if image.ndim > 3: continue - yield image.data + yield image.as_subclass(torch.Tensor) def make_pil_images(*args, **kwargs): diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index bf447c8ce71..680f9827ca8 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -687,7 +687,9 @@ def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_): if format != datapoints.BoundingBoxFormat.XYXY: in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) - output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size) + output_boxes, output_canvas_size = F.resized_crop_bounding_boxes( + in_boxes.as_subclass(torch.Tensor), format, top, left, height, width, size + ) if format != datapoints.BoundingBoxFormat.XYXY: output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) @@ -742,13 +744,16 @@ def _compute_expected_canvas_size(bbox, padding_): bboxes_canvas_size = bboxes.canvas_size output_boxes, output_canvas_size = F.pad_bounding_boxes( - bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding + bboxes.as_subclass(torch.Tensor), format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding ) torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding)) expected_bboxes = torch.stack( - [_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()] + [ + _compute_expected_bbox(b.as_subclass(torch.Tensor), bboxes_format, padding) + for b in bboxes.reshape(-1, 4).unbind() + ] ).reshape(bboxes.shape) torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) @@ -836,7 +841,7 @@ def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_): expected_bboxes = torch.stack( [ - _compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs) + _compute_expected_bbox(b.as_subclass(torch.Tensor), bboxes.format, bboxes.canvas_size, inv_pcoeffs) for b in bboxes.reshape(-1, 4).unbind() ] ).reshape(bboxes.shape) @@ -876,12 +881,12 @@ def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_): bboxes_canvas_size = bboxes.canvas_size output_boxes, output_canvas_size = F.center_crop_bounding_boxes( - bboxes, bboxes_format, bboxes_canvas_size, output_size + bboxes.as_subclass(torch.Tensor), bboxes_format, bboxes_canvas_size, output_size ) expected_bboxes = torch.stack( [ - _compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size) + _compute_expected_bbox(b.as_subclass(torch.Tensor), bboxes_format, bboxes_canvas_size, output_size) for b in bboxes.reshape(-1, 4).unbind() ] ).reshape(bboxes.shape) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index c1a21b6346e..ffc4f78dc40 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2270,3 +2270,86 @@ def test_image_correctness(self, permutation, batch_dims): expected = self.reference_image_correctness(image, permutation=permutation) torch.testing.assert_close(actual, expected) + + +def test_operations(): + + img = datapoints.Image(torch.rand(3, 10, 10)) + t = torch.rand(3, 10, 10) + mask = datapoints.Mask(torch.rand(1, 10, 10)) + + for out in ( + [ + img + t, + t + img, + img * t, + t * img, + img + 3, + 3 + img, + img * 3, + 3 * img, + img + img, + img.sum(), + img.reshape(-1), + img.float(), + torch.stack([img, img]), + ] + + list(torch.chunk(img, 2)) + + list(torch.unbind(img)) + ): + assert isinstance(out, datapoints.Image) + + for out in ( + [ + mask + t, + t + mask, + mask * t, + t * mask, + mask + 3, + 3 + mask, + mask * 3, + 3 * mask, + mask + mask, + mask.sum(), + mask.reshape(-1), + mask.float(), + torch.stack([mask, mask]), + ] + + list(torch.chunk(mask, 2)) + + list(torch.unbind(mask)) + ): + assert isinstance(out, datapoints.Mask) + + with pytest.raises(TypeError, match="unsupported operand type"): + img + mask + + with pytest.raises(TypeError, match="unsupported operand type"): + img * mask + + bboxes = datapoints.BoundingBoxes( + [[17, 16, 344, 495], [0, 10, 0, 10]], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1000, 1000) + ) + t = torch.rand(2, 4) + + for out in ( + [ + bboxes + t, + t + bboxes, + bboxes * t, + t * bboxes, + bboxes + 3, + 3 + bboxes, + bboxes * 3, + 3 * bboxes, + bboxes + bboxes, + bboxes.sum(), + bboxes.reshape(-1), + bboxes.float(), + torch.stack([bboxes, bboxes]), + ] + + list(torch.chunk(bboxes, 2)) + + list(torch.unbind(bboxes)) + ): + assert isinstance(out, datapoints.BoundingBoxes) + assert hasattr(out, "format") + assert hasattr(out, "canvas_size") diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index ac5651d3217..e96a549e87b 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -539,7 +539,10 @@ def reference_pad_bounding_boxes(bounding_boxes, *, format, canvas_size, padding width = canvas_size[1] + left + right expected_bboxes = reference_affine_bounding_boxes_helper( - bounding_boxes, format=format, canvas_size=(height, width), affine_matrix=affine_matrix + bounding_boxes, + format=format, + canvas_size=(height, width), + affine_matrix=affine_matrix, ) return expected_bboxes, (height, width) diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index d459a55448a..fcf3088dc40 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -1,9 +1,10 @@ from __future__ import annotations from enum import Enum -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Type, Union import torch +from torch.utils._pytree import tree_flatten from ._datapoint import Datapoint @@ -99,5 +100,27 @@ def wrap_like( canvas_size=canvas_size if canvas_size is not None else other.canvas_size, ) - def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(format=self.format, canvas_size=self.canvas_size) + @classmethod + def __torch_function__( + cls, + func: Callable[..., torch.Tensor], + types: Tuple[Type[torch.Tensor], ...], + args: Sequence[Any] = (), + kwargs: Optional[Mapping[str, Any]] = None, + ) -> torch.Tensor: + out = super().__torch_function__(func, types, args, kwargs) + + # If there are BoundingBoxes instances in the output, their metadata got lost when we called + # super().__torch_function__. We need to restore the metadata somehow, so we choose to take + # the metadata from the first bbox in the parameters. + # This should be what we want in most cases. When it's not, it's probably a mis-use anyway, e.g. + # something like some_xyxy_bbox + some_xywh_bbox; we don't guard against those cases. + first_bbox_from_args = None + for obj in tree_flatten(out)[0]: + if isinstance(obj, BoundingBoxes): + if first_bbox_from_args is None: + flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator] + first_bbox_from_args = next(x for x in flat_params if isinstance(x, BoundingBoxes)) + obj.format = first_bbox_from_args.format + obj.canvas_size = first_bbox_from_args.canvas_size + return out diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 7032d518fe4..0b38f979c6e 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -1,10 +1,8 @@ from __future__ import annotations -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Optional, Type, TypeVar, Union import torch -from torch._C import DisableTorchFunctionSubclass -from torch.types import _device, _dtype, _size D = TypeVar("D", bound="Datapoint") @@ -33,88 +31,12 @@ def _to_tensor( def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: return tensor.as_subclass(cls) - # The ops in this set are those that should *preserve* the Datapoint type, - # i.e. they are exceptions to the "no wrapping" rule. - _NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_} - - @classmethod - def __torch_function__( - cls, - func: Callable[..., torch.Tensor], - types: Tuple[Type[torch.Tensor], ...], - args: Sequence[Any] = (), - kwargs: Optional[Mapping[str, Any]] = None, - ) -> torch.Tensor: - """For general information about how the __torch_function__ protocol works, - see https://pytorch.org/docs/stable/notes/extending.html#extending-torch - - TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the - ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the - ``args`` and ``kwargs`` of the original call. - - The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint` - use case, this has two downsides: - - 1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e. - ``return cls(func(*args, **kwargs))``, will fail for them. - 2. For most operations, there is no way of knowing if the input type is still valid for the output. - - For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are - listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS` - """ - # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we - # need to reimplement the functionality. - - if not all(issubclass(cls, t) for t in types): - return NotImplemented - - with DisableTorchFunctionSubclass(): - output = func(*args, **kwargs or dict()) - - if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls): - # We also require the primary operand, i.e. `args[0]`, to be - # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will - # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, - # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with - # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would - # be wrapped into a `datapoints.Image`. - return cls.wrap_like(args[0], output) - - if isinstance(output, cls): - # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`, - # so for those, the output is still a Datapoint. Thus, we need to manually unwrap. - return output.as_subclass(torch.Tensor) - - return output - def _make_repr(self, **kwargs: Any) -> str: # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532. # If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class. extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items()) return f"{super().__repr__()[:-1]}, {extra_repr})" - # Add properties for common attributes like shape, dtype, device, ndim etc - # this way we return the result without passing into __torch_function__ - @property - def shape(self) -> _size: # type: ignore[override] - with DisableTorchFunctionSubclass(): - return super().shape - - @property - def ndim(self) -> int: # type: ignore[override] - with DisableTorchFunctionSubclass(): - return super().ndim - - @property - def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override] - with DisableTorchFunctionSubclass(): - return super().device - - @property - def dtype(self) -> _dtype: # type: ignore[override] - with DisableTorchFunctionSubclass(): - return super().dtype - def __deepcopy__(self: D, memo: Dict[int, Any]) -> D: # We need to detach first, since a plain `Tensor.clone` will be part of the computation graph, which does # *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad` diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 53f3f801303..2bc2fef4eb3 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -66,7 +66,7 @@ def _copy_paste( # Copy-paste masks: masks = masks * inverse_paste_alpha_mask - non_all_zero_masks = masks.sum((-1, -2)) > 0 + non_all_zero_masks = (masks.sum((-1, -2)) > 0).as_subclass(torch.Tensor) masks = masks[non_all_zero_masks] # Do a shallow copy of the target dict @@ -92,7 +92,9 @@ def _copy_paste( # Check for degenerated boxes and remove them boxes = F.convert_format_bounding_boxes( - out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY + out_target["boxes"].as_subclass(torch.Tensor), + old_format=bbox_format, + new_format=datapoints.BoundingBoxFormat.XYXY, ) degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] if degenerate_boxes.any(): diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 1a2802db0ac..b3d113322e4 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -115,7 +115,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: elif isinstance(inpt, datapoints.BoundingBoxes): inpt = datapoints.BoundingBoxes.wrap_like( inpt, - F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size), + F.clamp_bounding_boxes( + inpt[params["is_valid"]].as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + ), ) if params["needs_pad"]: diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index f1b859aac03..9672002eac0 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -42,7 +42,7 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]] def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: - return inpt.as_subclass(torch.Tensor) + return inpt return inpt.permute(*dims) @@ -64,5 +64,5 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: - return inpt.as_subclass(torch.Tensor) + return inpt return inpt.transpose(*dims) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index da71cebb416..f469bb8edc3 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -404,7 +404,7 @@ def forward(self, *inputs: Any) -> Any: valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) - params = dict(valid=valid, labels=labels) + params = dict(valid=valid.as_subclass(torch.Tensor), labels=labels) flat_outputs = [ # Even-though it may look like we're transforming all inputs, we don't: # _transform() will only care about BoundingBoxeses and the labels diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ce1c320a745..da913a7d9e2 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -20,6 +20,8 @@ def is_simple_tensor(inpt: Any) -> bool: def _kernel_datapoint_wrapper(kernel): @functools.wraps(kernel) def wrapper(inpt, *args, **kwargs): + # We always pass datapoints as pure tensors to the kernels to avoid going through the + # Tensor.__torch_function__ logic, which is costly. output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) return type(inpt).wrap_like(inpt, output)