Skip to content

Commit 98cb7e8

Browse files
awaelchlitchaton
andauthored
1/n Simplify spawn plugins: Simplify handling of multiprocessing queue (#10034)
Co-authored-by: thomas chaton <[email protected]>
1 parent 541b983 commit 98cb7e8

File tree

8 files changed

+122
-125
lines changed

8 files changed

+122
-125
lines changed

CHANGELOG.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8080
- Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649))
8181

8282

83-
-
83+
- The `DDPSpawnPlugin` no longer overrides the `post_dispatch` plugin hook ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))
84+
85+
86+
- The `LightningModule.{add_to_queue,get_from_queue}` hooks no longer get a `torch.multiprocessing.SimpleQueue` and instead receive a list based queue ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))
87+
8488

8589
### Deprecated
8690

@@ -188,6 +192,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
188192
- Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867))
189193

190194

195+
- Removed the property `TrainingTypePlugin.results` and corresponding properties in subclasses ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))
196+
197+
198+
- Removed the `mp_queue` attribute from `DDPSpawnPlugin` and `TPUSpawnPlugin` ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034))
199+
200+
191201
- Removed unnessesary `_move_optimizer_state` method overrides from `TPUSpawnPlugin` and `SingleTPUPlugin` ([#10849](https://github.com/PyTorchLightning/pytorch-lightning/pull/10849))
192202

193203

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,7 +1917,7 @@ def model_size(self) -> float:
19171917
)
19181918
return get_model_size_mb(self)
19191919

1920-
def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
1920+
def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None:
19211921
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
19221922
sharing, we cast the data to numpy.
19231923
@@ -1931,7 +1931,7 @@ def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
19311931
if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin):
19321932
self.trainer.training_type_plugin.add_to_queue(self.trainer, queue)
19331933

1934-
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
1934+
def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None:
19351935
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
19361936
we cast back the data to ``torch.Tensor``.
19371937

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 65 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
import logging
1515
import os
1616
import re
17+
from collections import UserList
1718
from multiprocessing.queues import SimpleQueue
18-
from typing import Any, Callable, Dict, List, Optional, Union
19+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union
1920

2021
import numpy as np
2122
import torch
@@ -45,7 +46,7 @@
4546
from pytorch_lightning.utilities.enums import _StrategyType
4647
from pytorch_lightning.utilities.model_helpers import is_overridden
4748
from pytorch_lightning.utilities.seed import reset_seed
48-
from pytorch_lightning.utilities.types import STEP_OUTPUT
49+
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
4950

5051
if _TORCH_GREATER_EQUAL_1_8:
5152
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
@@ -80,7 +81,6 @@ def __init__(
8081
self.sync_batchnorm = False
8182
self._ddp_kwargs = kwargs
8283
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
83-
self.mp_queue = None
8484
self._ddp_comm_state = ddp_comm_state
8585
self._ddp_comm_hook = ddp_comm_hook
8686
self._ddp_comm_wrapper = ddp_comm_wrapper
@@ -101,15 +101,6 @@ def num_nodes(self, num_nodes: int) -> None:
101101
def local_rank(self) -> int:
102102
return self._local_rank
103103

104-
def __getstate__(self):
105-
"""Makes this plugin pickleable without destroying the queue in the current process."""
106-
state = self.__dict__.copy()
107-
state["mp_queue"] = None
108-
return state
109-
110-
def __setstate__(self, state):
111-
self.__dict__ = state
112-
113104
@property
114105
def root_device(self):
115106
return self.parallel_devices[self.local_rank]
@@ -125,9 +116,6 @@ def _is_single_process_single_device(self):
125116

126117
def setup(self, trainer: "pl.Trainer") -> None:
127118
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
128-
# pass in a state q
129-
smp = mp.get_context("spawn")
130-
self.mp_queue = smp.SimpleQueue()
131119
super().setup(trainer)
132120

133121
def _setup_model(self, model: Module) -> DistributedDataParallel:
@@ -145,18 +133,24 @@ def set_world_ranks(self, process_idx: int = 0) -> None:
145133
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
146134
return {"nprocs": self.num_processes}
147135

148-
def start_training(self, trainer: "pl.Trainer") -> None:
149-
self.spawn(self.new_process, trainer, self.mp_queue)
136+
def start_training(self, trainer: "pl.Trainer") -> Any:
137+
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
138+
self.__recover_results_in_main_process(spawn_output, trainer)
150139
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
151140
trainer.optimizers = []
141+
return spawn_output.trainer_results
152142

153-
def start_evaluating(self, trainer: "pl.Trainer") -> None:
154-
self.spawn(self.new_process, trainer, self.mp_queue)
143+
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
144+
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
145+
self.__recover_results_in_main_process(spawn_output, trainer)
146+
return spawn_output.trainer_results
155147

156-
def start_predicting(self, trainer: "pl.Trainer") -> None:
157-
self.spawn(self.new_process, trainer, self.mp_queue)
148+
def start_predicting(self, trainer: "pl.Trainer") -> Any:
149+
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
150+
self.__recover_results_in_main_process(spawn_output, trainer)
151+
return spawn_output.trainer_results
158152

159-
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]:
153+
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
160154
"""Spawn processes that run the given function.
161155
162156
Args:
@@ -191,9 +185,7 @@ def _worker_setup(self, process_idx: int):
191185
self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size
192186
)
193187

194-
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
195-
self.mp_queue = mp_queue
196-
188+
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
197189
# move the model to the correct device
198190
self.model_to_device()
199191

@@ -208,28 +200,11 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
208200
self.barrier()
209201

210202
results = trainer.run_stage()
211-
212-
# persist info in ddp_spawn
213-
self.__transfer_distrib_spawn_state_on_fit_end(trainer, results)
203+
outputs = self.__collect_rank_zero_results(trainer, results)
214204

215205
# ensure that spawned processes go through teardown before joining
216206
trainer._call_teardown_hook()
217-
218-
def post_dispatch(self, trainer: "pl.Trainer"):
219-
# restore main state with best weights
220-
best_path = self.mp_queue.get()
221-
last_path = self.mp_queue.get()
222-
self._results = self.mp_queue.get()
223-
# get the `callback_metrics` and set it to the trainer
224-
# only in case the user does not override it.
225-
# TODO: Remove the if in v1.7
226-
if is_overridden("get_from_queue", self.lightning_module):
227-
self.lightning_module.get_from_queue(self.mp_queue)
228-
else:
229-
self.get_from_queue(trainer, self.mp_queue)
230-
231-
# recover the weights of the processes trained in the children
232-
self.__recover_child_process_weights(best_path, last_path)
207+
return outputs
233208

234209
def pre_configure_ddp(self):
235210
# if unset, default `find_unused_parameters` `True`
@@ -268,7 +243,7 @@ def determine_ddp_device_ids(self):
268243
return None
269244
return [self.root_device.index]
270245

271-
def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None:
246+
def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
272247
rank_zero_warn("cleaning up ddp environment...")
273248
checkpoint_callback = trainer.checkpoint_callback
274249
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
@@ -285,28 +260,37 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul
285260
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
286261
self.checkpoint_io.save_checkpoint(state_dict, last_path)
287262

288-
# todo, pass complete checkpoint as state dictionary
289-
self.mp_queue.put(best_model_path)
290-
self.mp_queue.put(last_path)
291-
self.mp_queue.put(results)
292263
# adds the `callback_metrics` to the queue
293-
# TODO: Remove the if in v1.7
264+
extra = _FakeQueue()
294265
if is_overridden("add_to_queue", self.lightning_module):
295-
self.lightning_module.add_to_queue(self.mp_queue)
266+
# TODO: Remove the if in v1.7
267+
self.lightning_module.add_to_queue(extra)
296268
else:
297-
self.add_to_queue(trainer, self.mp_queue)
269+
self.add_to_queue(trainer, extra)
298270

299-
def __recover_child_process_weights(self, best_path, last_path):
271+
return _SpawnOutput(best_model_path, last_path, results, extra)
272+
273+
def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None:
300274
# transfer back the best path to the trainer
301275
if self.lightning_module.trainer.checkpoint_callback:
302-
self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
303-
# todo, pass also best score
276+
self.lightning_module.trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path
304277

278+
# TODO: pass also best score
305279
# load last weights
306-
if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
307-
ckpt = self.checkpoint_io.load_checkpoint(last_path, map_location=(lambda storage, loc: storage))
280+
if spawn_output.last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
281+
ckpt = self.checkpoint_io.load_checkpoint(
282+
spawn_output.last_path, map_location=(lambda storage, loc: storage)
283+
)
308284
self.lightning_module.load_state_dict(ckpt)
309285

286+
# get the `callback_metrics` and set it to the trainer
287+
if is_overridden("get_from_queue", self.lightning_module):
288+
# only in case the user does not override it.
289+
# TODO: Remove the if in v1.7
290+
self.lightning_module.get_from_queue(spawn_output.extra)
291+
else:
292+
self.get_from_queue(trainer, spawn_output.extra)
293+
310294
def barrier(self, *args, **kwargs) -> None:
311295
if not distributed_available():
312296
return
@@ -372,23 +356,25 @@ def post_training_step(self):
372356
if not self.lightning_module.automatic_optimization:
373357
self.model.require_backward_grad_sync = True
374358

375-
def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
359+
def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
376360
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
377361
sharing, we cast the data to numpy.
378362
379363
Args:
364+
trainer: reference to the Trainer.
380365
queue: the instance of the queue to append the data.
381366
"""
382367
callback_metrics: dict = apply_to_collection(
383368
trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy()
384369
) # send as numpy to avoid issues with memory sharing
385370
queue.put(callback_metrics)
386371

387-
def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None:
372+
def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
388373
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
389374
we cast back the data to ``torch.Tensor``.
390375
391376
Args:
377+
trainer: reference to the Trainer.
392378
queue: the instance of the queue from where to get the data.
393379
"""
394380
# NOTE: `add_to_queue` needs to be called before
@@ -413,3 +399,23 @@ def teardown(self) -> None:
413399
self.lightning_module.cpu()
414400
# clean up memory
415401
torch.cuda.empty_cache()
402+
403+
404+
class _FakeQueue(UserList):
405+
"""Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list."""
406+
407+
def get(self) -> Any:
408+
return self.pop(0)
409+
410+
def put(self, item: Any) -> None:
411+
self.append(item)
412+
413+
def empty(self) -> bool:
414+
return len(self) == 0
415+
416+
417+
class _SpawnOutput(NamedTuple):
418+
best_model_path: Optional[_PATH]
419+
last_path: Optional[_PATH]
420+
trainer_results: Any
421+
extra: _FakeQueue

pytorch_lightning/plugins/training_type/sharded_spawn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from multiprocessing.queues import SimpleQueue
16-
from typing import Dict, Generator, List, Optional, Tuple
15+
from typing import Any, Dict, Generator, List, Optional, Tuple
1716

1817
import torch
1918
from torch.nn import Module
2019
from torch.optim import Optimizer
2120

2221
import pytorch_lightning as pl
2322
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
24-
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
23+
from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin
2524
from pytorch_lightning.trainer.states import TrainerFn
2625
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
2726
from pytorch_lightning.utilities.enums import _StrategyType
@@ -115,12 +114,12 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None:
115114
def post_training_step(self):
116115
pass
117116

118-
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
117+
def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]:
119118
# Ensure that the scaler points to the correct process group
120119
# which is re-initialized in a new process
121120
if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin):
122121
self._precision_plugin.scaler = ShardedGradScaler()
123-
return super().new_process(trainer, mp_queue)
122+
return super().new_process(trainer)
124123

125124
@classmethod
126125
def register_plugins(cls, plugin_registry: Dict) -> None:

0 commit comments

Comments
 (0)