-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Motivation
Currently, all processes load the checkpoint at the same time. This can lead to CPU OOMs for large models when processes are concurrently loading the checkpoint. These use cases, especially with things like mixture of experts, might require serialized loading of checkpoint dicts across ranks (ie load the checkpoint one rank at a time per node). Could we enable this for DDP?
Prior work: #8515
Pitch
This would be controlled per training type plugin. Example pseudocode: https://gist.github.com/ananthsub/4ceedff56b2049a63bbb05ccd283b919
To work through:
- Should the TrainingTypePlugin have responsibility of calling
LightningModule.on_load_checkpointinstead of the Trainer/connector? This would make sense as the TTP "owns" the LightningModule inside of the trainer, and since it already offersload_model_state_dict: https://github.com/PyTorchLightning/pytorch-lightning/blob/41ba639859cf6c6bf319eb33e5b3394504315962/pytorch_lightning/plugins/training_type/training_type_plugin.py#L159-L160
DeepSpeed already eschews most of the checkpoint connector logic when it comes to loading the lightning module state. This could be a gap for metrics, and this means we could be calling on_load_checkpoint multiple times with certain plugins. In my opinion, this points to needing all LightningModule state load/save/alterations sit inside of the training type plugin.
Alternatives
Additional context
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
-
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
-
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.