|
15 | 15 | import os |
16 | 16 | import re |
17 | 17 | import time |
18 | | -from typing import Any, Dict, List, Optional, Union |
| 18 | +from multiprocessing.queues import SimpleQueue |
| 19 | +from typing import Any, Callable, Dict, List, Optional, Union |
19 | 20 |
|
20 | 21 | import torch |
21 | 22 | import torch.multiprocessing as mp |
@@ -148,17 +149,9 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: |
148 | 149 | def set_world_ranks(self, process_idx: int = 0) -> None: |
149 | 150 | pass |
150 | 151 |
|
151 | | - def new_process(self, process_idx: int, trainer, mp_queue) -> None: |
| 152 | + def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: |
152 | 153 | self.mp_queue = mp_queue |
153 | 154 |
|
154 | | - reset_seed() |
155 | | - |
156 | | - self.tpu_local_core_rank = xm.get_local_ordinal() |
157 | | - self.tpu_global_core_rank = xm.get_ordinal() |
158 | | - |
159 | | - # set warning rank |
160 | | - rank_zero_only.rank = self.global_rank |
161 | | - |
162 | 155 | if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: |
163 | 156 | trainer.progress_bar_callback.disable() |
164 | 157 |
|
@@ -261,26 +254,31 @@ def _close_logger(self, trainer) -> None: |
261 | 254 | if trainer.logger is not None: |
262 | 255 | trainer.logger.finalize("success") |
263 | 256 |
|
264 | | - def get_mp_spawn_kwargs(self, trainer: "pl.Trainer") -> dict: |
| 257 | + def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: |
265 | 258 | return { |
266 | | - "args": (trainer, self.mp_queue), |
267 | 259 | "nprocs": len(self.parallel_devices), |
268 | 260 | "start_method": self.start_method, |
269 | 261 | } |
270 | 262 |
|
| 263 | + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None: |
| 264 | + xmp.spawn(self._wrapped_function, args=(function, args, kwargs), **self.get_mp_spawn_kwargs()) |
| 265 | + |
| 266 | + def _worker_setup(self, process_idx: int): |
| 267 | + reset_seed() |
| 268 | + self.tpu_local_core_rank = xm.get_local_ordinal() |
| 269 | + self.tpu_global_core_rank = xm.get_ordinal() |
| 270 | + rank_zero_only.rank = self.global_rank |
| 271 | + |
271 | 272 | def start_training(self, trainer: "pl.Trainer") -> None: |
272 | 273 | # todo: precision pluging is call in accelerator setup and should be moved |
273 | 274 | if "XLA_USE_BF16" in os.environ: |
274 | 275 | del os.environ["XLA_USE_BF16"] |
275 | 276 | self._close_logger(trainer) |
276 | | - xmp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) |
| 277 | + return super().start_training(trainer) |
277 | 278 |
|
278 | 279 | def start_evaluating(self, trainer: "pl.Trainer") -> None: |
279 | 280 | self._close_logger(trainer) |
280 | | - xmp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) |
281 | | - |
282 | | - def start_predicting(self, trainer: "pl.Trainer") -> None: |
283 | | - xmp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) |
| 281 | + return super().start_evaluating(trainer) |
284 | 282 |
|
285 | 283 | def training_step(self, *args, **kwargs): |
286 | 284 | return self.model(*args, **kwargs) |
|
0 commit comments