|
13 | 13 | # limitations under the License. |
14 | 14 | import os |
15 | 15 | import time |
| 16 | +from functools import wraps |
16 | 17 | from multiprocessing.queues import SimpleQueue |
17 | | -from typing import Any, Callable, Optional, TYPE_CHECKING |
| 18 | +from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING |
18 | 19 |
|
19 | 20 | import torch.multiprocessing as mp |
| 21 | +from torch.multiprocessing import ProcessContext |
20 | 22 |
|
21 | 23 | import pytorch_lightning as pl |
22 | 24 | from pytorch_lightning.strategies.launchers.multiprocessing import _FakeQueue, _MultiProcessingLauncher, _WorkerOutput |
|
26 | 28 | from pytorch_lightning.utilities.rank_zero import rank_zero_debug |
27 | 29 |
|
28 | 30 | if _TPU_AVAILABLE: |
| 31 | + import torch_xla.core.xla_model as xm |
29 | 32 | import torch_xla.distributed.xla_multiprocessing as xmp |
30 | 33 | else: |
31 | | - xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 |
| 34 | + xm, xmp = None, None |
32 | 35 |
|
33 | 36 | if TYPE_CHECKING: |
34 | 37 | from pytorch_lightning.strategies import Strategy |
@@ -72,7 +75,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] |
72 | 75 | """ |
73 | 76 | context = mp.get_context(self._start_method) |
74 | 77 | return_queue = context.SimpleQueue() |
75 | | - xmp.spawn( |
| 78 | + _save_spawn( |
76 | 79 | self._wrapping_function, |
77 | 80 | args=(trainer, function, args, kwargs, return_queue), |
78 | 81 | nprocs=len(self._strategy.parallel_devices), |
@@ -103,14 +106,6 @@ def _wrapping_function( |
103 | 106 | if self._strategy.local_rank == 0: |
104 | 107 | return_queue.put(move_data_to_device(results, "cpu")) |
105 | 108 |
|
106 | | - # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 |
107 | | - self._strategy.barrier("end-process") |
108 | | - |
109 | | - # Ensure that the rank 0 process is the one exiting last |
110 | | - # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 |
111 | | - if self._strategy.local_rank == 0: |
112 | | - time.sleep(2) |
113 | | - |
114 | 109 | def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]: |
115 | 110 | rank_zero_debug("Collecting results from rank 0 process.") |
116 | 111 | checkpoint_callback = trainer.checkpoint_callback |
@@ -138,3 +133,30 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt |
138 | 133 | self.add_to_queue(trainer, extra) |
139 | 134 |
|
140 | 135 | return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra) |
| 136 | + |
| 137 | + |
| 138 | +def _save_spawn( |
| 139 | + fn: Callable, |
| 140 | + args: Tuple = (), |
| 141 | + nprocs: Optional[int] = None, |
| 142 | + join: bool = True, |
| 143 | + daemon: bool = False, |
| 144 | + start_method: str = "spawn", |
| 145 | +) -> Optional[ProcessContext]: |
| 146 | + """Wraps the :func:`torch_xla.distributed.xla_multiprocessing.spawn` with added teardown logic for the worker |
| 147 | + processes.""" |
| 148 | + |
| 149 | + @wraps(fn) |
| 150 | + def wrapped(rank: int, *_args: Any) -> None: |
| 151 | + fn(rank, *_args) |
| 152 | + |
| 153 | + # Make all processes wait for each other before joining |
| 154 | + # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 |
| 155 | + xm.rendezvous("end-process") |
| 156 | + |
| 157 | + # Ensure that the rank 0 process is the one exiting last |
| 158 | + # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 |
| 159 | + if rank == 0: |
| 160 | + time.sleep(1) |
| 161 | + |
| 162 | + return xmp.spawn(wrapped, args=args, nprocs=nprocs, join=join, daemon=daemon, start_method=start_method) |
0 commit comments