-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
checkpointingRelated to checkpointingRelated to checkpointingfeatureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on
Milestone
Description
🚀 Feature
Currently, in CheckpointConnector.dump_checkpoint, we have
model = self.trainer.lightning_module
checkpoint = {
'epoch': current_epoch,
'global_step': global_step,
'pytorch-lightning_version': pytorch_lightning.__version__,
'state_dict': model.state_dict(),
}
so model's state dict is extracted here. However, let accelerator.training_type_plugin control the logic might make more sense especially for sharded plugin, we might need to access the local (i.e. sharded) state instead of the whole states.
Motivation
we would like to make customized model state dict for specific training type plugin, we could override the training_type_plugin.on_save method to modify the state dict, but this would cause duplicate call for extracting model state dict.
Pitch
define a new method for TrainingTypePlugin
def state_dict(self) -> dict:
model = self.lightning_module
return model.state_dict()
and in CheckpointConnector.dump_checkpoint,
checkpoint = {
'epoch': current_epoch,
'global_step': global_step,
'pytorch-lightning_version': pytorch_lightning.__version__,
'state_dict': self.trainer.accelerator.training_type_plugin.state_dict(),
}
Alternatives
Additional context
SeanNaren
Metadata
Metadata
Assignees
Labels
checkpointingRelated to checkpointingRelated to checkpointingfeatureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on