Skip to content

Commit 6e06820

Browse files
committed
split image_size into height and width in auto augment
1 parent 13c0d08 commit 6e06820

File tree

1 file changed

+72
-60
lines changed

1 file changed

+72
-60
lines changed

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 72 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -199,25 +199,31 @@ def _apply_image_transform(
199199

200200
class AutoAugment(_AutoAugmentBase):
201201
_AUGMENTATION_SPACE = {
202-
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
203-
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
204-
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
205-
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
206-
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
207-
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
208-
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
209-
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
210-
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
202+
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
203+
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
204+
"TranslateX": (
205+
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
206+
True,
207+
),
208+
"TranslateY": (
209+
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
210+
True,
211+
),
212+
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
213+
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
214+
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
215+
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
216+
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
211217
"Posterize": (
212-
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
218+
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
213219
.round()
214220
.int(),
215221
False,
216222
),
217-
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
218-
"AutoContrast": (lambda num_bins, image_size: None, False),
219-
"Equalize": (lambda num_bins, image_size: None, False),
220-
"Invert": (lambda num_bins, image_size: None, False),
223+
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
224+
"AutoContrast": (lambda num_bins, height, width: None, False),
225+
"Equalize": (lambda num_bins, height, width: None, False),
226+
"Invert": (lambda num_bins, height, width: None, False),
221227
}
222228

223229
def __init__(
@@ -335,7 +341,7 @@ def forward(self, *inputs: Any) -> Any:
335341

336342
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
337343

338-
magnitudes = magnitudes_fn(10, (height, width))
344+
magnitudes = magnitudes_fn(10, height, width)
339345
if magnitudes is not None:
340346
magnitude = float(magnitudes[magnitude_idx])
341347
if signed and torch.rand(()) <= 0.5:
@@ -352,25 +358,31 @@ def forward(self, *inputs: Any) -> Any:
352358

353359
class RandAugment(_AutoAugmentBase):
354360
_AUGMENTATION_SPACE = {
355-
"Identity": (lambda num_bins, image_size: None, False),
356-
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
357-
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
358-
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
359-
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
360-
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
361-
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
362-
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
363-
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
364-
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
361+
"Identity": (lambda num_bins, height, width: None, False),
362+
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
363+
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
364+
"TranslateX": (
365+
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
366+
True,
367+
),
368+
"TranslateY": (
369+
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
370+
True,
371+
),
372+
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
373+
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
374+
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
375+
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
376+
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
365377
"Posterize": (
366-
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
378+
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
367379
.round()
368380
.int(),
369381
False,
370382
),
371-
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
372-
"AutoContrast": (lambda num_bins, image_size: None, False),
373-
"Equalize": (lambda num_bins, image_size: None, False),
383+
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
384+
"AutoContrast": (lambda num_bins, height, width: None, False),
385+
"Equalize": (lambda num_bins, height, width: None, False),
374386
}
375387

376388
def __init__(
@@ -397,7 +409,7 @@ def forward(self, *inputs: Any) -> Any:
397409
for _ in range(self.num_ops):
398410
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
399411

400-
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
412+
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
401413
if magnitudes is not None:
402414
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
403415
if signed and torch.rand(()) <= 0.5:
@@ -414,25 +426,25 @@ def forward(self, *inputs: Any) -> Any:
414426

415427
class TrivialAugmentWide(_AutoAugmentBase):
416428
_AUGMENTATION_SPACE = {
417-
"Identity": (lambda num_bins, image_size: None, False),
418-
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
419-
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
420-
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True),
421-
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True),
422-
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True),
423-
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
424-
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
425-
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
426-
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True),
429+
"Identity": (lambda num_bins, height, width: None, False),
430+
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
431+
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
432+
"TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
433+
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
434+
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
435+
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
436+
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
437+
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
438+
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
427439
"Posterize": (
428-
lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)))
440+
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)))
429441
.round()
430442
.int(),
431443
False,
432444
),
433-
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
434-
"AutoContrast": (lambda num_bins, image_size: None, False),
435-
"Equalize": (lambda num_bins, image_size: None, False),
445+
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
446+
"AutoContrast": (lambda num_bins, height, width: None, False),
447+
"Equalize": (lambda num_bins, height, width: None, False),
436448
}
437449

438450
def __init__(
@@ -454,7 +466,7 @@ def forward(self, *inputs: Any) -> Any:
454466

455467
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
456468

457-
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
469+
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
458470
if magnitudes is not None:
459471
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
460472
if signed and torch.rand(()) <= 0.5:
@@ -468,27 +480,27 @@ def forward(self, *inputs: Any) -> Any:
468480

469481
class AugMix(_AutoAugmentBase):
470482
_PARTIAL_AUGMENTATION_SPACE = {
471-
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
472-
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
473-
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
474-
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
475-
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
483+
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
484+
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
485+
"TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
486+
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
487+
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
476488
"Posterize": (
477-
lambda num_bins, image_size: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
489+
lambda num_bins, height, width: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
478490
.round()
479491
.int(),
480492
False,
481493
),
482-
"Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False),
483-
"AutoContrast": (lambda num_bins, image_size: None, False),
484-
"Equalize": (lambda num_bins, image_size: None, False),
494+
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
495+
"AutoContrast": (lambda num_bins, height, width: None, False),
496+
"Equalize": (lambda num_bins, height, width: None, False),
485497
}
486-
_AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, Tuple[int, int]], Optional[torch.Tensor]], bool]] = {
498+
_AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
487499
**_PARTIAL_AUGMENTATION_SPACE,
488-
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
489-
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
490-
"Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
491-
"Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
500+
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
501+
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
502+
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
503+
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
492504
}
493505

494506
def __init__(
@@ -550,7 +562,7 @@ def forward(self, *inputs: Any) -> Any:
550562
for _ in range(depth):
551563
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
552564

553-
magnitudes = magnitudes_fn(self._PARAMETER_MAX, (height, width))
565+
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
554566
if magnitudes is not None:
555567
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
556568
if signed and torch.rand(()) <= 0.5:

0 commit comments

Comments
 (0)