Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Utility to check if a module needs to be materialized #49

@carmocca

Description

@carmocca

Is your feature request related to a problem? Please describe:
Add a function to check whether a nn.Module could be materialized.

Describe the solution you would like:

from torchdistx.deferred_init import is_deferred

implementation:

from typing import Mapping, Optional, Union

from torch import Tensor
from torch.nn import Module, Parameter

def is_deferred(module: Module) -> bool:
    def any_fake(tensors: Mapping[str, Optional[Union[Tensor, Parameter]]]) -> bool:
        return any(is_fake(t) for t in tensors.values() if t is not None)

    is_deferred = any(_is_deferred(m) for m in module.children())
    return is_deferred or any_fake(module._parameters) or any_fake(module._buffers)

Describe the alternatives you have considered:
I could copy the materialize_module implementation and loop over all parameters and buffers to check whether is_fake returns True for any of the tensors. But IMO this should be provided by this library to make sure the recursive logic always matches

Additional context:
This would be used in Lightning-AI/pytorch-lightning#13868, where we want to do some config validation, as deferred initialization is not supported with the spawn strategy

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions