From d93c7591c6bab8b1b03a122ea317a3a337842523 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 22 Aug 2022 16:40:37 +0200 Subject: [PATCH 1/2] fix MixUp and CutMix --- torchvision/prototype/transforms/_augment.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index bb884a6cb77..d33454c4cc6 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -105,9 +105,7 @@ def __init__(self, *, alpha: float, p: float = 0.5) -> None: def forward(self, *inpts: Any) -> Any: sample = inpts if len(inpts) > 1 else inpts[0] - if not ( - has_any(sample, features.Image, PIL.Image.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel) - ): + if not (has_any(sample, features.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel)): raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label): raise TypeError( @@ -129,12 +127,16 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: lam = params["lam"] - if isinstance(inpt, features.Image): + if isinstance(inpt, features.Image) or is_simple_tensor(inpt): if inpt.ndim < 4: raise ValueError("Need a batch of images") output = inpt.clone() output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam)) - return features.Image.new_like(inpt, output) + + if isinstance(inpt, features.Image): + output = features.Image.new_like(inpt, output) + + return output elif isinstance(inpt, features.OneHotLabel): return self._mixup_onehotlabel(inpt, lam) else: @@ -166,7 +168,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features.Image): + if isinstance(inpt, features.Image) or is_simple_tensor(inpt): box = params["box"] if inpt.ndim < 4: raise ValueError("Need a batch of images") @@ -174,7 +176,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: image_rolled = inpt.roll(1, -4) output = inpt.clone() output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] - return features.Image.new_like(inpt, output) + + if isinstance(inpt, features.Image): + output = features.Image.new_like(inpt, output) + + return output elif isinstance(inpt, features.OneHotLabel): lam_adjusted = params["lam_adjusted"] return self._mixup_onehotlabel(inpt, lam_adjusted) From 096c3b077862792a552f7be6b3ac555e1ceae0fe Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 23 Aug 2022 09:02:54 +0200 Subject: [PATCH 2/2] improve error message --- torchvision/prototype/transforms/_augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index d33454c4cc6..fbe3e3b6b68 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -106,7 +106,7 @@ def __init__(self, *, alpha: float, p: float = 0.5) -> None: def forward(self, *inpts: Any) -> Any: sample = inpts if len(inpts) > 1 else inpts[0] if not (has_any(sample, features.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel)): - raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") + raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.") if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label): raise TypeError( f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels."