-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Proposed refactoring or deprecation
Directly call TrainingTypePlugin APIs instead of going through the Accelerator wherever possible
@four4fish @awaelchli @justusschock @SeanNaren
Motivation
This carries forward the discussion from #9373 (comment)
Most of the Accelerator class today is a shell class that delegates calls to its attached TrainingTypePlugin. This creates an unnecessary level of indirection in many places. It also creates doubt as to whether custom accelerators should override these functions or not.
As most of the strategy around model distribution is embedded in the training type plugin, this is the hub where the following logic lives:
- Rank information
- Which ranks conduct IO for checkpoint saving/loading
- Control/Ownership of the LightningModule
- Collective communications
However, the accelerator is positioned as the gateway component the trainer interacts with for this functionality. In turn, much of the logic of the training type plugin is currently replicated on the accelerator. This creates an undesirable coupling (we're nearly doubling the APIs exposed). We could cut out this level of indirection by having the trainer call the training type plugin directly wherever applicable. This would shrink the accelerator interface. Ultimately, this will allow it to live as a component in the training type plugin eventually too. In this case, the accelerator can manage the device logic as part of the overall parallelization strategy.
Pitch
Have the Trainer directly call the training type plugin APIs for these methods and then deprecate/remove the corresponding APIs from the accelerator
Additional context
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
-
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
-
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.