diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index ddd6f37d083..1dec6bedf15 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -190,6 +190,7 @@ Miscellaneous v2.RandomErasing Lambda v2.Lambda + v2.SanitizeBoundingBox .. _conversion_transforms: @@ -210,6 +211,7 @@ Conversion ConvertImageDtype v2.ConvertImageDtype v2.ConvertDtype + v2.ToDtype Auto-Augmentation ----------------- diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 2a581bf5640..237e8d6181a 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -15,17 +15,11 @@ class Grayscale(Transform): .. betastatus:: Grayscale transform - If the image is torch Tensor, it is expected - to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions + If the input is a :class:`torch.Tensor`, it is expected + to have [..., 3 or 1, H, W] shape, where ... means an arbitrary number of leading dimensions Args: num_output_channels (int): (1 or 3) number of channels desired for output image - - Returns: - PIL Image: Grayscale version of the input. - - - If ``num_output_channels == 1`` : returned image is single channel - - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b """ _v1_transform_cls = _transforms.Grayscale @@ -50,18 +44,13 @@ class RandomGrayscale(_RandomApplyTransform): .. betastatus:: RandomGrayscale transform - If the image is torch Tensor, it is expected - to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions + If the input is a :class:`torch.Tensor`, it is expected to have [..., 3 or 1, H, W] shape, + where ... means an arbitrary number of leading dimensions + + The output has the same number of channels as the input. Args: p (float): probability that image should be converted to grayscale. - - Returns: - PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged - with probability (1-p). - - If input image is 1 channel: grayscale version is 1 channel - - If input image is 3 channel: grayscale version is 3 channel with r == g == b - """ _v1_transform_cls = _transforms.RandomGrayscale @@ -89,7 +78,7 @@ class ColorJitter(Transform): .. betastatus:: ColorJitter transform - If the image is torch Tensor, it is expected + If the input is a :class:`torch.Tensor`, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported. @@ -295,7 +284,7 @@ class RandomEqualize(_RandomApplyTransform): .. betastatus:: RandomEqualize transform - If the image is torch Tensor, it is expected + If the input is a :class:`torch.Tensor`, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". @@ -334,7 +323,7 @@ class RandomPosterize(_RandomApplyTransform): .. betastatus:: RandomPosterize transform - If the image is torch Tensor, it should be of type torch.uint8, + If the input is a :class:`torch.Tensor`, it should be of type torch.uint8, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". @@ -383,7 +372,7 @@ class RandomAutocontrast(_RandomApplyTransform): .. betastatus:: RandomAutocontrast transform - If the image is torch Tensor, it is expected + If the input is a :class:`torch.Tensor`, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". @@ -402,7 +391,7 @@ class RandomAdjustSharpness(_RandomApplyTransform): .. betastatus:: RandomAdjustSharpness transform - If the image is torch Tensor, + If the input is a :class:`torch.Tensor`, it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. Args: diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 53975a2ad2a..2237334f7a2 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -15,13 +15,14 @@ from .utils import has_any, is_simple_tensor, query_bounding_box +# TODO: do we want/need to expose this? class Identity(Transform): def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt class Lambda(Transform): - """[BETA] Apply a user-defined lambda as a transform. + """[BETA] Apply a user-defined function as a transform. .. betastatus:: Lambda transform @@ -52,7 +53,7 @@ def extra_repr(self) -> str: class LinearTransformation(Transform): - """[BETA] Transform a tensor image with a square transformation matrix and a mean_vector computed offline. + """[BETA] Transform a tensor image or video with a square transformation matrix and a mean_vector computed offline. .. betastatus:: LinearTransformation transform @@ -135,7 +136,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class Normalize(Transform): - """[BETA] Normalize a tensor image with mean and standard deviation. + """[BETA] Normalize a tensor image or video with mean and standard deviation. .. betastatus:: Normalize transform @@ -179,7 +180,7 @@ class GaussianBlur(Transform): .. betastatus:: GausssianBlur transform - If the image is torch Tensor, it is expected + If the input is a Tensor, it is expected to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions. Args: @@ -188,9 +189,6 @@ class GaussianBlur(Transform): creating kernel to perform blurring. If float, sigma is fixed. If it is tuple of float (min, max), sigma is chosen uniformly at random to lie in the given range. - - Returns: - PIL Image or Tensor: Gaussian blurred version of the input image. """ _v1_transform_cls = _transforms.GaussianBlur @@ -225,6 +223,15 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class ToDtype(Transform): + """[BETA] Converts the input to a specific dtype. + + .. betastatus:: ToDtype transform + + Args: + dtype (dtype or dict of Datapoint -> dtype): The dtype to convert to. A dict can be passed to specify + per-datapoint conversions, e.g. ``dtype={datapoints.Image: torch.float32, datapoints.Video: torch.float64}``. + """ + _transformed_types = (torch.Tensor,) def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None: @@ -247,9 +254,33 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class SanitizeBoundingBox(Transform): - # This removes boxes and their corresponding labels: - # - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) - # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) + """[BETA] Remove degenerate/invalid bounding boxes and their corresponding labels and masks. + + .. betastatus:: SanitizeBoundingBox transform + + This transform removes bounding boxes and their associated labels/masks that: + + - are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1. + - have any coordinate outside of their corresponding image. You may want to + call :class:`~torchvision.transforms.v2.ClampBoundingBox` first to avoid undesired removals. + + It is recommended to call it at the end of a pipeline, before passing the + input to the models. It is critical to call this transform if + :class:`~torchvision.transforms.v2.RandomIoUCrop` was called. + If you want to be extra careful, you may call it after all transforms that + may modify bounding boxes but once at the end should be enough in most + cases. + + Args: + min_size (float, optional) The size below which bounding boxes are removed. Default is 1. + labels_getter (callable or str or None, optional): indicates how to identify the labels in the input. + It can be a str in which case the input is expected to be a dict, and ``labels_getter`` then specifies + the key whose value corresponds to the labels. It can also be a callable that takes the same input + as the transform, and returns the labels. + By default, this will try to find a "labels" key in the input, if + the input is a dict or it is a tuple whose second element is a dict. + This heuristic should work well with a lot of datasets, including the built-in torchvision datasets. + """ def __init__( self,