Skip to content

Commit e3238e5

Browse files
authored
only flatten a pytree once (#6767)
1 parent dc5fd83 commit e3238e5

File tree

10 files changed

+143
-145
lines changed

10 files changed

+143
-145
lines changed

test/test_prototype_transforms.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def test__get_params(self, fill, side_range, mocker):
437437
image = mocker.MagicMock(spec=features.Image)
438438
h, w = image.spatial_size = (24, 32)
439439

440-
params = transform._get_params(image)
440+
params = transform._get_params([image])
441441

442442
assert len(params["padding"]) == 4
443443
assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w
@@ -462,7 +462,7 @@ def test__transform(self, fill, side_range, mocker):
462462
_ = transform(inpt)
463463
torch.manual_seed(12)
464464
torch.rand(1) # random apply changes random state
465-
params = transform._get_params(inpt)
465+
params = transform._get_params([inpt])
466466

467467
fill = transforms.functional._geometry._convert_fill_arg(fill)
468468
fn.assert_called_once_with(inpt, **params, fill=fill)
@@ -623,7 +623,7 @@ def test__get_params(self, degrees, translate, scale, shear, mocker):
623623
h, w = image.spatial_size
624624

625625
transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear)
626-
params = transform._get_params(image)
626+
params = transform._get_params([image])
627627

628628
if not isinstance(degrees, (list, tuple)):
629629
assert -degrees <= params["angle"] <= degrees
@@ -690,7 +690,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker
690690
torch.manual_seed(12)
691691
_ = transform(inpt)
692692
torch.manual_seed(12)
693-
params = transform._get_params(inpt)
693+
params = transform._get_params([inpt])
694694

695695
fill = transforms.functional._geometry._convert_fill_arg(fill)
696696
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)
@@ -722,7 +722,7 @@ def test__get_params(self, padding, pad_if_needed, size, mocker):
722722
h, w = image.spatial_size
723723

724724
transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed)
725-
params = transform._get_params(image)
725+
params = transform._get_params([image])
726726

727727
if padding is not None:
728728
if isinstance(padding, int):
@@ -793,7 +793,7 @@ def test__transform(self, padding, pad_if_needed, fill, padding_mode, mocker):
793793
torch.manual_seed(12)
794794
_ = transform(inpt)
795795
torch.manual_seed(12)
796-
params = transform._get_params(inpt)
796+
params = transform._get_params([inpt])
797797
if padding is None and not pad_if_needed:
798798
fn_crop.assert_called_once_with(
799799
inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
@@ -832,7 +832,7 @@ def test_assertions(self):
832832
@pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0]])
833833
def test__get_params(self, sigma):
834834
transform = transforms.GaussianBlur(3, sigma=sigma)
835-
params = transform._get_params(None)
835+
params = transform._get_params([])
836836

837837
if isinstance(sigma, float):
838838
assert params["sigma"][0] == params["sigma"][1] == 10
@@ -867,7 +867,7 @@ def test__transform(self, kernel_size, sigma, mocker):
867867
torch.manual_seed(12)
868868
_ = transform(inpt)
869869
torch.manual_seed(12)
870-
params = transform._get_params(inpt)
870+
params = transform._get_params([inpt])
871871

872872
fn.assert_called_once_with(inpt, kernel_size, **params)
873873

@@ -912,7 +912,7 @@ def test__get_params(self, mocker):
912912
image.num_channels = 3
913913
image.spatial_size = (24, 32)
914914

915-
params = transform._get_params(image)
915+
params = transform._get_params([image])
916916

917917
h, w = image.spatial_size
918918
assert "perspective_coeffs" in params
@@ -935,7 +935,7 @@ def test__transform(self, distortion_scale, mocker):
935935
_ = transform(inpt)
936936
torch.manual_seed(12)
937937
torch.rand(1) # random apply changes random state
938-
params = transform._get_params(inpt)
938+
params = transform._get_params([inpt])
939939

940940
fill = transforms.functional._geometry._convert_fill_arg(fill)
941941
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
@@ -973,7 +973,7 @@ def test__get_params(self, mocker):
973973
image.num_channels = 3
974974
image.spatial_size = (24, 32)
975975

976-
params = transform._get_params(image)
976+
params = transform._get_params([image])
977977

978978
h, w = image.spatial_size
979979
displacement = params["displacement"]
@@ -1006,7 +1006,7 @@ def test__transform(self, alpha, sigma, mocker):
10061006
# Let's mock transform._get_params to control the output:
10071007
transform._get_params = mocker.MagicMock()
10081008
_ = transform(inpt)
1009-
params = transform._get_params(inpt)
1009+
params = transform._get_params([inpt])
10101010
fill = transforms.functional._geometry._convert_fill_arg(fill)
10111011
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
10121012

@@ -1035,7 +1035,7 @@ def test_assertions(self, mocker):
10351035
transform = transforms.RandomErasing(value=[1, 2, 3, 4])
10361036

10371037
with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
1038-
transform._get_params(image)
1038+
transform._get_params([image])
10391039

10401040
@pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"])
10411041
def test__get_params(self, value, mocker):
@@ -1044,7 +1044,7 @@ def test__get_params(self, value, mocker):
10441044
image.spatial_size = (24, 32)
10451045

10461046
transform = transforms.RandomErasing(value=value)
1047-
params = transform._get_params(image)
1047+
params = transform._get_params([image])
10481048

10491049
v = params["v"]
10501050
h, w = params["h"], params["w"]
@@ -1197,6 +1197,7 @@ def test_assertions(self, transform_cls):
11971197
[
11981198
[transforms.Pad(2), transforms.RandomCrop(28)],
11991199
[lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)],
1200+
[transforms.Pad(2), lambda x: 2.0 * x, transforms.RandomCrop(28)],
12001201
],
12011202
)
12021203
def test_ctor(self, transform_cls, trfms):
@@ -1339,7 +1340,7 @@ def test__get_params(self, mocker):
13391340
n_samples = 5
13401341
for _ in range(n_samples):
13411342

1342-
params = transform._get_params(sample)
1343+
params = transform._get_params([sample])
13431344

13441345
assert "size" in params
13451346
size = params["size"]
@@ -1386,7 +1387,7 @@ def test__get_params(self, mocker):
13861387
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size)
13871388

13881389
sample = mocker.MagicMock(spec=features.Image, num_channels=3, spatial_size=spatial_size)
1389-
params = transform._get_params(sample)
1390+
params = transform._get_params([sample])
13901391

13911392
assert "size" in params
13921393
size = params["size"]
@@ -1554,13 +1555,13 @@ def test__get_params(self, mocker):
15541555

15551556
transform = transforms.FixedSizeCrop(size=crop_size)
15561557

1557-
sample = dict(
1558-
image=make_image(size=spatial_size, color_space=features.ColorSpace.RGB),
1559-
bounding_boxes=make_bounding_box(
1558+
flat_inputs = [
1559+
make_image(size=spatial_size, color_space=features.ColorSpace.RGB),
1560+
make_bounding_box(
15601561
format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape
15611562
),
1562-
)
1563-
params = transform._get_params(sample)
1563+
]
1564+
params = transform._get_params(flat_inputs)
15641565

15651566
assert params["needs_crop"]
15661567
assert params["height"] <= crop_size[0]
@@ -1759,7 +1760,7 @@ def test__get_params(self):
17591760
transform = transforms.RandomResize(min_size=min_size, max_size=max_size)
17601761

17611762
for _ in range(10):
1762-
params = transform._get_params(None)
1763+
params = transform._get_params([])
17631764

17641765
assert isinstance(params["size"], list) and len(params["size"]) == 1
17651766
size = params["size"][0]

test/test_prototype_transforms_consistency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def test_random_apply(self, p):
639639
prototype_transform = prototype_transforms.RandomApply(
640640
[
641641
prototype_transforms.Resize(256),
642-
legacy_transforms.CenterCrop(224),
642+
prototype_transforms.CenterCrop(224),
643643
],
644644
p=p,
645645
)

torchvision/prototype/transforms/_augment.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def __init__(
4545

4646
self._log_ratio = torch.log(torch.tensor(self.ratio))
4747

48-
def _get_params(self, sample: Any) -> Dict[str, Any]:
49-
img_c, img_h, img_w = query_chw(sample)
48+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
49+
img_c, img_h, img_w = query_chw(flat_inputs)
5050

5151
if isinstance(self.value, (int, float)):
5252
value = [self.value]
@@ -107,13 +107,13 @@ def __init__(self, alpha: float, p: float = 0.5) -> None:
107107
self.alpha = alpha
108108
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
109109

110-
def _check_inputs(self, sample: Any) -> None:
110+
def _check_inputs(self, flat_inputs: List[Any]) -> None:
111111
if not (
112-
has_any(sample, features.Image, features.Video, features.is_simple_tensor)
113-
and has_any(sample, features.OneHotLabel)
112+
has_any(flat_inputs, features.Image, features.Video, features.is_simple_tensor)
113+
and has_any(flat_inputs, features.OneHotLabel)
114114
):
115115
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
116-
if has_any(sample, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
116+
if has_any(flat_inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
117117
raise TypeError(
118118
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
119119
)
@@ -127,7 +127,7 @@ def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features
127127

128128

129129
class RandomMixup(_BaseMixupCutmix):
130-
def _get_params(self, sample: Any) -> Dict[str, Any]:
130+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
131131
return dict(lam=float(self._dist.sample(())))
132132

133133
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
@@ -150,10 +150,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
150150

151151

152152
class RandomCutmix(_BaseMixupCutmix):
153-
def _get_params(self, sample: Any) -> Dict[str, Any]:
153+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
154154
lam = float(self._dist.sample(()))
155155

156-
H, W = query_spatial_size(sample)
156+
H, W = query_spatial_size(flat_inputs)
157157

158158
r_x = torch.randint(W, ())
159159
r_y = torch.randint(H, ())
@@ -344,9 +344,9 @@ def _insert_outputs(
344344
c3 += 1
345345

346346
def forward(self, *inputs: Any) -> Any:
347-
flat_sample, spec = tree_flatten(inputs)
347+
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
348348

349-
images, targets = self._extract_image_targets(flat_sample)
349+
images, targets = self._extract_image_targets(flat_inputs)
350350

351351
# images = [t1, t2, ..., tN]
352352
# Let's define paste_images as shifted list of input images
@@ -384,6 +384,6 @@ def forward(self, *inputs: Any) -> Any:
384384
output_targets.append(output_target)
385385

386386
# Insert updated images and targets into input flat_sample
387-
self._insert_outputs(flat_sample, output_images, output_targets)
387+
self._insert_outputs(flat_inputs, output_images, output_targets)
388388

389-
return tree_unflatten(flat_sample, spec)
389+
return tree_unflatten(flat_inputs, spec)

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import PIL.Image
55
import torch
66

7-
from torch.utils._pytree import tree_flatten, tree_unflatten
7+
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
88
from torchvision.prototype import features
99
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
1010
from torchvision.prototype.transforms.functional._meta import get_spatial_size
@@ -31,16 +31,17 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
3131
key = keys[int(torch.randint(len(keys), ()))]
3232
return key, dct[key]
3333

34-
def _extract_image_or_video(
34+
def _flatten_and_extract_image_or_video(
3535
self,
36-
sample: Any,
36+
inputs: Any,
3737
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
38-
) -> Tuple[int, Union[features.ImageType, features.VideoType]]:
39-
sample_flat, _ = tree_flatten(sample)
38+
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[features.ImageType, features.VideoType]]:
39+
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
40+
4041
image_or_videos = []
41-
for id, inpt in enumerate(sample_flat):
42+
for idx, inpt in enumerate(flat_inputs):
4243
if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)):
43-
image_or_videos.append((id, inpt))
44+
image_or_videos.append((idx, inpt))
4445
elif isinstance(inpt, unsupported_types):
4546
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
4647

@@ -51,12 +52,18 @@ def _extract_image_or_video(
5152
f"Auto augment transformations are only properly defined for a single image or video, "
5253
f"but found {len(image_or_videos)}."
5354
)
54-
return image_or_videos[0]
5555

56-
def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any:
57-
sample_flat, spec = tree_flatten(sample)
58-
sample_flat[id] = item
59-
return tree_unflatten(sample_flat, spec)
56+
idx, image_or_video = image_or_videos[0]
57+
return (flat_inputs, spec, idx), image_or_video
58+
59+
def _unflatten_and_insert_image_or_video(
60+
self,
61+
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
62+
image_or_video: Union[features.ImageType, features.VideoType],
63+
) -> Any:
64+
flat_inputs, spec, idx = flat_inputs_with_spec
65+
flat_inputs[idx] = image_or_video
66+
return tree_unflatten(flat_inputs, spec)
6067

6168
def _apply_image_or_video_transform(
6269
self,
@@ -275,9 +282,7 @@ def _get_policies(
275282
raise ValueError(f"The provided policy {policy} is not recognized.")
276283

277284
def forward(self, *inputs: Any) -> Any:
278-
sample = inputs if len(inputs) > 1 else inputs[0]
279-
280-
id, image_or_video = self._extract_image_or_video(sample)
285+
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
281286
height, width = get_spatial_size(image_or_video)
282287

283288
policy = self._policies[int(torch.randint(len(self._policies), ()))]
@@ -300,7 +305,7 @@ def forward(self, *inputs: Any) -> Any:
300305
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
301306
)
302307

303-
return self._put_into_sample(sample, id, image_or_video)
308+
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
304309

305310

306311
class RandAugment(_AutoAugmentBase):
@@ -346,9 +351,7 @@ def __init__(
346351
self.num_magnitude_bins = num_magnitude_bins
347352

348353
def forward(self, *inputs: Any) -> Any:
349-
sample = inputs if len(inputs) > 1 else inputs[0]
350-
351-
id, image_or_video = self._extract_image_or_video(sample)
354+
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
352355
height, width = get_spatial_size(image_or_video)
353356

354357
for _ in range(self.num_ops):
@@ -364,7 +367,7 @@ def forward(self, *inputs: Any) -> Any:
364367
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
365368
)
366369

367-
return self._put_into_sample(sample, id, image_or_video)
370+
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
368371

369372

370373
class TrivialAugmentWide(_AutoAugmentBase):
@@ -400,9 +403,7 @@ def __init__(
400403
self.num_magnitude_bins = num_magnitude_bins
401404

402405
def forward(self, *inputs: Any) -> Any:
403-
sample = inputs if len(inputs) > 1 else inputs[0]
404-
405-
id, image_or_video = self._extract_image_or_video(sample)
406+
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
406407
height, width = get_spatial_size(image_or_video)
407408

408409
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
@@ -418,7 +419,7 @@ def forward(self, *inputs: Any) -> Any:
418419
image_or_video = self._apply_image_or_video_transform(
419420
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
420421
)
421-
return self._put_into_sample(sample, id, image_or_video)
422+
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
422423

423424

424425
class AugMix(_AutoAugmentBase):
@@ -471,8 +472,7 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
471472
return torch._sample_dirichlet(params)
472473

473474
def forward(self, *inputs: Any) -> Any:
474-
sample = inputs if len(inputs) > 1 else inputs[0]
475-
id, orig_image_or_video = self._extract_image_or_video(sample)
475+
flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
476476
height, width = get_spatial_size(orig_image_or_video)
477477

478478
if isinstance(orig_image_or_video, torch.Tensor):
@@ -525,4 +525,4 @@ def forward(self, *inputs: Any) -> Any:
525525
elif isinstance(orig_image_or_video, PIL.Image.Image):
526526
mix = F.to_image_pil(mix)
527527

528-
return self._put_into_sample(sample, id, mix)
528+
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)

0 commit comments

Comments
 (0)