-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Background:
We are auditing the Lightning components and APIs to assess opportunities for improvements:
- Review Lightning architecture & API #7740
- https://docs.google.com/document/d/1xHU7-iQSpp9KJTjI3As2EM0mfNHHr37WZYpDpwLkivA/edit#
Motivation
Lightning has the concept of a Training Type Plugin, which functions as the distribution strategy to be used during execution. For this, Lightning offers an API and out of the box implementations for data parallel-based approaches (DataParallel, DistributedDataParallel, ShardedDataParallel, DeepSpeed, etc). The trainer wraps the whole LightningModule with a module-wrapper in order to facilitate the gradient synchronization.
This allows users to easily specify distributed training options like:
trainer = Trainer(num_nodes=..., gpus=..., plugins=["..."])
module = MyLightningModule(...)
trainer.fit(module)
The situation we'd like to explicitly avoid is if the user makes successive calls to Trainer APIs like this:
trainer = Trainer(num_nodes=..., gpus=..., plugins=["..."])
module = MyLightningModule(...)
trainer.fit(module)
trainer.test(module)
potentially causing the wrapper to be applied multiple times to the LightningModule. (e.g. DistributedDataParallel(DistributedDataParallel(LightningModule)) )
Pitch
Inside of the following plugins, avoid wrapping the plugin's self._model if the model is already an instance of the wrapped type.
Plugins to update
DDP: https://github.com/PyTorchLightning/pytorch-lightning/blob/80c529351439a0f8d3d6e9449cd47d16ba3abbec/pytorch_lightning/plugins/training_type/ddp.py#L249-L256
DDP Spawn: https://github.com/PyTorchLightning/pytorch-lightning/blob/6b47cbe3ca8aa3fd82211bc9fa32e734753a6950/pytorch_lightning/plugins/training_type/ddp_spawn.py#L247-L252
Sharded Data Parallel: https://github.com/PyTorchLightning/pytorch-lightning/blob/6b47cbe3ca8aa3fd82211bc9fa32e734753a6950/pytorch_lightning/plugins/training_type/sharded.py#L37-L45
Sharded Data Parallel Spawn: https://github.com/PyTorchLightning/pytorch-lightning/blob/6b47cbe3ca8aa3fd82211bc9fa32e734753a6950/pytorch_lightning/plugins/training_type/sharded_spawn.py#L36-L41
Alternatives
Additional context
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
-
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
-
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.