From 8f8f93663d7a50bf8ce1769f7e406144d01caedc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Aug 2023 15:55:58 +0100 Subject: [PATCH 1/8] move stuff out of CM --- torchvision/datapoints/_datapoint.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 9b1c648648d..bf78c230504 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -37,6 +37,8 @@ def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: return tensor.as_subclass(cls) _NO_WRAPPING_EXCEPTIONS = { + # The ops in this dict are those that should *preserve* the Datapoint + # type, i.e. they are exceptions to the "no wrapping" rule. torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output), torch.Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output), @@ -79,22 +81,22 @@ def __torch_function__( with DisableTorchFunctionSubclass(): output = func(*args, **kwargs or dict()) - wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) + wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) + if wrapper and isinstance(args[0], cls): # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would # be wrapped into a `datapoints.Image`. - if wrapper and isinstance(args[0], cls): - return wrapper(cls, args[0], output) + return wrapper(cls, args[0], output) - # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`, - # will retain the input type. Thus, we need to unwrap here. - if isinstance(output, cls): - return output.as_subclass(torch.Tensor) + if isinstance(output, cls): + # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`, + # so for those, the output is still a Datapoint. Thus, we need to manually unwrap. + return output.as_subclass(torch.Tensor) - return output + return output def _make_repr(self, **kwargs: Any) -> str: # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532. From b1018a98fed3f4efde10dc7ac1620b318d002328 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Aug 2023 16:19:03 +0100 Subject: [PATCH 2/8] Call wrap_like for all exceptions --- test/test_datapoints.py | 1 - torchvision/datapoints/_datapoint.py | 20 ++++++-------------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/test/test_datapoints.py b/test/test_datapoints.py index 25a2182e050..ded9a771e14 100644 --- a/test/test_datapoints.py +++ b/test/test_datapoints.py @@ -203,4 +203,3 @@ def test_deepcopy(datapoint, requires_grad): assert type(datapoint_deepcopied) is type(datapoint) assert datapoint_deepcopied.requires_grad is requires_grad - assert datapoint_deepcopied.is_leaf diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index bf78c230504..2faa4e3716a 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -36,16 +36,9 @@ def _to_tensor( def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: return tensor.as_subclass(cls) - _NO_WRAPPING_EXCEPTIONS = { - # The ops in this dict are those that should *preserve* the Datapoint - # type, i.e. they are exceptions to the "no wrapping" rule. - torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), - torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output), - torch.Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output), - # We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus - # retains the type automatically - torch.Tensor.requires_grad_: lambda cls, input, output: output, - } + # The ops in this set are those that should *preserve* the Datapoint type, + # i.e. they are exceptions to the "no wrapping" rule. + _NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_} @classmethod def __torch_function__( @@ -81,15 +74,14 @@ def __torch_function__( with DisableTorchFunctionSubclass(): output = func(*args, **kwargs or dict()) - wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) - if wrapper and isinstance(args[0], cls): - # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be + if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls): + # We also require the primary operand, i.e. `args[0]`, to be # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would # be wrapped into a `datapoints.Image`. - return wrapper(cls, args[0], output) + return cls.wrap_like(args[0], output) if isinstance(output, cls): # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`, From ddd88cd82a044058aeb07c34c4a2f0b03f52c38a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Aug 2023 17:33:17 +0100 Subject: [PATCH 3/8] Get rid of __torchfunction__ and the whole wrapping/unwrapping logic --- test/test_datapoints.py | 6 +- test/test_transforms_v2_functional.py | 9 ++- test/test_transforms_v2_refactored.py | 19 +++++ test/transforms_v2_kernel_infos.py | 8 +- torchvision/datapoints/_datapoint.py | 81 +++---------------- .../transforms/v2/functional/_utils.py | 11 +-- 6 files changed, 49 insertions(+), 85 deletions(-) diff --git a/test/test_datapoints.py b/test/test_datapoints.py index ded9a771e14..d8492badbaa 100644 --- a/test/test_datapoints.py +++ b/test/test_datapoints.py @@ -80,7 +80,7 @@ def test_to_datapoint_reference(): tensor_to = tensor.to(image) - assert type(tensor_to) is torch.Tensor + assert type(tensor_to) is datapoints.Image assert tensor_to.dtype is torch.float64 @@ -139,7 +139,7 @@ def test_other_op_no_wrapping(): # any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here output = image * 2 - assert type(output) is torch.Tensor + assert type(output) is datapoints.Image @pytest.mark.parametrize( @@ -163,7 +163,7 @@ def test_inplace_op_no_wrapping(): output = image.add_(0) - assert type(output) is torch.Tensor + assert type(output) is datapoints.Image assert type(image) is datapoints.Image diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 713737abbff..bcea8b3dfbf 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -687,7 +687,9 @@ def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_): if format != datapoints.BoundingBoxFormat.XYXY: in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) - output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size) + output_boxes, output_canvas_size = F.resized_crop_bounding_boxes( + in_boxes.as_subclass(torch.Tensor), format, top, left, height, width, size + ) if format != datapoints.BoundingBoxFormat.XYXY: output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) @@ -743,7 +745,7 @@ def _compute_expected_canvas_size(bbox, padding_): bboxes_canvas_size = bboxes.canvas_size output_boxes, output_canvas_size = F.pad_bounding_boxes( - bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding + bboxes.as_subclass(torch.Tensor), format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding ) torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding)) @@ -753,6 +755,9 @@ def _compute_expected_canvas_size(bbox, padding_): expected_bboxes = [] for bbox in bboxes: + print() + print(type(bbox)) + print(hasattr(bbox, "format")) bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size) expected_bboxes.append(_compute_expected_bbox(bbox, padding)) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index c910882f9fd..8be3988bc63 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2280,3 +2280,22 @@ def resize_my_datapoint(): _register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint) assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint + + +def test_operations(): + + img = datapoints.Image(torch.rand(3, 10, 10)) + t = torch.rand(3, 10, 10) + mask = datapoints.Mask(torch.rand(1, 10, 10)) + + for out in [img + t, t + img, img * t, t * img, img + 3, 3 + img, img * 3, 3 * img, img + img]: + assert isinstance(out, datapoints.Image) + + for out in [mask + t, t + mask, mask * t, t * mask, mask + 3, 3 + mask, mask * 3, 3 * mask, mask + mask]: + assert isinstance(out, datapoints.Mask) + + with pytest.raises(TypeError, match="unsupported operand type"): + img + mask + + with pytest.raises(TypeError, match="unsupported operand type"): + img * mask diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index 01605f696b4..b7fe90817bf 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -241,9 +241,11 @@ def sample_inputs_convert_format_bounding_boxes(): def reference_convert_format_bounding_boxes(bounding_boxes, old_format, new_format): - return torchvision.ops.box_convert( - bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower() - ).to(bounding_boxes.dtype) + return ( + torchvision.ops.box_convert(bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()) + .as_subclass(torch.Tensor) + .to(bounding_boxes.dtype) + ) def reference_inputs_convert_format_bounding_boxes(): diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 2faa4e3716a..8d4e8bf648a 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -32,63 +32,22 @@ def _to_tensor( requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) + # We have to override a few method to make sure the meta-data is preserved on them. @classmethod def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: return tensor.as_subclass(cls) - # The ops in this set are those that should *preserve* the Datapoint type, - # i.e. they are exceptions to the "no wrapping" rule. - _NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_} + def clone(self): + return type(self).wrap_like(self, super().clone()) - @classmethod - def __torch_function__( - cls, - func: Callable[..., torch.Tensor], - types: Tuple[Type[torch.Tensor], ...], - args: Sequence[Any] = (), - kwargs: Optional[Mapping[str, Any]] = None, - ) -> torch.Tensor: - """For general information about how the __torch_function__ protocol works, - see https://pytorch.org/docs/stable/notes/extending.html#extending-torch - - TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the - ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the - ``args`` and ``kwargs`` of the original call. - - The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint` - use case, this has two downsides: - - 1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e. - ``return cls(func(*args, **kwargs))``, will fail for them. - 2. For most operations, there is no way of knowing if the input type is still valid for the output. + def to(self, *args, **kwargs): + return type(self).wrap_like(self, super().to(*args, **kwargs)) - For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are - listed in :attr:`Datapoint._NO_WRAPPING_EXCEPTIONS` - """ - # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we - # need to reimplement the functionality. + def detach(self): + return type(self).wrap_like(self, super().detach()) - if not all(issubclass(cls, t) for t in types): - return NotImplemented - - with DisableTorchFunctionSubclass(): - output = func(*args, **kwargs or dict()) - - if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls): - # We also require the primary operand, i.e. `args[0]`, to be - # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will - # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, - # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with - # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would - # be wrapped into a `datapoints.Image`. - return cls.wrap_like(args[0], output) - - if isinstance(output, cls): - # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`, - # so for those, the output is still a Datapoint. Thus, we need to manually unwrap. - return output.as_subclass(torch.Tensor) - - return output + def requires_grad_(self, requires_grad: bool = True): + return type(self).wrap_like(self, super().requires_grad_(requires_grad)) def _make_repr(self, **kwargs: Any) -> str: # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532. @@ -96,28 +55,6 @@ def _make_repr(self, **kwargs: Any) -> str: extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items()) return f"{super().__repr__()[:-1]}, {extra_repr})" - # Add properties for common attributes like shape, dtype, device, ndim etc - # this way we return the result without passing into __torch_function__ - @property - def shape(self) -> _size: # type: ignore[override] - with DisableTorchFunctionSubclass(): - return super().shape - - @property - def ndim(self) -> int: # type: ignore[override] - with DisableTorchFunctionSubclass(): - return super().ndim - - @property - def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override] - with DisableTorchFunctionSubclass(): - return super().device - - @property - def dtype(self) -> _dtype: # type: ignore[override] - with DisableTorchFunctionSubclass(): - return super().dtype - def __deepcopy__(self: D, memo: Dict[int, Any]) -> D: # We need to detach first, since a plain `Tensor.clone` will be part of the computation graph, which does # *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad` diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 576a2b99dbf..e3c519221c3 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -29,11 +29,12 @@ def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True) raise ValueError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.") def decorator(kernel): - registry[input_type] = ( - _kernel_datapoint_wrapper(kernel) - if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper - else kernel - ) + # registry[input_type] = ( + # _kernel_datapoint_wrapper(kernel) + # if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper + # else kernel + # ) + registry[input_type] = kernel return kernel return decorator From 4e8b53de60dffaaa136ae282b5261e9411e45b65 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Aug 2023 17:41:53 +0100 Subject: [PATCH 4/8] bbox tests --- test/test_transforms_v2_refactored.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 8be3988bc63..4e3ea8a1e03 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2299,3 +2299,15 @@ def test_operations(): with pytest.raises(TypeError, match="unsupported operand type"): img * mask + + bboxes = datapoints.BoundingBoxes( + [[17, 16, 344, 495], [0, 10, 0, 10]], + format=datapoints.BoundingBoxFormat.XYXY, + canvas_size=(1000, 1000) + ) + t = torch.rand(2, 4) + + for out in [bboxes + t, t + bboxes, bboxes * t, t * bboxes, bboxes + 3, 3 + bboxes, bboxes * 3, 3 * bboxes, bboxes + bboxes]: + assert isinstance(out, datapoints.BoundingBoxes) + assert not(hasattr(out, "format")) + assert not(hasattr(out, "canvas_size")) \ No newline at end of file From 74712717fd210df2b80b177ff3ab5f456acd8e80 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 8 Aug 2023 11:05:08 +0100 Subject: [PATCH 5/8] Put back wrapping / unwrapping in kernels --- torchvision/transforms/v2/functional/_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index e3c519221c3..20330b04dfe 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -17,6 +17,8 @@ def is_simple_tensor(inpt: Any) -> bool: def _kernel_datapoint_wrapper(kernel): @functools.wraps(kernel) def wrapper(inpt, *args, **kwargs): + # We always pass datapoints as pure tensors to the kernels to avoid going through the + # Tensor.__torch_function__ logic, which is costly. output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) return type(inpt).wrap_like(inpt, output) @@ -29,12 +31,11 @@ def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True) raise ValueError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.") def decorator(kernel): - # registry[input_type] = ( - # _kernel_datapoint_wrapper(kernel) - # if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper - # else kernel - # ) - registry[input_type] = kernel + registry[input_type] = ( + _kernel_datapoint_wrapper(kernel) + if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper + else kernel + ) return kernel return decorator From 23b9704170966b9e93812ac25143e97fae16f415 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 8 Aug 2023 12:45:28 +0100 Subject: [PATCH 6/8] Fix tests --- test/test_prototype_transforms.py | 4 ++-- test/test_transforms_v2.py | 2 +- test/test_transforms_v2_functional.py | 11 ++++++---- test/test_transforms_v2_refactored.py | 20 +++++++++++++------ torchvision/datapoints/_datapoint.py | 4 +--- torchvision/prototype/transforms/_augment.py | 6 ++++-- torchvision/prototype/transforms/_geometry.py | 6 +++++- torchvision/prototype/transforms/_misc.py | 4 ++-- torchvision/transforms/v2/_misc.py | 2 +- 9 files changed, 37 insertions(+), 22 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index d395c224785..d665ad44fad 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -362,7 +362,7 @@ def test_call(self, dims, inverse_dims): if check_type(value, (Image, is_simple_tensor, Video)): if transform.dims.get(value_type) is not None: assert transformed_value.permute(inverse_dims[value_type]).equal(value) - assert type(transformed_value) == torch.Tensor + assert type(transformed_value) == value_type else: assert transformed_value is value @@ -407,7 +407,7 @@ def test_call(self, dims): if check_type(value, (Image, is_simple_tensor, Video)): if transposed_dims is not None: assert transformed_value.transpose(*transposed_dims).equal(value) - assert type(transformed_value) == torch.Tensor + assert type(transformed_value) == value_type else: assert transformed_value is value diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 353cc846bed..1589165b334 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -36,7 +36,7 @@ def make_vanilla_tensor_images(*args, **kwargs): for image in make_images(*args, **kwargs): if image.ndim > 3: continue - yield image.data + yield image.as_subclass(torch.Tensor) def make_pil_images(*args, **kwargs): diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index e09385aca63..680f9827ca8 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -750,7 +750,10 @@ def _compute_expected_canvas_size(bbox, padding_): torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding)) expected_bboxes = torch.stack( - [_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()] + [ + _compute_expected_bbox(b.as_subclass(torch.Tensor), bboxes_format, padding) + for b in bboxes.reshape(-1, 4).unbind() + ] ).reshape(bboxes.shape) torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) @@ -838,7 +841,7 @@ def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_): expected_bboxes = torch.stack( [ - _compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs) + _compute_expected_bbox(b.as_subclass(torch.Tensor), bboxes.format, bboxes.canvas_size, inv_pcoeffs) for b in bboxes.reshape(-1, 4).unbind() ] ).reshape(bboxes.shape) @@ -878,12 +881,12 @@ def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_): bboxes_canvas_size = bboxes.canvas_size output_boxes, output_canvas_size = F.center_crop_bounding_boxes( - bboxes, bboxes_format, bboxes_canvas_size, output_size + bboxes.as_subclass(torch.Tensor), bboxes_format, bboxes_canvas_size, output_size ) expected_bboxes = torch.stack( [ - _compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size) + _compute_expected_bbox(b.as_subclass(torch.Tensor), bboxes_format, bboxes_canvas_size, output_size) for b in bboxes.reshape(-1, 4).unbind() ] ).reshape(bboxes.shape) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 4e3ea8a1e03..50ff53c7edd 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2301,13 +2301,21 @@ def test_operations(): img * mask bboxes = datapoints.BoundingBoxes( - [[17, 16, 344, 495], [0, 10, 0, 10]], - format=datapoints.BoundingBoxFormat.XYXY, - canvas_size=(1000, 1000) + [[17, 16, 344, 495], [0, 10, 0, 10]], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1000, 1000) ) t = torch.rand(2, 4) - for out in [bboxes + t, t + bboxes, bboxes * t, t * bboxes, bboxes + 3, 3 + bboxes, bboxes * 3, 3 * bboxes, bboxes + bboxes]: + for out in [ + bboxes + t, + t + bboxes, + bboxes * t, + t * bboxes, + bboxes + 3, + 3 + bboxes, + bboxes * 3, + 3 * bboxes, + bboxes + bboxes, + ]: assert isinstance(out, datapoints.BoundingBoxes) - assert not(hasattr(out, "format")) - assert not(hasattr(out, "canvas_size")) \ No newline at end of file + assert not (hasattr(out, "format")) + assert not (hasattr(out, "canvas_size")) diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 8d4e8bf648a..44a4d4e4040 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -1,11 +1,9 @@ from __future__ import annotations -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Dict, List, Optional, Sequence, Type, TypeVar, Union import PIL.Image import torch -from torch._C import DisableTorchFunctionSubclass -from torch.types import _device, _dtype, _size D = TypeVar("D", bound="Datapoint") diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 95585fe287c..251186a92d0 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -66,7 +66,7 @@ def _copy_paste( # Copy-paste masks: masks = masks * inverse_paste_alpha_mask - non_all_zero_masks = masks.sum((-1, -2)) > 0 + non_all_zero_masks = (masks.sum((-1, -2)) > 0).as_subclass(torch.Tensor) masks = masks[non_all_zero_masks] # Do a shallow copy of the target dict @@ -92,7 +92,9 @@ def _copy_paste( # Check for degenerated boxes and remove them boxes = F.convert_format_bounding_boxes( - out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY + out_target["boxes"].as_subclass(torch.Tensor), + old_format=bbox_format, + new_format=datapoints.BoundingBoxFormat.XYXY, ) degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] if degenerate_boxes.any(): diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index e3819554d0b..f8e5dc27352 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -115,7 +115,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: elif isinstance(inpt, datapoints.BoundingBoxes): inpt = datapoints.BoundingBoxes.wrap_like( inpt, - F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size), + F.clamp_bounding_boxes( + inpt[params["is_valid"]].as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + ), ) if params["needs_pad"]: diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 51a2ea9074a..d16f31c7782 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -44,7 +44,7 @@ def _transform( ) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: - return inpt.as_subclass(torch.Tensor) + return inpt return inpt.permute(*dims) @@ -68,5 +68,5 @@ def _transform( ) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: - return inpt.as_subclass(torch.Tensor) + return inpt return inpt.transpose(*dims) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index d2dddd96d5c..e67f184a42f 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -406,7 +406,7 @@ def forward(self, *inputs: Any) -> Any: valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) - params = dict(valid=valid, labels=labels) + params = dict(valid=valid.as_subclass(torch.Tensor), labels=labels) flat_outputs = [ # Even-though it may look like we're transforming all inputs, we don't: # _transform() will only care about BoundingBoxeses and the labels From f12fee1ef3d1673dcc979331527079716693e668 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Aug 2023 17:03:25 +0100 Subject: [PATCH 7/8] preserve metadata on bboxes --- test/test_transforms_v2_refactored.py | 74 ++++++++++++++++++++----- test/transforms_v2_kernel_infos.py | 13 +++-- torchvision/datapoints/_bounding_box.py | 27 ++++++++- torchvision/datapoints/_datapoint.py | 13 ----- 4 files changed, 91 insertions(+), 36 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 50ff53c7edd..101f9301e34 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2288,10 +2288,46 @@ def test_operations(): t = torch.rand(3, 10, 10) mask = datapoints.Mask(torch.rand(1, 10, 10)) - for out in [img + t, t + img, img * t, t * img, img + 3, 3 + img, img * 3, 3 * img, img + img]: + for out in ( + [ + img + t, + t + img, + img * t, + t * img, + img + 3, + 3 + img, + img * 3, + 3 * img, + img + img, + img.sum(), + img.reshape(-1), + img.float(), + torch.stack([img, img]), + ] + + list(torch.chunk(img, 2)) + + list(torch.unbind(img)) + ): assert isinstance(out, datapoints.Image) - for out in [mask + t, t + mask, mask * t, t * mask, mask + 3, 3 + mask, mask * 3, 3 * mask, mask + mask]: + for out in ( + [ + mask + t, + t + mask, + mask * t, + t * mask, + mask + 3, + 3 + mask, + mask * 3, + 3 * mask, + mask + mask, + mask.sum(), + mask.reshape(-1), + mask.float(), + torch.stack([mask, mask]), + ] + + list(torch.chunk(mask, 2)) + + list(torch.unbind(mask)) + ): assert isinstance(out, datapoints.Mask) with pytest.raises(TypeError, match="unsupported operand type"): @@ -2305,17 +2341,25 @@ def test_operations(): ) t = torch.rand(2, 4) - for out in [ - bboxes + t, - t + bboxes, - bboxes * t, - t * bboxes, - bboxes + 3, - 3 + bboxes, - bboxes * 3, - 3 * bboxes, - bboxes + bboxes, - ]: + for out in ( + [ + bboxes + t, + t + bboxes, + bboxes * t, + t * bboxes, + bboxes + 3, + 3 + bboxes, + bboxes * 3, + 3 * bboxes, + bboxes + bboxes, + bboxes.sum(), + bboxes.reshape(-1), + bboxes.float(), + torch.stack([bboxes, bboxes]), + ] + + list(torch.chunk(bboxes, 2)) + + list(torch.unbind(bboxes)) + ): assert isinstance(out, datapoints.BoundingBoxes) - assert not (hasattr(out, "format")) - assert not (hasattr(out, "canvas_size")) + assert hasattr(out, "format") + assert hasattr(out, "canvas_size") diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index eda325b57ad..e96a549e87b 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -234,11 +234,9 @@ def sample_inputs_convert_format_bounding_boxes(): def reference_convert_format_bounding_boxes(bounding_boxes, old_format, new_format): - return ( - torchvision.ops.box_convert(bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()) - .as_subclass(torch.Tensor) - .to(bounding_boxes.dtype) - ) + return torchvision.ops.box_convert( + bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower() + ).to(bounding_boxes.dtype) def reference_inputs_convert_format_bounding_boxes(): @@ -541,7 +539,10 @@ def reference_pad_bounding_boxes(bounding_boxes, *, format, canvas_size, padding width = canvas_size[1] + left + right expected_bboxes = reference_affine_bounding_boxes_helper( - bounding_boxes, format=format, canvas_size=(height, width), affine_matrix=affine_matrix + bounding_boxes, + format=format, + canvas_size=(height, width), + affine_matrix=affine_matrix, ) return expected_bboxes, (height, width) diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index d459a55448a..662309d13a0 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -4,6 +4,7 @@ from typing import Any, Optional, Tuple, Union import torch +from torch.utils._pytree import tree_flatten from ._datapoint import Datapoint @@ -99,5 +100,27 @@ def wrap_like( canvas_size=canvas_size if canvas_size is not None else other.canvas_size, ) - def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(format=self.format, canvas_size=self.canvas_size) + @classmethod + def __torch_function__( + cls, + func, + types, + args, + kwargs=None, + ) -> torch.Tensor: + out = super().__torch_function__(func, types, args, kwargs) + + # If there are BoundingBoxes instances in the output, their metadata got lost when we called + # super().__torch_function__. We need to restore the metadata somehow, so we choose to take + # the metadata from the first bbox in the parameters. + # This should be what we want in most cases. When it's not, it's probably a mis-use anyway, e.g. + # something like some_xyxy_bbox + some_xywh_bbox; we don't guard against those cases. + first_bbox_from_args = None + for obj in tree_flatten(out)[0]: + if isinstance(obj, BoundingBoxes): + if first_bbox_from_args is None: + flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) + first_bbox_from_args = next(x for x in flat_params if isinstance(x, BoundingBoxes)) + obj.format = first_bbox_from_args.format + obj.canvas_size = first_bbox_from_args.canvas_size + return out diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 44a4d4e4040..7ee1f2854dd 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -30,23 +30,10 @@ def _to_tensor( requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) - # We have to override a few method to make sure the meta-data is preserved on them. @classmethod def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: return tensor.as_subclass(cls) - def clone(self): - return type(self).wrap_like(self, super().clone()) - - def to(self, *args, **kwargs): - return type(self).wrap_like(self, super().to(*args, **kwargs)) - - def detach(self): - return type(self).wrap_like(self, super().detach()) - - def requires_grad_(self, requires_grad: bool = True): - return type(self).wrap_like(self, super().requires_grad_(requires_grad)) - def _make_repr(self, **kwargs: Any) -> str: # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532. # If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class. From 854b01c6757e5e2b750021d7a987040414483252 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 10 Aug 2023 10:25:12 +0100 Subject: [PATCH 8/8] mypy --- torchvision/datapoints/_bounding_box.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index 662309d13a0..fcf3088dc40 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Type, Union import torch from torch.utils._pytree import tree_flatten @@ -103,10 +103,10 @@ def wrap_like( @classmethod def __torch_function__( cls, - func, - types, - args, - kwargs=None, + func: Callable[..., torch.Tensor], + types: Tuple[Type[torch.Tensor], ...], + args: Sequence[Any] = (), + kwargs: Optional[Mapping[str, Any]] = None, ) -> torch.Tensor: out = super().__torch_function__(func, types, args, kwargs) @@ -119,7 +119,7 @@ def __torch_function__( for obj in tree_flatten(out)[0]: if isinstance(obj, BoundingBoxes): if first_bbox_from_args is None: - flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) + flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator] first_bbox_from_args = next(x for x in flat_params if isinstance(x, BoundingBoxes)) obj.format = first_bbox_from_args.format obj.canvas_size = first_bbox_from_args.canvas_size