-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Extrude an interface that allows us to abstract and disentangle the process creation / spawning from plugins.
Motivation
The simplifications that #10059 introduced brought the DDPSpawnPlugin and DDPPlugin closer together in their function, execution order and API. The fundamental difference between the two however remains in how the processes are created.
DDPSpawnPlugin
The spawning logic in DDPSpawn comprises mainly these three methods:
https://github.com/PyTorchLightning/pytorch-lightning/blob/aeb0b5595fd73d086f4ae0f99d3f1f112f6a4c29/pytorch_lightning/plugins/training_type/ddp_spawn.py#L152
https://github.com/PyTorchLightning/pytorch-lightning/blob/aeb0b5595fd73d086f4ae0f99d3f1f112f6a4c29/pytorch_lightning/plugins/training_type/ddp_spawn.py#L245
https://github.com/PyTorchLightning/pytorch-lightning/blob/aeb0b5595fd73d086f4ae0f99d3f1f112f6a4c29/pytorch_lightning/plugins/training_type/ddp_spawn.py#L271
DDPPlugin
As with the spawn plugin, the creation of subprocesses is quite strongly decoupled in a single method in the DDPPlugin:
The Trainer today (after #10896) has to differentiate between the two and call them differently:
if isinstance(self.training_type_plugin, DDPSpawnPlugin):
spawn_output = self.training_type_plugin.spawn(trainer_fn, *args, **kwargs)
self.training_type_plugin._recover_results_in_main_process(spawn_output, self)
return spawn_output.trainer_results
else:
return trainer_fn(*args, **kwargs)Here, the plugin type check leaks into the trainer. This and the fact that the spawning logic is quite isolated inside the respective plugins motivates a refactor that separates them. Two designs have been proposed so far.
Pitch
Proposal 1 (@ananthsub):
class AbstractSpawn(ABC):
@abstractmethod
def spawn(self, ...)
@abstractmethod
def collect_rank_zero_results(...):
pass
@abstractmethod
def recover_results_in_main_process(...):
pass
class DDPSpawnPlugin(ParallelPlugin, AbstractSpawn):
def spawn(self, function, *args, **kwargs):
...
def recover_results_in_main_process(...):
pass
def collect_rank_zero_results(...):
passIn this proposal, the Trainer call reduces to:
if isinstance(self.training_type_plugin, AbstractSpawn):
...
else:
...Proposal 2 (@awaelchli):
class Executor(ABC):
def create_processes(...):
...
class ScriptExecutor(Executor):
# calls script in subprocesses like in current DDPPlugin
class SpawnExecutor(Executor):
# spawns processes from Trainer function like in DDPSpawnPlugin
# draft implementation
def create_processes(fn):
# trainer reference up for debate
output = self._spawn(self._wrap(fn))
return self.recover_results_in_main_process(trainer, output)
def _wrap(fn):
fn() # run it
return self.collect_rank_zero_results()The plugins would then own an instance of this executor. The DDPPlugin and DDPSpawnPlugin would collapse to a single class, for the sake of demonstration call it DDPNew, and it owns either a ScriptExecutor or a SpawnExecutor:
class DDPNew(ParallelPlugin):
def __init__(..., executor: Executor)
self.checkpoint_io = ...
self.executor = executorAlternatives
Additional context
At this point a very open discussion. The proposal may be updated depending on the feedback and discussions.
#10896 (comment)
Thanks @ananthsub for kicking off the discussion.
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @Borda @tchaton @justusschock @awaelchli @kaushikb11 @akihironitta