Skip to content

Commit 60c78f2

Browse files
authored
make RandomErasing scriptable for integer value (#7134)
1 parent 5dd9594 commit 60c78f2

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

test/test_transforms_tensor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,17 @@ def shear(pil_img, level, mode, resample):
672672
@pytest.mark.parametrize("device", cpu_and_gpu())
673673
@pytest.mark.parametrize(
674674
"config",
675-
[{"value": 0.2}, {"value": "random"}, {"value": (0.2, 0.2, 0.2)}, {"value": "random", "ratio": (0.1, 0.2)}],
675+
[
676+
{},
677+
{"value": 1},
678+
{"value": 0.2},
679+
{"value": "random"},
680+
{"value": (1, 1, 1)},
681+
{"value": (0.2, 0.2, 0.2)},
682+
{"value": [1, 1, 1]},
683+
{"value": [0.2, 0.2, 0.2]},
684+
{"value": "random", "ratio": (0.1, 0.2)},
685+
],
676686
)
677687
def test_random_erasing(device, config):
678688
tensor, _ = _create_data(24, 32, channels=3, device=device)

torchvision/prototype/transforms/_augment.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
p: float = 0.5,
2424
scale: Tuple[float, float] = (0.02, 0.33),
2525
ratio: Tuple[float, float] = (0.3, 3.3),
26-
value: float = 0,
26+
value: float = 0.0,
2727
inplace: bool = False,
2828
):
2929
super().__init__(p=p)
@@ -42,11 +42,11 @@ def __init__(
4242
self.scale = scale
4343
self.ratio = ratio
4444
if isinstance(value, (int, float)):
45-
self.value = [value]
45+
self.value = [float(value)]
4646
elif isinstance(value, str):
4747
self.value = None
48-
elif isinstance(value, tuple):
49-
self.value = list(value)
48+
elif isinstance(value, (list, tuple)):
49+
self.value = [float(v) for v in value]
5050
else:
5151
self.value = value
5252
self.inplace = inplace

torchvision/transforms/transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,11 +1713,11 @@ def forward(self, img):
17131713

17141714
# cast self.value to script acceptable type
17151715
if isinstance(self.value, (int, float)):
1716-
value = [self.value]
1716+
value = [float(self.value)]
17171717
elif isinstance(self.value, str):
17181718
value = None
1719-
elif isinstance(self.value, tuple):
1720-
value = list(self.value)
1719+
elif isinstance(self.value, (list, tuple)):
1720+
value = [float(v) for v in self.value]
17211721
else:
17221722
value = self.value
17231723

0 commit comments

Comments
 (0)