-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
By design, our transforms v2 can handle arbitrary input structures. Internally, we are using torch.utils._pytree
for it.
Status quo
The simplest transforms flatten / unflatten only once:
vision/torchvision/prototype/transforms/_transform.py
Lines 37 to 42 in 54a2d4e
flat_inputs, spec = tree_flatten(sample) | |
flat_outputs = [ | |
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt | |
for inpt in flat_inputs | |
] | |
return tree_unflatten(flat_outputs, spec) |
However, albeit hidden, most transforms flatten at least twice if not multiple times:
-
If a transform needs to know the spatial size to compute its
params
height, width = query_spatial_size(sample) the extraction logic flattens again
vision/torchvision/prototype/transforms/_utils.py
Lines 101 to 102 in 54a2d4e
def query_spatial_size(sample: Any) -> Tuple[int, int]: flat_sample, _ = tree_flatten(sample) -
If a transform performs some checks on the sample before transforming
vision/torchvision/prototype/transforms/_geometry.py
Lines 188 to 190 in 54a2d4e
if has_any(inputs, features.BoundingBox, features.Mask): raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") return super().forward(*inputs) the checking utility flattens again
vision/torchvision/prototype/transforms/_utils.py
Lines 124 to 125 in 54a2d4e
def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample)
This of course has some performance implications that can be avoided.
Proposal
In all of the cases where we perform the extra flattening are happening in internal and thus not user-facing methods. Thus, instead of keeping the option to operate on arbitrary input structures on our utilities, we could have them just work on already flattened inputs. This would avoid the repeated tree_flatten
calls inside of them.
For our transforms that means two changes:
-
The extra calls in
_get_params
can be avoided by simply flattening before its callvision/torchvision/prototype/transforms/_transform.py
Lines 35 to 37 in 54a2d4e
params = self._get_params(sample) flat_inputs, spec = tree_flatten(sample) -
The extra calls in overridden
forward
's require more boilerplate code. Basically, each transform that overridesforward
needs to perform the flattening / unflattening themselves since the check utilities are called before thesuper().forward(...)
call. This also means that we are technically still flattening twice, although the second time insidesuper().forward(...)
does nothing.However, there is introduce _check method for type checks on prototype transforms #6503 that introduces a common interface for the checks. IIRC, we never followed up on it, since we eliminated some boilerplate in the overridden
forward
in [proto] Simplified code in overridden transform forward methods #6504. Since this proposal would re-add some boilerplate for performance gains, we could pick introduce _check method for type checks on prototype transforms #6503 up again. If we do, this leaves very few, objectively "outlier" transforms that would need to have this boiler plate.If we go for the common check interface, it could receive the already flattened sample as well similar to what was proposed in 1.
Conclusion
Flattening an input sample multiple times inside a single transformation has no benefits while slowing down execution. This issue proposes a way to avoid this while keeping the UI as convenient as it is.