Skip to content

Extended support for wrapping container modules (ModuleList, ModuleDict, Sequential) in Fabric #18427

@awaelchli

Description

@awaelchli

Description & Motivation

Fabric's setup methods wrap the user's nn.Module in a FabricModule. This one forwards all calls to the original module with a few exceptions where we add additional logic before and after.

There are three additional special module types in PyTorch: Sequential, ModuleList, and ModuleDict. These have additional methods like __getitem__ for advanced indexing. The FabricModule does not support handling these:

module = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU())
fabric_module = fabric.setup(module)
module[0]  # works
fabric_module[0]  # fails

module = torch.nn.ModuleDict({"module_key": torch.nn.Linear(1, 1)})
fabric_module = fabric.setup(module)
module["module_key"]  # works
fabric_module["module_key"]  # fails

Pitch

Implement the dunder methods (or a subset of them) for the FabricModule wrapper.

Alternatives

Keep as is, but document that the desired functionality can be accessed via fabric_module.module.

Additional context

No response

cc @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    fabriclightning.fabric.FabricfeatureIs an improvement or enhancement

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions