Skip to content

Avoid multiple pytree flattenings inside the prototype transforms #6760

@pmeier

Description

@pmeier

By design, our transforms v2 can handle arbitrary input structures. Internally, we are using torch.utils._pytree for it.

Status quo

The simplest transforms flatten / unflatten only once:

flat_inputs, spec = tree_flatten(sample)
flat_outputs = [
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)

However, albeit hidden, most transforms flatten at least twice if not multiple times:

This of course has some performance implications that can be avoided.

Proposal

In all of the cases where we perform the extra flattening are happening in internal and thus not user-facing methods. Thus, instead of keeping the option to operate on arbitrary input structures on our utilities, we could have them just work on already flattened inputs. This would avoid the repeated tree_flatten calls inside of them.

For our transforms that means two changes:

  1. The extra calls in _get_params can be avoided by simply flattening before its call

    params = self._get_params(sample)
    flat_inputs, spec = tree_flatten(sample)

  2. The extra calls in overridden forward's require more boilerplate code. Basically, each transform that overrides forward needs to perform the flattening / unflattening themselves since the check utilities are called before the super().forward(...) call. This also means that we are technically still flattening twice, although the second time inside super().forward(...) does nothing.

    However, there is introduce _check method for type checks on prototype transforms #6503 that introduces a common interface for the checks. IIRC, we never followed up on it, since we eliminated some boilerplate in the overridden forward in [proto] Simplified code in overridden transform forward methods #6504. Since this proposal would re-add some boilerplate for performance gains, we could pick introduce _check method for type checks on prototype transforms #6503 up again. If we do, this leaves very few, objectively "outlier" transforms that would need to have this boiler plate.

    If we go for the common check interface, it could receive the already flattened sample as well similar to what was proposed in 1.

Conclusion

Flattening an input sample multiple times inside a single transformation has no benefits while slowing down execution. This issue proposes a way to avoid this while keeping the UI as convenient as it is.

cc @vfdev-5 @datumbox @bjuncek

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions