-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Proposed refactoring or deprecation
When working on #9987 and the corresponding refactors around the spawn plugins, I realized that the logic around the multiprocessing queue and how results are handled and returned to the trainer is quite outdated and overly complicated. This logic has outlived many changes, but we never saw a way to make it simpler. With this issue I propose several steps towards a clean and easy to follow, easy to debug code path for
- spawning processes
- moving results from the subprocess back to the main process
- returning the result back to the trainer
Motivation
There are several components in the DDPSpawn plugin around spawning processes and handling of results that are obscure and not well documented.
On top of that, result handling bleeds into the TrainingType base class
and also into the trainer:
This is quite confusing to anyone not familiar with the peculiarities of ddp spawn. But it does not have to be that way. The situation can be drastically improved!
Pitch
Step 1
Remove the self.mp_queue attribute from the plugin. It is not required and can be locally created and used within the recently introduced DDPSpawnPlugin.spawn method #10018
Step 2
Instead of adding items like last_path, best_path, or results to the queue one by one, add all data at once as one result tuple to the queue.
This logic
becomes
# inside the spawned processes
def new_process(self, ...):
# ...
return best_path, last_path, results
# in spawn wrapper:
def _wrapped_function(...):
result = function(*args, **kwargs)
if self.is_global_zero:
queue.put(move_data_to_device(result, "cpu")) # the tuple of data
# in main process:
def spawn(self, ...):
queue = SimpleQueue()
mp.spawn(self._wrapped_function, args=(function, args, kwargs, mp_queue), nprocs=self.num_processes)
return mp_queue.get() # the tuple of dataThis allows us to standardize and limit the queue to a single put() and correspondingly a single get. This is less error prone and easier to understand for everyone working with custom plugins. The only really complicated part where a learning curve is steep for the reader is this code snippet above. Everything else will extremely simplify.
Step 3
With 1) and 2) in place, we can directly return the results from the spawned function instead of caching it in the attribute self._results.
Step 4
Finally, we can get rid of dispatch and post dispatch
and combine it into a single plugin.run call or alike:
self.training_type_plugin.run(self.train) # or whatever other trainer methodThis then cleanly generalizes across all plugins. The confusing concept of dispatch and post dispatch is gone.
Step 5
Proposed by @ananthsub, the next step would be to directly spawn processes with the Trainer.fit() call. This how we do it in Lite as well:
The benefits of this last step are ultimately (#10059 (comment)):
- consistent execution flow between spawn and non-spawn versions since _run is shared across both
- No longer will it be required to exclude some hooks or change hook order based on the plugin spawn type. Example: setup() hook
- No longer regressions with Loggers being accidentally created on the main process for spawn plugins. With step 5, we will no longer need to artificially delay any logger.experiment calls.
- we can move up initialization of distributed init for spawn plugins because processes are spawned earlier. this means no subtle difference in implementations of collectives, which @four4fish found out the tedious & difficult way with Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly Deprecate
LightningDistributedand keep logic in ddp/ddpSpawn directly #9691
Nexts steps
A quick and dirty draft is available here in form of a PR (excluding some steps): #10034
Additional context
The add_to_queue and get_from_queue methods were recently introduced, initially on the LightningModule and now they are in a deprecation phase. We would need to incorporate them into this design as well. Suggestions welcome.
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
-
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
-
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.