-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Proposed refactor
Discuss and refactor TPU/DDPSpawnPlugin.get_mp_spawn_kwargs.
Motivation
As part of our effort to converge to a stable and well organized Strategy+Accelerator API #10416 we are looking for opportunities to simplify a couple of hooks. One is the get_mp_spawn_kwargs in the Spawn-plugins.
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
return {
"nprocs": len(self.parallel_devices),
"start_method": self.start_method,
}Pitch
Proposal 1:
- Remove/deprecate the method
- Directly use
start_methodas an attribute and/or constructor argument nprocscan be set already in the constructor- A custom implementation that needs to pass additional arguments would instead override the new
spawn()method.
Proposal 2:
- Keep it
- Remove the Trainer argument
- Mark it as protected, since the method is only used internally by the Plugin
Proposal 3:
- Keep as is
Additional context
Original discussion started in #10034 by @four4fish
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.