Skip to content

Proper extraction of state_dict for fsdp strategy #16526

@SpirinEgor

Description

@SpirinEgor

Outline & Motivation

For now, FSDP uses default implementation from Strategy to get model state dict: Strategy.lightning_module_state_dict. But, like creating an optimizer, this returns an empty ordered dict. This affects, for example, the ModelCheckpoint callback, there is no model weight in the saved checkpoints.

Pitch

We need a custom implementation to get model state dict when the user uses the FSDP strategy (add the lightning_module_state_dict method into the FSDPStrategy class). I found 2 workarounds, first one is pretty straightforward and replicates the optimizer initialization behavior:

def lightning_module_state_dict(self) -> Dict[str, Any]:
    """Returns model state."""
    assert self.lightning_module.trainer.model is not None
    wrapped_state_dict = self.lightning_module.trainer.model.state_dict()
    return {k.replace("_forward_module.", ""): v for k, v in wrapped_state_dict}

But I'm not sure if this is the correct way to get sharded state dict together. Following documentation from PyTorch, we need to use the state_dict_type context manager to unshard the model. This implementation should be like this:

def lightning_module_state_dict(self) -> Dict[str, Any]:
    """Returns model state."""
    assert self.lightning_module.trainer.model is not None
    model = self.lightning_module.trainer.model
    full_state_dict_config = FullStateDictConfig(rank0_only=True)
    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
        wrapped_state_dict = sharded_module.state_dict()
    return {k.replace("_forward_module.", ""): v for k, v in wrapped_state_dict}

Additional context

  1. For both suggested implementations, I remove the prefix _forward_module from the state dict. This is due to wrapping the user model into _LightningModuleWrapperBase.
  2. For the second option there is also an option to specify offloading into the CPU. This is useful when a model is too large to fit on a single GPU.

cc @awaelchli @carmocca

Metadata

Metadata

Labels

bugSomething isn't workingcheckpointingRelated to checkpointingstrategy: fsdpFully Sharded Data Parallel

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions