Skip to content

Commit eb233ea

Browse files
awaelchliBorda
andauthored
Snapshot selected globals and restore them in spawned process (#13921)
Co-authored-by: Jirka Borovec <[email protected]>
1 parent 91bdacf commit eb233ea

File tree

6 files changed

+113
-10
lines changed

6 files changed

+113
-10
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
396396
- Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/Lightning-AI/lightning/pull/13845))
397397

398398

399+
- Fixed an issue causing deterministic algorighms and other globals to get reset in spawned processes ([#13921](https://github.com/Lightning-AI/lightning/pull/13921))
400+
401+
399402
- Fixed default `amp_level` for `DeepSpeedPrecisionPlugin` to `O2` ([#13897](https://github.com/PyTorchLightning/pytorch-lightning/pull/13897))
400403

401404

src/pytorch_lightning/strategies/ddp_spawn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
5151
from pytorch_lightning.utilities.optimizer import optimizers_to_device
5252
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
53-
from pytorch_lightning.utilities.seed import reset_seed
5453
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
5554

5655
log = logging.getLogger(__name__)
@@ -175,7 +174,6 @@ def set_world_ranks(self, process_idx: int = 0) -> None:
175174
rank_zero_only.rank = self.cluster_environment.global_rank()
176175

177176
def _worker_setup(self, process_idx: int) -> None:
178-
reset_seed()
179177
self.set_world_ranks(process_idx)
180178
rank_zero_only.rank = self.global_rank
181179
self._process_group_backend = self._get_process_group_backend()

src/pytorch_lightning/strategies/launchers/multiprocessing.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
# limitations under the License.
1414
import os
1515
from collections import UserList
16+
from dataclasses import dataclass
1617
from multiprocessing.queues import SimpleQueue
17-
from typing import Any, Callable, NamedTuple, Optional
18+
from typing import Any, Callable, Dict, NamedTuple, Optional
1819

1920
import numpy as np
2021
import torch
22+
import torch.backends.cudnn
2123
import torch.multiprocessing as mp
2224
from torch import Tensor
2325
from typing_extensions import Literal
@@ -27,7 +29,9 @@
2729
from pytorch_lightning.strategies.strategy import Strategy
2830
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
2931
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
32+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
3033
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
34+
from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_states
3135
from pytorch_lightning.utilities.types import _PATH
3236

3337

@@ -89,9 +93,16 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
8993
os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port)
9094
context = mp.get_context(self._start_method)
9195
return_queue = context.SimpleQueue()
96+
97+
if self._start_method == "spawn":
98+
global_states = _GlobalStateSnapshot.capture()
99+
process_args = [trainer, function, args, kwargs, return_queue, global_states]
100+
else:
101+
process_args = [trainer, function, args, kwargs, return_queue]
102+
92103
mp.start_processes(
93104
self._wrapping_function,
94-
args=(trainer, function, args, kwargs, return_queue),
105+
args=process_args,
95106
nprocs=self._strategy.num_processes,
96107
start_method=self._start_method,
97108
)
@@ -110,7 +121,10 @@ def _wrapping_function(
110121
args: Any,
111122
kwargs: Any,
112123
return_queue: SimpleQueue,
124+
global_states: Optional["_GlobalStateSnapshot"] = None,
113125
) -> None:
126+
if global_states:
127+
global_states.restore()
114128
self._strategy._worker_setup(process_idx)
115129
results = function(*args, **kwargs)
116130

@@ -209,3 +223,50 @@ class _WorkerOutput(NamedTuple):
209223
trainer_state: TrainerState
210224
trainer_results: Any
211225
extra: _FakeQueue
226+
227+
228+
@dataclass
229+
class _GlobalStateSnapshot:
230+
"""Captures a hand-selected set of (global) variables in modules and provides a way to restore them.
231+
232+
It facilitates and encapsulates the transfer of globals like PyTorch's deterministic flags or random generator state
233+
across process boundaries when launching processes with :func:`torch.multiprocessing.spawn`.
234+
235+
Example:
236+
237+
.. code-block:: python
238+
239+
# in main process
240+
snapshot = _GlobalStateSnapshot.capture()
241+
242+
# in worker process
243+
snapshot.restore()
244+
"""
245+
246+
use_deterministic_algorithms: bool
247+
use_deterministic_algorithms_warn_only: bool
248+
cudnn_benchmark: bool
249+
rng_states: Dict[str, Any]
250+
251+
@classmethod
252+
def capture(cls) -> "_GlobalStateSnapshot":
253+
"""Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker
254+
process."""
255+
warn_only = torch.is_deterministic_algorithms_warn_only_enabled() if _TORCH_GREATER_EQUAL_1_11 else False
256+
return cls(
257+
use_deterministic_algorithms=torch.are_deterministic_algorithms_enabled(),
258+
use_deterministic_algorithms_warn_only=warn_only,
259+
cudnn_benchmark=torch.backends.cudnn.benchmark,
260+
rng_states=_collect_rng_states(),
261+
)
262+
263+
def restore(self) -> None:
264+
"""Restores all globals to the values captured in the :meth:`capture` method."""
265+
if _TORCH_GREATER_EQUAL_1_11:
266+
torch.use_deterministic_algorithms(
267+
self.use_deterministic_algorithms, warn_only=self.use_deterministic_algorithms_warn_only
268+
)
269+
else:
270+
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
271+
torch.backends.cudnn.benchmark = self.cudnn_benchmark
272+
_set_rng_states(self.rng_states)

src/pytorch_lightning/strategies/launchers/xla.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
from torch.multiprocessing import ProcessContext
2222

2323
import pytorch_lightning as pl
24-
from pytorch_lightning.strategies.launchers.multiprocessing import _FakeQueue, _MultiProcessingLauncher, _WorkerOutput
24+
from pytorch_lightning.strategies.launchers.multiprocessing import (
25+
_FakeQueue,
26+
_GlobalStateSnapshot,
27+
_MultiProcessingLauncher,
28+
_WorkerOutput,
29+
)
2530
from pytorch_lightning.trainer.states import TrainerFn
2631
from pytorch_lightning.utilities import _TPU_AVAILABLE
2732
from pytorch_lightning.utilities.apply_func import move_data_to_device
@@ -96,6 +101,7 @@ def _wrapping_function(
96101
args: Any,
97102
kwargs: Any,
98103
return_queue: SimpleQueue,
104+
global_states: Optional[_GlobalStateSnapshot] = None,
99105
) -> None:
100106
self._strategy._worker_setup(process_idx)
101107
results = function(*args, **kwargs)

src/pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3838
from pytorch_lightning.utilities.optimizer import optimizers_to_device
3939
from pytorch_lightning.utilities.rank_zero import rank_zero_only
40-
from pytorch_lightning.utilities.seed import reset_seed
4140
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
4241

4342
if _TPU_AVAILABLE:
@@ -206,7 +205,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
206205

207206
def _worker_setup(self, process_idx: int):
208207
self._launched = True
209-
reset_seed()
210208
self.set_world_ranks(process_idx)
211209
rank_zero_only.rank = self.global_rank
212210

tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,20 @@
1515
from unittest.mock import ANY, Mock
1616

1717
import pytest
18+
import torch
1819

19-
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
20+
from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
2021

2122

2223
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
23-
def test_spawn_launcher_forking_on_unsupported_platform(_):
24+
def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
2425
with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
2526
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")
2627

2728

2829
@pytest.mark.parametrize("start_method", ["spawn", "fork"])
2930
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
30-
def test_spawn_launcher_start_method(mp_mock, start_method):
31+
def test_multiprocessing_launcher_start_method(mp_mock, start_method):
3132
mp_mock.get_all_start_methods.return_value = [start_method]
3233
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
3334
launcher.launch(function=Mock())
@@ -38,3 +39,39 @@ def test_spawn_launcher_start_method(mp_mock, start_method):
3839
nprocs=ANY,
3940
start_method=start_method,
4041
)
42+
43+
44+
@pytest.mark.parametrize("start_method", ["spawn", "fork"])
45+
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
46+
def test_multiprocessing_launcher_restore_globals(mp_mock, start_method):
47+
"""Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
48+
mp_mock.get_all_start_methods.return_value = [start_method]
49+
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
50+
launcher.launch(function=Mock())
51+
function_args = mp_mock.start_processes.call_args[1]["args"]
52+
if start_method == "spawn":
53+
assert len(function_args) == 6
54+
assert isinstance(function_args[5], _GlobalStateSnapshot)
55+
else:
56+
assert len(function_args) == 5
57+
58+
59+
def test_global_state_snapshot():
60+
"""Test the capture() and restore() methods for the global state snapshot."""
61+
torch.use_deterministic_algorithms(True)
62+
torch.backends.cudnn.benchmark = False
63+
torch.manual_seed(123)
64+
65+
# capture the state of globals
66+
snapshot = _GlobalStateSnapshot.capture()
67+
68+
# simulate there is a process boundary and flags get reset here
69+
torch.use_deterministic_algorithms(False)
70+
torch.backends.cudnn.benchmark = True
71+
torch.manual_seed(321)
72+
73+
# restore the state of globals
74+
snapshot.restore()
75+
assert torch.are_deterministic_algorithms_enabled()
76+
assert not torch.backends.cudnn.benchmark
77+
assert torch.initial_seed() == 123

0 commit comments

Comments
 (0)