1414import logging
1515import os
1616import re
17+ from collections import UserList
1718from multiprocessing .queues import SimpleQueue
18- from typing import Any , Callable , Dict , List , Optional , Union
19+ from typing import Any , Callable , Dict , List , NamedTuple , Optional , Union
1920
2021import numpy as np
2122import torch
4546from pytorch_lightning .utilities .enums import _StrategyType
4647from pytorch_lightning .utilities .model_helpers import is_overridden
4748from pytorch_lightning .utilities .seed import reset_seed
48- from pytorch_lightning .utilities .types import STEP_OUTPUT
49+ from pytorch_lightning .utilities .types import _PATH , STEP_OUTPUT
4950
5051if _TORCH_GREATER_EQUAL_1_8 :
5152 from pytorch_lightning .utilities .distributed import register_ddp_comm_hook
@@ -80,7 +81,6 @@ def __init__(
8081 self .sync_batchnorm = False
8182 self ._ddp_kwargs = kwargs
8283 self .num_processes = len (parallel_devices ) if parallel_devices is not None else 0
83- self .mp_queue = None
8484 self ._ddp_comm_state = ddp_comm_state
8585 self ._ddp_comm_hook = ddp_comm_hook
8686 self ._ddp_comm_wrapper = ddp_comm_wrapper
@@ -101,15 +101,6 @@ def num_nodes(self, num_nodes: int) -> None:
101101 def local_rank (self ) -> int :
102102 return self ._local_rank
103103
104- def __getstate__ (self ):
105- """Makes this plugin pickleable without destroying the queue in the current process."""
106- state = self .__dict__ .copy ()
107- state ["mp_queue" ] = None
108- return state
109-
110- def __setstate__ (self , state ):
111- self .__dict__ = state
112-
113104 @property
114105 def root_device (self ):
115106 return self .parallel_devices [self .local_rank ]
@@ -125,9 +116,6 @@ def _is_single_process_single_device(self):
125116
126117 def setup (self , trainer : "pl.Trainer" ) -> None :
127118 os .environ ["MASTER_PORT" ] = str (self .cluster_environment .main_port )
128- # pass in a state q
129- smp = mp .get_context ("spawn" )
130- self .mp_queue = smp .SimpleQueue ()
131119 super ().setup (trainer )
132120
133121 def _setup_model (self , model : Module ) -> DistributedDataParallel :
@@ -145,18 +133,24 @@ def set_world_ranks(self, process_idx: int = 0) -> None:
145133 def get_mp_spawn_kwargs (self , trainer : Optional ["pl.Trainer" ] = None ) -> Dict [str , Any ]:
146134 return {"nprocs" : self .num_processes }
147135
148- def start_training (self , trainer : "pl.Trainer" ) -> None :
149- self .spawn (self .new_process , trainer , self .mp_queue )
136+ def start_training (self , trainer : "pl.Trainer" ) -> Any :
137+ spawn_output : _SpawnOutput = self .spawn (self .new_process , trainer )
138+ self .__recover_results_in_main_process (spawn_output , trainer )
150139 # reset optimizers, since main process is never used for training and thus does not have a valid optim state
151140 trainer .optimizers = []
141+ return spawn_output .trainer_results
152142
153- def start_evaluating (self , trainer : "pl.Trainer" ) -> None :
154- self .spawn (self .new_process , trainer , self .mp_queue )
143+ def start_evaluating (self , trainer : "pl.Trainer" ) -> Any :
144+ spawn_output : _SpawnOutput = self .spawn (self .new_process , trainer )
145+ self .__recover_results_in_main_process (spawn_output , trainer )
146+ return spawn_output .trainer_results
155147
156- def start_predicting (self , trainer : "pl.Trainer" ) -> None :
157- self .spawn (self .new_process , trainer , self .mp_queue )
148+ def start_predicting (self , trainer : "pl.Trainer" ) -> Any :
149+ spawn_output : _SpawnOutput = self .spawn (self .new_process , trainer )
150+ self .__recover_results_in_main_process (spawn_output , trainer )
151+ return spawn_output .trainer_results
158152
159- def spawn (self , function : Callable , * args : Any , ** kwargs : Any ) -> Optional [Any ]:
153+ def spawn (self , function : Callable , * args : Any , ** kwargs : Any ) -> Optional [Union [ Any , "_SpawnOutput" ] ]:
160154 """Spawn processes that run the given function.
161155
162156 Args:
@@ -191,9 +185,7 @@ def _worker_setup(self, process_idx: int):
191185 self .cluster_environment , self .torch_distributed_backend , self .global_rank , self .world_size
192186 )
193187
194- def new_process (self , trainer : "pl.Trainer" , mp_queue : SimpleQueue ) -> None :
195- self .mp_queue = mp_queue
196-
188+ def new_process (self , trainer : "pl.Trainer" ) -> Optional ["_SpawnOutput" ]:
197189 # move the model to the correct device
198190 self .model_to_device ()
199191
@@ -208,28 +200,11 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
208200 self .barrier ()
209201
210202 results = trainer .run_stage ()
211-
212- # persist info in ddp_spawn
213- self .__transfer_distrib_spawn_state_on_fit_end (trainer , results )
203+ outputs = self .__collect_rank_zero_results (trainer , results )
214204
215205 # ensure that spawned processes go through teardown before joining
216206 trainer ._call_teardown_hook ()
217-
218- def post_dispatch (self , trainer : "pl.Trainer" ):
219- # restore main state with best weights
220- best_path = self .mp_queue .get ()
221- last_path = self .mp_queue .get ()
222- self ._results = self .mp_queue .get ()
223- # get the `callback_metrics` and set it to the trainer
224- # only in case the user does not override it.
225- # TODO: Remove the if in v1.7
226- if is_overridden ("get_from_queue" , self .lightning_module ):
227- self .lightning_module .get_from_queue (self .mp_queue )
228- else :
229- self .get_from_queue (trainer , self .mp_queue )
230-
231- # recover the weights of the processes trained in the children
232- self .__recover_child_process_weights (best_path , last_path )
207+ return outputs
233208
234209 def pre_configure_ddp (self ):
235210 # if unset, default `find_unused_parameters` `True`
@@ -268,7 +243,7 @@ def determine_ddp_device_ids(self):
268243 return None
269244 return [self .root_device .index ]
270245
271- def __transfer_distrib_spawn_state_on_fit_end (self , trainer : "pl.Trainer" , results : Any ) -> None :
246+ def __collect_rank_zero_results (self , trainer : "pl.Trainer" , results : Any ) -> Optional [ "_SpawnOutput" ] :
272247 rank_zero_warn ("cleaning up ddp environment..." )
273248 checkpoint_callback = trainer .checkpoint_callback
274249 best_model_path = checkpoint_callback .best_model_path if checkpoint_callback else None
@@ -285,28 +260,37 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul
285260 last_path = re .sub (".ckpt" , ".tmp_end.ckpt" , best_model_path )
286261 self .checkpoint_io .save_checkpoint (state_dict , last_path )
287262
288- # todo, pass complete checkpoint as state dictionary
289- self .mp_queue .put (best_model_path )
290- self .mp_queue .put (last_path )
291- self .mp_queue .put (results )
292263 # adds the `callback_metrics` to the queue
293- # TODO: Remove the if in v1.7
264+ extra = _FakeQueue ()
294265 if is_overridden ("add_to_queue" , self .lightning_module ):
295- self .lightning_module .add_to_queue (self .mp_queue )
266+ # TODO: Remove the if in v1.7
267+ self .lightning_module .add_to_queue (extra )
296268 else :
297- self .add_to_queue (trainer , self . mp_queue )
269+ self .add_to_queue (trainer , extra )
298270
299- def __recover_child_process_weights (self , best_path , last_path ):
271+ return _SpawnOutput (best_model_path , last_path , results , extra )
272+
273+ def __recover_results_in_main_process (self , spawn_output : "_SpawnOutput" , trainer ) -> None :
300274 # transfer back the best path to the trainer
301275 if self .lightning_module .trainer .checkpoint_callback :
302- self .lightning_module .trainer .checkpoint_callback .best_model_path = best_path
303- # todo, pass also best score
276+ self .lightning_module .trainer .checkpoint_callback .best_model_path = spawn_output .best_model_path
304277
278+ # TODO: pass also best score
305279 # load last weights
306- if last_path is not None and self .lightning_module .trainer .state .fn == TrainerFn .FITTING :
307- ckpt = self .checkpoint_io .load_checkpoint (last_path , map_location = (lambda storage , loc : storage ))
280+ if spawn_output .last_path is not None and self .lightning_module .trainer .state .fn == TrainerFn .FITTING :
281+ ckpt = self .checkpoint_io .load_checkpoint (
282+ spawn_output .last_path , map_location = (lambda storage , loc : storage )
283+ )
308284 self .lightning_module .load_state_dict (ckpt )
309285
286+ # get the `callback_metrics` and set it to the trainer
287+ if is_overridden ("get_from_queue" , self .lightning_module ):
288+ # only in case the user does not override it.
289+ # TODO: Remove the if in v1.7
290+ self .lightning_module .get_from_queue (spawn_output .extra )
291+ else :
292+ self .get_from_queue (trainer , spawn_output .extra )
293+
310294 def barrier (self , * args , ** kwargs ) -> None :
311295 if not distributed_available ():
312296 return
@@ -372,23 +356,25 @@ def post_training_step(self):
372356 if not self .lightning_module .automatic_optimization :
373357 self .model .require_backward_grad_sync = True
374358
375- def add_to_queue (self , trainer : "pl.Trainer" , queue : torch . multiprocessing . SimpleQueue ) -> None :
359+ def add_to_queue (self , trainer : "pl.Trainer" , queue : "_FakeQueue" ) -> None :
376360 """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
377361 sharing, we cast the data to numpy.
378362
379363 Args:
364+ trainer: reference to the Trainer.
380365 queue: the instance of the queue to append the data.
381366 """
382367 callback_metrics : dict = apply_to_collection (
383368 trainer .callback_metrics , torch .Tensor , lambda x : x .cpu ().numpy ()
384369 ) # send as numpy to avoid issues with memory sharing
385370 queue .put (callback_metrics )
386371
387- def get_from_queue (self , trainer : "pl.Trainer" , queue : torch . multiprocessing . SimpleQueue ) -> None :
372+ def get_from_queue (self , trainer : "pl.Trainer" , queue : "_FakeQueue" ) -> None :
388373 """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
389374 we cast back the data to ``torch.Tensor``.
390375
391376 Args:
377+ trainer: reference to the Trainer.
392378 queue: the instance of the queue from where to get the data.
393379 """
394380 # NOTE: `add_to_queue` needs to be called before
@@ -413,3 +399,23 @@ def teardown(self) -> None:
413399 self .lightning_module .cpu ()
414400 # clean up memory
415401 torch .cuda .empty_cache ()
402+
403+
404+ class _FakeQueue (UserList ):
405+ """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list."""
406+
407+ def get (self ) -> Any :
408+ return self .pop (0 )
409+
410+ def put (self , item : Any ) -> None :
411+ self .append (item )
412+
413+ def empty (self ) -> bool :
414+ return len (self ) == 0
415+
416+
417+ class _SpawnOutput (NamedTuple ):
418+ best_model_path : Optional [_PATH ]
419+ last_path : Optional [_PATH ]
420+ trainer_results : Any
421+ extra : _FakeQueue
0 commit comments