Skip to content

Commit a4083df

Browse files
authored
2/n Simplify spawn plugins: Spawn immediately (#10896)
1 parent 3fcfd02 commit a4083df

File tree

8 files changed

+77
-146
lines changed

8 files changed

+77
-146
lines changed

CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9696
- Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934))
9797

9898

99+
- Redesigned process creation for spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896))
100+
* All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}`
101+
* The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts
102+
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
103+
104+
99105
### Deprecated
100106

101107
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))
@@ -239,7 +245,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
239245
- Removed method `training_step`, `test_step`, `validation_step` and `predict_step` from the `Accelerator` ([#10890](https://github.com/PyTorchLightning/pytorch-lightning/pull/10890))
240246

241247

242-
- Removed `HorovodPlugin.start_{training,evaluating,predicting}` hooks ([#10989](https://github.com/PyTorchLightning/pytorch-lightning/pull/10989))
248+
- Removed `TrainingTypePlugin.start_{training,evaluating,predicting}` hooks and the same in all subclasses ([#10989](https://github.com/PyTorchLightning/pytorch-lightning/pull/10989), [#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896))
243249

244250

245251
- Removed `Accelerator.on_train_start` ([#10999](https://github.com/PyTorchLightning/pytorch-lightning/pull/10999))

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -132,23 +132,6 @@ def set_world_ranks(self, process_idx: int = 0) -> None:
132132
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
133133
return {"nprocs": self.num_processes}
134134

135-
def start_training(self, trainer: "pl.Trainer") -> Any:
136-
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
137-
self._recover_results_in_main_process(spawn_output, trainer)
138-
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
139-
trainer.optimizers = []
140-
return spawn_output.trainer_results
141-
142-
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
143-
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
144-
self._recover_results_in_main_process(spawn_output, trainer)
145-
return spawn_output.trainer_results
146-
147-
def start_predicting(self, trainer: "pl.Trainer") -> Any:
148-
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
149-
self._recover_results_in_main_process(spawn_output, trainer)
150-
return spawn_output.trainer_results
151-
152135
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
153136
"""Spawn processes that run the given function.
154137
@@ -184,7 +167,9 @@ def _worker_setup(self, process_idx: int):
184167
self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size
185168
)
186169

187-
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
170+
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
171+
super().pre_dispatch(trainer)
172+
188173
# move the model to the correct device
189174
self.model_to_device()
190175

@@ -196,15 +181,6 @@ def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
196181
if trainer_fn == TrainerFn.FITTING:
197182
self.configure_ddp()
198183

199-
self.barrier()
200-
201-
results = trainer.run_stage()
202-
outputs = self._collect_rank_zero_results(trainer, results)
203-
204-
# ensure that spawned processes go through teardown before joining
205-
trainer._call_teardown_hook()
206-
return outputs
207-
208184
def pre_configure_ddp(self):
209185
# if unset, default `find_unused_parameters` `True`
210186
# Many models require setting this parameter to True, as there are corner cases
@@ -268,7 +244,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
268244

269245
return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
270246

271-
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None:
247+
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None:
272248
# transfer back the best path to the trainer
273249
if trainer.checkpoint_callback:
274250
trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path

pytorch_lightning/plugins/training_type/sharded_spawn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import pytorch_lightning as pl
2222
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
23-
from pytorch_lightning.plugins.training_type.ddp_spawn import _SpawnOutput, DDPSpawnPlugin
23+
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
2424
from pytorch_lightning.trainer.states import TrainerFn
2525
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
2626
from pytorch_lightning.utilities.enums import _StrategyType
@@ -114,12 +114,12 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None:
114114
def post_training_step(self):
115115
pass
116116

117-
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
117+
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
118118
# Ensure that the scaler points to the correct process group
119119
# which is re-initialized in a new process
120120
if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin):
121121
self._precision_plugin.scaler = ShardedGradScaler()
122-
return super().new_process(trainer)
122+
return super().pre_dispatch(trainer)
123123

124124
@classmethod
125125
def register_plugins(cls, plugin_registry: Dict) -> None:

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 21 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from torch.utils.data import DataLoader
2424

2525
import pytorch_lightning as pl
26-
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
2726
from pytorch_lightning.overrides import LightningDistributedModule
2827
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2928
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
@@ -118,10 +117,23 @@ def connect(self, model: "pl.LightningModule") -> None:
118117
return super().connect(model)
119118

120119
def pre_dispatch(self, trainer: "pl.Trainer") -> None:
121-
super().pre_dispatch(trainer)
120+
self._move_optimizer_state()
122121
if self.debug:
123122
os.environ["PT_XLA_DEBUG"] = str(1)
124123

124+
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
125+
trainer.progress_bar_callback.disable()
126+
127+
shared_params = find_shared_parameters(self.model)
128+
self.model_to_device()
129+
if is_overridden("on_post_move_to_device", self.lightning_module):
130+
self.model.module.on_post_move_to_device()
131+
else:
132+
set_shared_parameters(self.model.module, shared_params)
133+
134+
self.setup_optimizers(trainer)
135+
self.precision_plugin.connect(self._model, None, None)
136+
125137
def setup(self, trainer: "pl.Trainer") -> None:
126138
self.start_method = "fork"
127139
super().setup(trainer)
@@ -154,37 +166,6 @@ def init_dist_connection(self, global_rank: int, world_size: int) -> None:
154166
def set_world_ranks(self, process_idx: int = 0) -> None:
155167
pass
156168

157-
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
158-
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
159-
trainer.progress_bar_callback.disable()
160-
161-
shared_params = find_shared_parameters(self.model)
162-
self.model_to_device()
163-
if is_overridden("on_post_move_to_device", self.lightning_module):
164-
self.model.module.on_post_move_to_device()
165-
else:
166-
set_shared_parameters(self.model.module, shared_params)
167-
168-
trainer.training_type_plugin.setup_optimizers(trainer)
169-
trainer.precision_plugin.connect(self._model, None, None)
170-
171-
self.barrier("pre-run-stage")
172-
173-
results = trainer.run_stage()
174-
175-
outputs = self._collect_rank_zero_results(trainer, results)
176-
177-
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
178-
self.barrier("end-process")
179-
180-
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
181-
if self.local_rank == 0:
182-
time.sleep(2)
183-
184-
# ensure that spawned processes go through teardown before joining
185-
trainer._call_teardown_hook()
186-
return outputs
187-
188169
def model_to_device(self) -> None:
189170
self.model = self.wrapped_model.to(self.root_device)
190171

@@ -215,8 +196,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
215196
if is_overridden("add_to_queue", self.lightning_module):
216197
# TODO: Remove the if in v1.7
217198
self.lightning_module.add_to_queue(extra)
218-
else:
219-
self.add_to_queue(trainer, extra)
199+
self.add_to_queue(trainer, extra)
220200

221201
return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
222202

@@ -263,6 +243,9 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st
263243
}
264244

265245
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
246+
# todo: precision pluging is call in accelerator setup and should be moved
247+
if "XLA_USE_BF16" in os.environ:
248+
del os.environ["XLA_USE_BF16"]
266249
context = mp.get_context(self.start_method or "fork")
267250
return_queue = context.SimpleQueue()
268251
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
@@ -276,7 +259,10 @@ def _wrapped_function(
276259
if self.local_rank == 0:
277260
return_queue.put(move_data_to_device(result, "cpu"))
278261

262+
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
279263
self.barrier("end-process")
264+
265+
# Ensure that the rank 0 process is the one exiting last
280266
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
281267
if self.local_rank == 0:
282268
time.sleep(2)
@@ -287,21 +273,6 @@ def _worker_setup(self, process_idx: int):
287273
self.tpu_global_core_rank = xm.get_ordinal()
288274
rank_zero_only.rank = self.global_rank
289275

290-
def start_training(self, trainer: "pl.Trainer") -> Any:
291-
# todo: precision pluging is call in accelerator setup and should be moved
292-
if "XLA_USE_BF16" in os.environ:
293-
del os.environ["XLA_USE_BF16"]
294-
self._clean_logger(trainer)
295-
return super().start_training(trainer)
296-
297-
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
298-
self._clean_logger(trainer)
299-
return super().start_evaluating(trainer)
300-
301-
def start_predicting(self, trainer: "pl.Trainer") -> Any:
302-
self._clean_logger(trainer)
303-
return super().start_predicting(trainer)
304-
305276
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
306277
with self.precision_plugin.val_step_context():
307278
return self.model(*args, **kwargs)
@@ -358,9 +329,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
358329
return xm.all_gather(tensor)
359330

360331
def teardown(self) -> None:
361-
# TPU teardown
362332
os.environ.pop("PT_XLA_DEBUG", None)
363-
self.barrier("teardown")
364333

365334
@property
366335
def should_rank_save_checkpoint(self) -> bool:
@@ -377,13 +346,3 @@ def checkpoint_io(self) -> CheckpointIO:
377346
@checkpoint_io.setter
378347
def checkpoint_io(self, plugin: CheckpointIO) -> None:
379348
raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.")
380-
381-
@staticmethod
382-
def _clean_logger(trainer: "pl.Trainer") -> None:
383-
loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger]
384-
for logger in loggers:
385-
if isinstance(logger, TensorBoardLogger) and logger._experiment is not None:
386-
# the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang.
387-
# we want to make sure these are closed before we spawn our own threads.
388-
# assuming nothing else references the experiment object, python should instantly `__del__` it.
389-
logger._experiment = None

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -307,18 +307,6 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
307307
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
308308
optimizer.load_state_dict(opt_state)
309309

310-
def start_training(self, trainer: "pl.Trainer") -> Any:
311-
# double dispatch to initiate the training loop
312-
return trainer.run_stage()
313-
314-
def start_evaluating(self, trainer: "pl.Trainer") -> Any:
315-
# double dispatch to initiate the test loop
316-
return trainer.run_stage()
317-
318-
def start_predicting(self, trainer: "pl.Trainer") -> Any:
319-
# double dispatch to initiate the predicting loop
320-
return trainer.run_stage()
321-
322310
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
323311
"""The actual training step.
324312

0 commit comments

Comments
 (0)