-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Outline & Motivation
For now, FSDP uses default implementation from Strategy to get model state dict: Strategy.lightning_module_state_dict. But, like creating an optimizer, this returns an empty ordered dict. This affects, for example, the ModelCheckpoint callback, there is no model weight in the saved checkpoints.
Pitch
We need a custom implementation to get model state dict when the user uses the FSDP strategy (add the lightning_module_state_dict method into the FSDPStrategy class). I found 2 workarounds, first one is pretty straightforward and replicates the optimizer initialization behavior:
def lightning_module_state_dict(self) -> Dict[str, Any]:
"""Returns model state."""
assert self.lightning_module.trainer.model is not None
wrapped_state_dict = self.lightning_module.trainer.model.state_dict()
return {k.replace("_forward_module.", ""): v for k, v in wrapped_state_dict}But I'm not sure if this is the correct way to get sharded state dict together. Following documentation from PyTorch, we need to use the state_dict_type context manager to unshard the model. This implementation should be like this:
def lightning_module_state_dict(self) -> Dict[str, Any]:
"""Returns model state."""
assert self.lightning_module.trainer.model is not None
model = self.lightning_module.trainer.model
full_state_dict_config = FullStateDictConfig(rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
wrapped_state_dict = sharded_module.state_dict()
return {k.replace("_forward_module.", ""): v for k, v in wrapped_state_dict}Additional context
- For both suggested implementations, I remove the prefix
_forward_modulefrom the state dict. This is due to wrapping the user model into_LightningModuleWrapperBase. - For the second option there is also an option to specify offloading into the CPU. This is useful when a model is too large to fit on a single GPU.