-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
fabriclightning.fabric.Fabriclightning.fabric.FabricfeatureIs an improvement or enhancementIs an improvement or enhancement
Milestone
Description
Description & Motivation
Fabric's setup methods wrap the user's nn.Module in a FabricModule. This one forwards all calls to the original module with a few exceptions where we add additional logic before and after.
There are three additional special module types in PyTorch: Sequential, ModuleList, and ModuleDict. These have additional methods like __getitem__ for advanced indexing. The FabricModule does not support handling these:
module = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU())
fabric_module = fabric.setup(module)
module[0] # works
fabric_module[0] # fails
module = torch.nn.ModuleDict({"module_key": torch.nn.Linear(1, 1)})
fabric_module = fabric.setup(module)
module["module_key"] # works
fabric_module["module_key"] # failsPitch
Implement the dunder methods (or a subset of them) for the FabricModule wrapper.
Alternatives
Keep as is, but document that the desired functionality can be accessed via fabric_module.module.
Additional context
No response
cc @Borda
carmocca and TheophileBlard
Metadata
Metadata
Assignees
Labels
fabriclightning.fabric.Fabriclightning.fabric.FabricfeatureIs an improvement or enhancementIs an improvement or enhancement