This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Description
Is your feature request related to a problem? Please describe:
Add a function to check whether a nn.Module could be materialized.
Describe the solution you would like:
from torchdistx.deferred_init import is_deferred
implementation:
from typing import Mapping, Optional, Union
from torch import Tensor
from torch.nn import Module, Parameter
def is_deferred(module: Module) -> bool:
def any_fake(tensors: Mapping[str, Optional[Union[Tensor, Parameter]]]) -> bool:
return any(is_fake(t) for t in tensors.values() if t is not None)
is_deferred = any(_is_deferred(m) for m in module.children())
return is_deferred or any_fake(module._parameters) or any_fake(module._buffers)
Describe the alternatives you have considered:
I could copy the materialize_module implementation and loop over all parameters and buffers to check whether is_fake returns True for any of the tensors. But IMO this should be provided by this library to make sure the recursive logic always matches
Additional context:
This would be used in Lightning-AI/pytorch-lightning#13868, where we want to do some config validation, as deferred initialization is not supported with the spawn strategy