Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ Miscellaneous
v2.RandomErasing
Lambda
v2.Lambda
v2.SanitizeBoundingBox

.. _conversion_transforms:

Expand All @@ -210,6 +211,7 @@ Conversion
ConvertImageDtype
v2.ConvertImageDtype
v2.ConvertDtype
v2.ToDtype

Auto-Augmentation
-----------------
Expand Down
33 changes: 11 additions & 22 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,11 @@ class Grayscale(Transform):

.. betastatus:: Grayscale transform

If the image is torch Tensor, it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 3 or 1, H, W] shape, where ... means an arbitrary number of leading dimensions

Args:
num_output_channels (int): (1 or 3) number of channels desired for output image

Returns:
PIL Image: Grayscale version of the input.

- If ``num_output_channels == 1`` : returned image is single channel
- If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
"""

_v1_transform_cls = _transforms.Grayscale
Expand All @@ -50,18 +44,13 @@ class RandomGrayscale(_RandomApplyTransform):

.. betastatus:: RandomGrayscale transform

If the image is torch Tensor, it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
If the input is a :class:`torch.Tensor`, it is expected to have [..., 3 or 1, H, W] shape,
where ... means an arbitrary number of leading dimensions

The output has the same number of channels as the input.

Args:
p (float): probability that image should be converted to grayscale.

Returns:
PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
with probability (1-p).
- If input image is 1 channel: grayscale version is 1 channel
- If input image is 3 channel: grayscale version is 3 channel with r == g == b

"""

_v1_transform_cls = _transforms.RandomGrayscale
Expand Down Expand Up @@ -89,7 +78,7 @@ class ColorJitter(Transform):

.. betastatus:: ColorJitter transform

If the image is torch Tensor, it is expected
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.

Expand Down Expand Up @@ -295,7 +284,7 @@ class RandomEqualize(_RandomApplyTransform):

.. betastatus:: RandomEqualize transform

If the image is torch Tensor, it is expected
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".

Expand Down Expand Up @@ -334,7 +323,7 @@ class RandomPosterize(_RandomApplyTransform):

.. betastatus:: RandomPosterize transform

If the image is torch Tensor, it should be of type torch.uint8,
If the input is a :class:`torch.Tensor`, it should be of type torch.uint8,
and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".

Expand Down Expand Up @@ -383,7 +372,7 @@ class RandomAutocontrast(_RandomApplyTransform):

.. betastatus:: RandomAutocontrast transform

If the image is torch Tensor, it is expected
If the input is a :class:`torch.Tensor`, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".

Expand All @@ -402,7 +391,7 @@ class RandomAdjustSharpness(_RandomApplyTransform):

.. betastatus:: RandomAdjustSharpness transform

If the image is torch Tensor,
If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.

Args:
Expand Down
51 changes: 41 additions & 10 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from .utils import has_any, is_simple_tensor, query_bounding_box


# TODO: do we want/need to expose this?
class Identity(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt


class Lambda(Transform):
"""[BETA] Apply a user-defined lambda as a transform.
"""[BETA] Apply a user-defined function as a transform.

.. betastatus:: Lambda transform

Expand Down Expand Up @@ -52,7 +53,7 @@ def extra_repr(self) -> str:


class LinearTransformation(Transform):
"""[BETA] Transform a tensor image with a square transformation matrix and a mean_vector computed offline.
"""[BETA] Transform a tensor image or video with a square transformation matrix and a mean_vector computed offline.

.. betastatus:: LinearTransformation transform

Expand Down Expand Up @@ -135,7 +136,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class Normalize(Transform):
"""[BETA] Normalize a tensor image with mean and standard deviation.
"""[BETA] Normalize a tensor image or video with mean and standard deviation.

.. betastatus:: Normalize transform

Expand Down Expand Up @@ -179,7 +180,7 @@ class GaussianBlur(Transform):

.. betastatus:: GausssianBlur transform

If the image is torch Tensor, it is expected
If the input is a Tensor, it is expected
to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.

Args:
Expand All @@ -188,9 +189,6 @@ class GaussianBlur(Transform):
creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
of float (min, max), sigma is chosen uniformly at random to lie in the
given range.

Returns:
PIL Image or Tensor: Gaussian blurred version of the input image.
"""

_v1_transform_cls = _transforms.GaussianBlur
Expand Down Expand Up @@ -225,6 +223,15 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class ToDtype(Transform):
"""[BETA] Converts the input to a specific dtype.

.. betastatus:: ToDtype transform

Args:
dtype (dtype or dict of Datapoint -> dtype): The dtype to convert to. A dict can be passed to specify
per-datapoint conversions, e.g. ``dtype={datapoints.Image: torch.float32, datapoints.Video: torch.float64}``.
"""

_transformed_types = (torch.Tensor,)

def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None:
Expand All @@ -247,9 +254,33 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class SanitizeBoundingBox(Transform):
# This removes boxes and their corresponding labels:
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
"""[BETA] Remove degenerate/invalid bounding boxes and their corresponding labels and masks.

.. betastatus:: SanitizeBoundingBox transform

This transform removes bounding boxes and their associated labels/masks that:

- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- have any coordinate outside of their corresponding image. You may want to
call :class:`~torchvision.transforms.v2.ClampBoundingBox` first to avoid undesired removals.

It is recommended to call it at the end of a pipeline, before passing the
input to the models. It is critical to call this transform if
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
If you want to be extra careful, you may call it after all transforms that
may modify bounding boxes but once at the end should be enough in most
cases.

Args:
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input.
It can be a str in which case the input is expected to be a dict, and ``labels_getter`` then specifies
the key whose value corresponds to the labels. It can also be a callable that takes the same input
as the transform, and returns the labels.
By default, this will try to find a "labels" key in the input, if
the input is a dict or it is a tuple whose second element is a dict.
This heuristic should work well with a lot of datasets, including the built-in torchvision datasets.
"""

def __init__(
self,
Expand Down