diff --git a/docs/source/conf.py b/docs/source/conf.py index 7b3e9e8a7f3..fed3884ea27 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -320,7 +320,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines): used within the autoclass directive. """ - if obj.__name__.endswith(("_Weights", "_QuantizedWeights")): + if getattr(obj, ".__name__", "").endswith(("_Weights", "_QuantizedWeights")): if len(obj) == 0: lines[:] = ["There are no available pre-trained weights."] diff --git a/docs/source/datapoints.rst b/docs/source/datapoints.rst index 55d3cda4a8c..ea23a7ff7a6 100644 --- a/docs/source/datapoints.rst +++ b/docs/source/datapoints.rst @@ -17,3 +17,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`. BoundingBoxFormat BoundingBoxes Mask + Datapoint diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 73adb3cf3b5..a1858c6b514 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -375,3 +375,14 @@ you can use a functional transform to build transform classes with custom behavi to_pil_image to_tensor vflip + +Developer tools +--------------- + +.. currentmodule:: torchvision.transforms.v2.functional + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + register_kernel diff --git a/gallery/plot_custom_datapoints.py b/gallery/plot_custom_datapoints.py new file mode 100644 index 00000000000..ea757283e86 --- /dev/null +++ b/gallery/plot_custom_datapoints.py @@ -0,0 +1,125 @@ +""" +===================================== +How to write your own Datapoint class +===================================== + +This guide is intended for downstream library maintainers. We explain how to +write your own datapoint class, and how to make it compatible with the built-in +Torchvision v2 transforms. Before continuing, make sure you have read +:ref:`sphx_glr_auto_examples_plot_datapoints.py`. +""" + +# %% +import torch +import torchvision + +# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that +# some APIs may slightly change in the future +torchvision.disable_beta_transforms_warning() + +from torchvision import datapoints +from torchvision.transforms import v2 + +# %% +# We will create a very simple class that just inherits from the base +# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover +# what you need to know to implement your more elaborate uses-cases. If you need +# to create a class that carries meta-data, take a look at how the +# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented +# `_. + + +class MyDatapoint(datapoints.Datapoint): + pass + + +my_dp = MyDatapoint([1, 2, 3]) +my_dp + +# %% +# Now that we have defined our custom Datapoint class, we want it to be +# compatible with the built-in torchvision transforms, and the functional API. +# For that, we need to implement a kernel which performs the core of the +# transformation, and then "hook" it to the functional that we want to support +# via :func:`~torchvision.transforms.v2.functional.register_kernel`. +# +# We illustrate this process below: we create a kernel for the "horizontal flip" +# operation of our MyDatapoint class, and register it to the functional API. + +from torchvision.transforms.v2 import functional as F + + +@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint) +def hflip_my_datapoint(my_dp, *args, **kwargs): + print("Flipping!") + out = my_dp.flip(-1) + return MyDatapoint.wrap_like(my_dp, out) + + +# %% +# To understand why ``wrap_like`` is used, see +# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now, +# we will explain it below in :ref:`param_forwarding`. +# +# .. note:: +# +# In our call to ``register_kernel`` above we used a string +# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We +# could also have used the functional *itself*, i.e. +# ``@register_kernel(dispatcher=F.hflip, ...)``. +# +# The functionals that you can be hooked into are the ones in +# ``torchvision.transforms.v2.functional`` and they are documented in +# :ref:`functional_transforms`. +# +# Now that we have registered our kernel, we can call the functional API on a +# ``MyDatapoint`` instance: + +my_dp = MyDatapoint(torch.rand(3, 256, 256)) +_ = F.hflip(my_dp) + +# %% +# And we can also use the +# :class:`~torchvision.transforms.v2.RandomHorizontalFlip` transform, since it relies on :func:`~torchvision.transforms.v2.functional.hflip` internally: +t = v2.RandomHorizontalFlip(p=1) +_ = t(my_dp) + +# %% +# .. note:: +# +# We cannot register a kernel for a transform class, we can only register a +# kernel for a **functional**. The reason we can't register a transform +# class is because one transform may internally rely on more than one +# functional, so in general we can't register a single kernel for a given +# class. +# +# .. _param_forwarding: +# +# Parameter forwarding, and ensuring future compatibility of your kernels +# ----------------------------------------------------------------------- +# +# The functional API that you're hooking into is public and therefore +# **backward** compatible: we guarantee that the parameters of these functionals +# won't be removed or renamed without a proper deprecation cycle. However, we +# don't guarantee **forward** compatibility, and we may add new parameters in +# the future. +# +# Imagine that in a future version, Torchvision adds a new ``inplace`` parameter +# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you +# already defined and registered your own kernel as + +def hflip_my_datapoint(my_dp): # noqa + print("Flipping!") + out = my_dp.flip(-1) + return MyDatapoint.wrap_like(my_dp, out) + + +# %% +# then calling ``F.hflip(my_dp)`` will **fail**, because ``hflip`` will try to +# pass the new ``inplace`` parameter to your kernel, but your kernel doesn't +# accept it. +# +# For this reason, we recommend to always define your kernels with +# ``*args, **kwargs`` in their signature, as done above. This way, your kernel +# will be able to accept any new parameter that we may add in the future. +# (Technically, adding `**kwargs` only should be enough). diff --git a/gallery/plot_custom_transforms.py b/gallery/plot_custom_transforms.py new file mode 100644 index 00000000000..eba8e91faf4 --- /dev/null +++ b/gallery/plot_custom_transforms.py @@ -0,0 +1,123 @@ +""" +=================================== +How to write your own v2 transforms +=================================== + +This guide explains how to write transforms that are compatible with the +torchvision transforms V2 API. +""" + +# %% +import torch +import torchvision + +# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that +# some APIs may slightly change in the future +torchvision.disable_beta_transforms_warning() + +from torchvision import datapoints +from torchvision.transforms import v2 + + +# %% +# Just create a ``nn.Module`` and override the ``forward`` method +# =============================================================== +# +# In most cases, this is all you're going to need, as long as you already know +# the structure of the input that your transform will expect. For example if +# you're just doing image classification, your transform will typically accept a +# single image as input, or a ``(img, label)`` input. So you can just hard-code +# your ``forward`` method to accept just that, e.g. +# +# .. code:: python +# +# class MyCustomTransform(torch.nn.Module): +# def forward(self, img, label): +# # Do some transformations +# return new_img, new_label +# +# .. note:: +# +# This means that if you have a custom transform that is already compatible +# with the V1 transforms (those in ``torchvision.transforms``), it will +# still work with the V2 transforms without any change! +# +# We will illustrate this more completely below with a typical detection case, +# where our samples are just images, bounding boxes and labels: + +class MyCustomTransform(torch.nn.Module): + def forward(self, img, bboxes, label): # we assume inputs are always structured like this + print( + f"I'm transforming an image of shape {img.shape} " + f"with bboxes = {bboxes}\n{label = }" + ) + # Do some transformations. Here, we're just passing though the input + return img, bboxes, label + + +transforms = v2.Compose([ + MyCustomTransform(), + v2.RandomResizedCrop((224, 224), antialias=True), + v2.RandomHorizontalFlip(p=1), + v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1]) +]) + +H, W = 256, 256 +img = torch.rand(3, H, W) +bboxes = datapoints.BoundingBoxes( + torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]), + format="XYXY", + canvas_size=(H, W) +) +label = 3 + +out_img, out_bboxes, out_label = transforms(img, bboxes, label) +# %% +print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }") +# %% +# .. note:: +# While working with datapoint classes in your code, make sure to +# familiarize yourself with this section: +# :ref:`datapoint_unwrapping_behaviour` +# +# Supporting arbitrary input structures +# ===================================== +# +# In the section above, we have assumed that you already know the structure of +# your inputs and that you're OK with hard-coding this expected structure in +# your code. If you want your custom transforms to be as flexible as possible, +# this can be a bit limitting. +# +# A key feature of the builtin Torchvision V2 transforms is that they can accept +# arbitrary input structure and return the same structure as output (with +# transformed entries). For example, transforms can accept a single image, or a +# tuple of ``(img, label)``, or an arbitrary nested dictionary as input: + +structured_input = { + "img": img, + "annotations": (bboxes, label), + "something_that_will_be_ignored": (1, "hello") +} +structured_output = v2.RandomHorizontalFlip(p=1)(structured_input) + +assert isinstance(structured_output, dict) +assert structured_output["something_that_will_be_ignored"] == (1, "hello") +print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") + +# %% +# If you want to reproduce this behavior in your own transform, we invite you to +# look at our `code +# `_ +# and adapt it to your needs. +# +# In brief, the core logic is to unpack the input into a flat list using `pytree +# `_, and +# then transform only the entries that can be transformed (the decision is made +# based on the **class** of the entries, as all datapoints are +# tensor-subclasses) plus some custom logic that is out of score here - check the +# code for details. The (potentially transformed) entries are then repacked and +# returned, in the same structure as the input. +# +# We do not provide public dev-facing tools to achieve that at this time, but if +# this is something that would be valuable to you, please let us know by opening +# an issue on our `GitHub repo `_. diff --git a/gallery/plot_datapoints.py b/gallery/plot_datapoints.py index 57e29bd86eb..d87575cdb8e 100644 --- a/gallery/plot_datapoints.py +++ b/gallery/plot_datapoints.py @@ -3,13 +3,22 @@ Datapoints FAQ ============== -The :mod:`torchvision.datapoints` namespace was introduced together with ``torchvision.transforms.v2``. This example -showcases what these datapoints are and how they behave. This is a fairly low-level topic that most users will not need -to worry about: you do not need to understand the internals of datapoints to efficiently rely on -``torchvision.transforms.v2``. It may however be useful for advanced users trying to implement their own datasets, -transforms, or work directly with the datapoints. +Datapoints are Tensor subclasses introduced together with +``torchvision.transforms.v2``. This example showcases what these datapoints are +and how they behave. + +.. warning:: + + **Intended Audience** Unless you're writing your own transforms or your own datapoints, you + probably do not need to read this guide. This is a fairly low-level topic + that most users will not need to worry about: you do not need to understand + the internals of datapoints to efficiently rely on + ``torchvision.transforms.v2``. It may however be useful for advanced users + trying to implement their own datasets, transforms, or work directly with + the datapoints. """ +# %% import PIL.Image import torch @@ -35,11 +44,20 @@ assert isinstance(image, torch.Tensor) assert image.data_ptr() == tensor.data_ptr() - # %% # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function # for the input data. # +# What can I do with a datapoint? +# ------------------------------- +# +# Datapoints look and feel just like regular tensors - they **are** tensors. +# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or +# any ``torch.*`` operator will also works on datapoints. See +# :ref:`datapoint_unwrapping_behaviour` for a few gotchas. + +# %% +# # What datapoints are supported? # ------------------------------ # @@ -50,9 +68,14 @@ # * :class:`~torchvision.datapoints.BoundingBoxes` # * :class:`~torchvision.datapoints.Mask` # +# .. _datapoint_creation: +# # How do I construct a datapoint? # ------------------------------- # +# Using the constructor +# ^^^^^^^^^^^^^^^^^^^^^ +# # Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor` image = datapoints.Image([[[[0, 1], [1, 0]]]]) @@ -68,27 +91,52 @@ # %% -# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` also take a +# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` can also take a # :class:`PIL.Image.Image` directly: image = datapoints.Image(PIL.Image.open("assets/astronaut.jpg")) print(image.shape, image.dtype) # %% -# In general, the datapoints can also store additional metadata that complements the underlying tensor. For example, -# :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the -# corresponding image alongside the actual values: - -bounding_box = datapoints.BoundingBoxes( - [17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:] +# Some datapoints require additional metadata to be passed in ordered to be constructed. For example, +# :class:`~torchvision.datapoints.BoundingBoxes` requires the coordinate format as well as the size of the +# corresponding image (``canvas_size``) alongside the actual values. These +# metadata are required to properly transform the bounding boxes. + +bboxes = datapoints.BoundingBoxes( + [[17, 16, 344, 495], [0, 10, 0, 10]], + format=datapoints.BoundingBoxFormat.XYXY, + canvas_size=image.shape[-2:] ) -print(bounding_box) +print(bboxes) + +# %% +# Using the ``wrap_like()`` class method +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# You can also use the ``wrap_like()`` class method to wrap a tensor object +# into a datapoint. This is useful when you already have an object of the +# desired type, which typically happens when writing transforms: you just want +# to wrap the output like the input. This API is inspired by utils like +# :func:`torch.zeros_like`: + +new_bboxes = torch.tensor([0, 20, 30, 40]) +new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) +assert isinstance(new_bboxes, datapoints.BoundingBoxes) +assert new_bboxes.canvas_size == bboxes.canvas_size # %% +# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass +# it as a parameter to override it. Check the +# :meth:`~torchvision.datapoints.BoundingBoxes.wrap_like` documentation for +# more details. +# # Do I have to wrap the output of the datasets myself? # ---------------------------------------------------- # +# TODO: Move this in another guide - this is user-facing, not dev-facing. +# # Only if you are using custom datasets. For the built-in ones, you can use # :func:`torchvision.datasets.wrap_dataset_for_transforms_v2`. Note that the function also supports subclasses of the # built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you @@ -105,8 +153,8 @@ class PennFudanDataset(torch.utils.data.Dataset): def __getitem__(self, item): ... - target["boxes"] = datapoints.BoundingBoxes( - boxes, + target["bboxes"] = datapoints.BoundingBoxes( + bboxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=F.get_size(img), ) @@ -147,7 +195,7 @@ def get_transform(train): # %% # .. note:: # -# If both :class:`~torchvision.datapoints.BoundingBoxes`'es and :class:`~torchvision.datapoints.Mask`'s are included in +# If both :class:`~torchvision.datapoints.BoundingBoxes` and :class:`~torchvision.datapoints.Mask`'s are included in # the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or # at least not wrapping the obsolete parts, can lead to a significant performance boost. # @@ -156,41 +204,66 @@ def get_transform(train): # even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are # generated from the masks. # -# How do the datapoints behave inside a computation? -# -------------------------------------------------- +# .. _datapoint_unwrapping_behaviour: # -# Datapoints look and feel just like regular tensors. Everything that is supported on a plain :class:`torch.Tensor` -# also works on datapoints. -# Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the -# datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below): +# I had a Datapoint but now I have a Tensor. Help! +# ------------------------------------------------ +# +# For a lot of operations involving datapoints, we cannot safely infer whether +# the result should retain the datapoint type, so we choose to return a plain +# tensor instead of a datapoint (this might change, see note below): -assert isinstance(image, datapoints.Image) +assert isinstance(bboxes, datapoints.BoundingBoxes) -new_image = image + 0 +# Shift bboxes by 3 pixels in both H and W +new_bboxes = bboxes + 3 -assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image) +assert isinstance(new_bboxes, torch.Tensor) and not isinstance(new_bboxes, datapoints.BoundingBoxes) + +# %% +# If you're writing your own custom transforms or code involving datapoints, you +# can re-wrap the output into a datapoint by just calling their constructor, or +# by using the ``.wrap_like()`` class method: + +new_bboxes = bboxes + 3 +new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) +assert isinstance(new_bboxes, datapoints.BoundingBoxes) # %% +# See more details above in :ref:`datapoint_creation`. +# +# .. note:: +# +# You never need to re-wrap manually if you're using the built-in transforms +# or their functional equivalents: this is automatically taken care of for +# you. +# # .. note:: # # This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you # have any suggestions on how to better support your use-cases, please reach out to us via this issue: # https://github.com/pytorch/vision/issues/7319 # -# There are two exceptions to this rule: +# There are a few exceptions to this "unwrapping" rule: # -# 1. The operations :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, and :meth:`~torch.Tensor.requires_grad_` -# retain the datapoint type. -# 2. Inplace operations on datapoints cannot change the type of the datapoint they are called on. However, if you use -# the flow style, the returned value will be unwrapped: +# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, +# :meth:`torch.Tensor.detach` and :meth:`~torch.Tensor.requires_grad_` retain +# the datapoint type. +# 2. Inplace operations on datapoints like ``.add_()`` preserve they type. However, +# the **returned** value of inplace operations will be unwrapped into a pure +# tensor: image = datapoints.Image([[[0, 1], [1, 0]]]) new_image = image.add_(1).mul_(2) -assert isinstance(image, torch.Tensor) +# image got transformed in-place and is still an Image datapoint, but new_image +# is a Tensor. They share the same underlying data and they're equal, just +# different classes. +assert isinstance(image, datapoints.Image) print(image) assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image) assert (new_image == image).all() +assert new_image.data_ptr() == image.data_ptr() diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index 912cc3bca08..7477b3652dc 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -42,7 +42,7 @@ class BoundingBoxes(Datapoint): canvas_size: Tuple[int, int] @classmethod - def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: + def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override] bounding_boxes = tensor.as_subclass(cls) bounding_boxes.format = format bounding_boxes.canvas_size = canvas_size diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 384273301de..fae3c18656b 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -14,6 +14,13 @@ class Datapoint(torch.Tensor): + """[Beta] Base class for all datapoints. + + You probably don't want to use this class unless you're defining your own + custom Datapoints. See + :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for details. + """ + @staticmethod def _to_tensor( data: Any, @@ -25,9 +32,13 @@ def _to_tensor( requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) + @classmethod + def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: + return tensor.as_subclass(cls) + @classmethod def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: - raise NotImplementedError + return cls._wrap(tensor) _NO_WRAPPING_EXCEPTIONS = { torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), diff --git a/torchvision/datapoints/_image.py b/torchvision/datapoints/_image.py index dccfc81a605..9b635e8e034 100644 --- a/torchvision/datapoints/_image.py +++ b/torchvision/datapoints/_image.py @@ -22,11 +22,6 @@ class Image(Datapoint): ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. """ - @classmethod - def _wrap(cls, tensor: torch.Tensor) -> Image: - image = tensor.as_subclass(cls) - return image - def __new__( cls, data: Any, @@ -48,10 +43,6 @@ def __new__( return cls._wrap(tensor) - @classmethod - def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image: - return cls._wrap(tensor) - def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() diff --git a/torchvision/datapoints/_mask.py b/torchvision/datapoints/_mask.py index 2b95eca72e2..95eda077929 100644 --- a/torchvision/datapoints/_mask.py +++ b/torchvision/datapoints/_mask.py @@ -22,10 +22,6 @@ class Mask(Datapoint): ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. """ - @classmethod - def _wrap(cls, tensor: torch.Tensor) -> Mask: - return tensor.as_subclass(cls) - def __new__( cls, data: Any, @@ -41,11 +37,3 @@ def __new__( tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) return cls._wrap(tensor) - - @classmethod - def wrap_like( - cls, - other: Mask, - tensor: torch.Tensor, - ) -> Mask: - return cls._wrap(tensor) diff --git a/torchvision/datapoints/_video.py b/torchvision/datapoints/_video.py index 11d6e2a854d..842c05bf7e9 100644 --- a/torchvision/datapoints/_video.py +++ b/torchvision/datapoints/_video.py @@ -20,11 +20,6 @@ class Video(Datapoint): ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. """ - @classmethod - def _wrap(cls, tensor: torch.Tensor) -> Video: - video = tensor.as_subclass(cls) - return video - def __new__( cls, data: Any, @@ -38,10 +33,6 @@ def __new__( raise ValueError return cls._wrap(tensor) - @classmethod - def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video: - return cls._wrap(tensor) - def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() diff --git a/torchvision/prototype/datapoints/_label.py b/torchvision/prototype/datapoints/_label.py index 7ed2f7522b0..ac9b2d8912a 100644 --- a/torchvision/prototype/datapoints/_label.py +++ b/torchvision/prototype/datapoints/_label.py @@ -15,7 +15,7 @@ class _LabelBase(Datapoint): categories: Optional[Sequence[str]] @classmethod - def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: + def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: # type: ignore[override] label_base = tensor.as_subclass(cls) label_base.categories = categories return label_base diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 1eaa54102a4..bb3d59b551a 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -47,6 +47,11 @@ def _name_to_dispatcher(name): def register_kernel(dispatcher, datapoint_cls): + """Decorate a kernel to register it for a dispatcher and a (custom) datapoint type. + + See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage + details. + """ if isinstance(dispatcher, str): dispatcher = _name_to_dispatcher(name=dispatcher) return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)