Skip to content

API change, expose model's state_dict to accelerator.training_type_plugin #7470

@shuyingsunshine21

Description

@shuyingsunshine21

🚀 Feature

Currently, in CheckpointConnector.dump_checkpoint, we have

model = self.trainer.lightning_module

checkpoint = {
    'epoch': current_epoch,
    'global_step': global_step,
    'pytorch-lightning_version': pytorch_lightning.__version__,
    'state_dict': model.state_dict(),
}

so model's state dict is extracted here. However, let accelerator.training_type_plugin control the logic might make more sense especially for sharded plugin, we might need to access the local (i.e. sharded) state instead of the whole states.

Motivation

#6152 (comment)

we would like to make customized model state dict for specific training type plugin, we could override the training_type_plugin.on_save method to modify the state dict, but this would cause duplicate call for extracting model state dict.

Pitch

define a new method for TrainingTypePlugin

def state_dict(self) -> dict:
     model = self.lightning_module
     return model.state_dict()

and in CheckpointConnector.dump_checkpoint,

checkpoint = {
    'epoch': current_epoch,
    'global_step': global_step,
    'pytorch-lightning_version': pytorch_lightning.__version__,
    'state_dict': self.trainer.accelerator.training_type_plugin.state_dict(),
}

Alternatives

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    checkpointingRelated to checkpointingfeatureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions