Skip to content

Support serialized checkpoint loading #9406

@ananthsub

Description

@ananthsub

🚀 Feature

Motivation

Currently, all processes load the checkpoint at the same time. This can lead to CPU OOMs for large models when processes are concurrently loading the checkpoint. These use cases, especially with things like mixture of experts, might require serialized loading of checkpoint dicts across ranks (ie load the checkpoint one rank at a time per node). Could we enable this for DDP?

Prior work: #8515

Pitch

This would be controlled per training type plugin. Example pseudocode: https://gist.github.com/ananthsub/4ceedff56b2049a63bbb05ccd283b919

To work through:

  1. Should the TrainingTypePlugin have responsibility of calling LightningModule.on_load_checkpoint instead of the Trainer/connector? This would make sense as the TTP "owns" the LightningModule inside of the trainer, and since it already offers load_model_state_dict: https://github.com/PyTorchLightning/pytorch-lightning/blob/41ba639859cf6c6bf319eb33e5b3394504315962/pytorch_lightning/plugins/training_type/training_type_plugin.py#L159-L160

DeepSpeed already eschews most of the checkpoint connector logic when it comes to loading the lightning module state. This could be a gap for metrics, and this means we could be calling on_load_checkpoint multiple times with certain plugins. In my opinion, this points to needing all LightningModule state load/save/alterations sit inside of the training type plugin.

Alternatives

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

Metadata

Metadata

Labels

checkpointingRelated to checkpointingfeatureIs an improvement or enhancementhelp wantedOpen to be worked onlet's do it!approved to implement

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions