diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index fefe5d3defc..55d254a92d4 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -99,10 +99,8 @@ def __init__(self, *, alpha: float, p: float = 0.5) -> None: def forward(self, *inpts: Any) -> Any: sample = inpts if len(inpts) > 1 else inpts[0] - if not ( - has_any(sample, features.Image, PIL.Image.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel) - ): - raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") + if not (has_any(sample, features.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel)): + raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.") if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label): raise TypeError( f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels." @@ -123,12 +121,16 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: lam = params["lam"] - if isinstance(inpt, features.Image): + if isinstance(inpt, features.Image) or is_simple_tensor(inpt): if inpt.ndim < 4: raise ValueError("Need a batch of images") output = inpt.clone() output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam)) - return features.Image.new_like(inpt, output) + + if isinstance(inpt, features.Image): + output = features.Image.new_like(inpt, output) + + return output elif isinstance(inpt, features.OneHotLabel): return self._mixup_onehotlabel(inpt, lam) else: @@ -159,7 +161,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features.Image): + if isinstance(inpt, features.Image) or is_simple_tensor(inpt): box = params["box"] if inpt.ndim < 4: raise ValueError("Need a batch of images") @@ -167,7 +169,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: image_rolled = inpt.roll(1, -4) output = inpt.clone() output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] - return features.Image.new_like(inpt, output) + + if isinstance(inpt, features.Image): + output = features.Image.new_like(inpt, output) + + return output elif isinstance(inpt, features.OneHotLabel): lam_adjusted = params["lam_adjusted"] return self._mixup_onehotlabel(inpt, lam_adjusted)