Skip to content

Simplify multiprocessing logic in DDPSpawn plugins #10059

@awaelchli

Description

@awaelchli

Proposed refactoring or deprecation

When working on #9987 and the corresponding refactors around the spawn plugins, I realized that the logic around the multiprocessing queue and how results are handled and returned to the trainer is quite outdated and overly complicated. This logic has outlived many changes, but we never saw a way to make it simpler. With this issue I propose several steps towards a clean and easy to follow, easy to debug code path for

  1. spawning processes
  2. moving results from the subprocess back to the main process
  3. returning the result back to the trainer

Motivation

There are several components in the DDPSpawn plugin around spawning processes and handling of results that are obscure and not well documented.
On top of that, result handling bleeds into the TrainingType base class

https://github.com/PyTorchLightning/pytorch-lightning/blob/aa1540410ff55854e050ff117c2d66f22741d182/pytorch_lightning/plugins/training_type/training_type_plugin.py#L38

and also into the trainer:

https://github.com/PyTorchLightning/pytorch-lightning/blob/aa1540410ff55854e050ff117c2d66f22741d182/pytorch_lightning/trainer/trainer.py#L1123-L1125

This is quite confusing to anyone not familiar with the peculiarities of ddp spawn. But it does not have to be that way. The situation can be drastically improved!

Pitch

Step 1

Remove the self.mp_queue attribute from the plugin. It is not required and can be locally created and used within the recently introduced DDPSpawnPlugin.spawn method #10018

Step 2

Instead of adding items like last_path, best_path, or results to the queue one by one, add all data at once as one result tuple to the queue.

This logic

https://github.com/PyTorchLightning/pytorch-lightning/blob/aa1540410ff55854e050ff117c2d66f22741d182/pytorch_lightning/plugins/training_type/ddp_spawn.py#L224-L238

becomes

# inside the spawned processes
def new_process(self, ...):
    # ...
    return best_path, last_path, results

# in spawn wrapper:
def _wrapped_function(...):
    result = function(*args, **kwargs)
    if self.is_global_zero:
        queue.put(move_data_to_device(result, "cpu"))  # the tuple of data

# in main process:
def spawn(self, ...):
    queue = SimpleQueue()
    mp.spawn(self._wrapped_function, args=(function, args, kwargs, mp_queue), nprocs=self.num_processes)
    return mp_queue.get()  # the tuple of data

This allows us to standardize and limit the queue to a single put() and correspondingly a single get. This is less error prone and easier to understand for everyone working with custom plugins. The only really complicated part where a learning curve is steep for the reader is this code snippet above. Everything else will extremely simplify.

Step 3

With 1) and 2) in place, we can directly return the results from the spawned function instead of caching it in the attribute self._results.

Step 4

Finally, we can get rid of dispatch and post dispatch

https://github.com/PyTorchLightning/pytorch-lightning/blob/aa1540410ff55854e050ff117c2d66f22741d182/pytorch_lightning/trainer/trainer.py#L1102-L1107

and combine it into a single plugin.run call or alike:

self.training_type_plugin.run(self.train)  # or whatever other trainer method

This then cleanly generalizes across all plugins. The confusing concept of dispatch and post dispatch is gone.

Step 5

Proposed by @ananthsub, the next step would be to directly spawn processes with the Trainer.fit() call. This how we do it in Lite as well:

https://github.com/PyTorchLightning/pytorch-lightning/blob/412d507a73c79f5e4f7ef14289cefe2eb2af94a7/pytorch_lightning/lite/lite.py#L387-L396

The benefits of this last step are ultimately (#10059 (comment)):

  • consistent execution flow between spawn and non-spawn versions since _run is shared across both
    • No longer will it be required to exclude some hooks or change hook order based on the plugin spawn type. Example: setup() hook
    • No longer regressions with Loggers being accidentally created on the main process for spawn plugins. With step 5, we will no longer need to artificially delay any logger.experiment calls.
  • we can move up initialization of distributed init for spawn plugins because processes are spawned earlier. this means no subtle difference in implementations of collectives, which @four4fish found out the tedious & difficult way with Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly Deprecate LightningDistributed and keep logic in ddp/ddpSpawn directly #9691

Nexts steps

A quick and dirty draft is available here in form of a PR (excluding some steps): #10034

Additional context

The add_to_queue and get_from_queue methods were recently introduced, initially on the LightningModule and now they are in a deprecation phase. We would need to incorporate them into this design as well. Suggestions welcome.


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

Metadata

Metadata

Assignees

No one assigned

    Labels

    designIncludes a design discussiondiscussionIn a discussion stagerefactor

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions