diff --git a/CHANGELOG.md b/CHANGELOG.md index b85f3c76cad46..29585a0e82ac8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -201,7 +201,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - LightningLite: * Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988)) - * Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018)) + * Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018), [#10022](https://github.com/PyTorchLightning/pytorch-lightning/pull/10022)) * Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994)) ### Changed @@ -500,6 +500,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Remove deprecated `distributed_backend` from `Trainer` ([#10017](https://github.com/PyTorchLightning/pytorch-lightning/pull/10017)) +- Removed `process_idx` from the `{DDPSpawnPlugin,TPUSpawnPlugin}.new_process` methods ([#10022](https://github.com/PyTorchLightning/pytorch-lightning/pull/10022)) + + ### Fixed diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 55e62aade809d..e71d1c64f8d65 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -15,7 +15,8 @@ import os import re import time -from typing import Any, Dict, List, Optional, Union +from multiprocessing.queues import SimpleQueue +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.multiprocessing as mp @@ -148,17 +149,9 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: def set_world_ranks(self, process_idx: int = 0) -> None: pass - def new_process(self, process_idx: int, trainer, mp_queue) -> None: + def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: self.mp_queue = mp_queue - reset_seed() - - self.tpu_local_core_rank = xm.get_local_ordinal() - self.tpu_global_core_rank = xm.get_ordinal() - - # set warning rank - rank_zero_only.rank = self.global_rank - if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: trainer.progress_bar_callback.disable() @@ -261,26 +254,31 @@ def _close_logger(self, trainer) -> None: if trainer.logger is not None: trainer.logger.finalize("success") - def get_mp_spawn_kwargs(self, trainer: "pl.Trainer") -> dict: + def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return { - "args": (trainer, self.mp_queue), "nprocs": len(self.parallel_devices), "start_method": self.start_method, } + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None: + xmp.spawn(self._wrapped_function, args=(function, args, kwargs), **self.get_mp_spawn_kwargs()) + + def _worker_setup(self, process_idx: int): + reset_seed() + self.tpu_local_core_rank = xm.get_local_ordinal() + self.tpu_global_core_rank = xm.get_ordinal() + rank_zero_only.rank = self.global_rank + def start_training(self, trainer: "pl.Trainer") -> None: # todo: precision pluging is call in accelerator setup and should be moved if "XLA_USE_BF16" in os.environ: del os.environ["XLA_USE_BF16"] self._close_logger(trainer) - xmp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) + return super().start_training(trainer) def start_evaluating(self, trainer: "pl.Trainer") -> None: self._close_logger(trainer) - xmp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) - - def start_predicting(self, trainer: "pl.Trainer") -> None: - xmp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) + return super().start_evaluating(trainer) def training_step(self, *args, **kwargs): return self.model(*args, **kwargs)