Skip to content

Commit e0c83ee

Browse files
authored
Update TPUSpawnPlugin spawn methods (#10022)
1 parent e44921e commit e0c83ee

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
207207

208208
- LightningLite:
209209
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
210-
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018))
210+
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018), [#10022](https://github.com/PyTorchLightning/pytorch-lightning/pull/10022))
211211
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
212212
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
213213
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
@@ -508,6 +508,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
508508
- Remove deprecated `distributed_backend` from `Trainer` ([#10017](https://github.com/PyTorchLightning/pytorch-lightning/pull/10017))
509509

510510

511+
- Removed `process_idx` from the `{DDPSpawnPlugin,TPUSpawnPlugin}.new_process` methods ([#10022](https://github.com/PyTorchLightning/pytorch-lightning/pull/10022))
512+
513+
511514
### Fixed
512515

513516

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import os
1616
import re
1717
import time
18-
from typing import Any, Dict, List, Optional, Union
18+
from multiprocessing.queues import SimpleQueue
19+
from typing import Any, Callable, Dict, List, Optional, Union
1920

2021
import torch
2122
import torch.multiprocessing as mp
@@ -148,17 +149,9 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
148149
def set_world_ranks(self, process_idx: int = 0) -> None:
149150
pass
150151

151-
def new_process(self, process_idx: int, trainer, mp_queue) -> None:
152+
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
152153
self.mp_queue = mp_queue
153154

154-
reset_seed()
155-
156-
self.tpu_local_core_rank = xm.get_local_ordinal()
157-
self.tpu_global_core_rank = xm.get_ordinal()
158-
159-
# set warning rank
160-
rank_zero_only.rank = self.global_rank
161-
162155
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
163156
trainer.progress_bar_callback.disable()
164157

@@ -261,26 +254,31 @@ def _close_logger(self, trainer) -> None:
261254
if trainer.logger is not None:
262255
trainer.logger.finalize("success")
263256

264-
def get_mp_spawn_kwargs(self, trainer: "pl.Trainer") -> dict:
257+
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
265258
return {
266-
"args": (trainer, self.mp_queue),
267259
"nprocs": len(self.parallel_devices),
268260
"start_method": self.start_method,
269261
}
270262

263+
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None:
264+
xmp.spawn(self._wrapped_function, args=(function, args, kwargs), **self.get_mp_spawn_kwargs())
265+
266+
def _worker_setup(self, process_idx: int):
267+
reset_seed()
268+
self.tpu_local_core_rank = xm.get_local_ordinal()
269+
self.tpu_global_core_rank = xm.get_ordinal()
270+
rank_zero_only.rank = self.global_rank
271+
271272
def start_training(self, trainer: "pl.Trainer") -> None:
272273
# todo: precision pluging is call in accelerator setup and should be moved
273274
if "XLA_USE_BF16" in os.environ:
274275
del os.environ["XLA_USE_BF16"]
275276
self._close_logger(trainer)
276-
xmp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
277+
return super().start_training(trainer)
277278

278279
def start_evaluating(self, trainer: "pl.Trainer") -> None:
279280
self._close_logger(trainer)
280-
xmp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
281-
282-
def start_predicting(self, trainer: "pl.Trainer") -> None:
283-
xmp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
281+
return super().start_evaluating(trainer)
284282

285283
def training_step(self, *args, **kwargs):
286284
return self.model(*args, **kwargs)

0 commit comments

Comments
 (0)