Skip to content
24 changes: 23 additions & 1 deletion src/lightning_lite/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lightning_lite.strategies.launchers.base import _Launcher
from lightning_lite.strategies.strategy import Strategy
from lightning_lite.utilities.apply_func import move_data_to_device
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from lightning_lite.utilities.imports import _IS_INTERACTIVE, _TORCH_GREATER_EQUAL_1_11
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states


Expand Down Expand Up @@ -82,6 +82,9 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
*args: Optional positional arguments to be passed to the given function.
**kwargs: Optional keyword arguments to be passed to the given function.
"""
if self._start_method in ("fork", "forkserver"):
_check_bad_cuda_fork()

# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
# through the same port
Expand Down Expand Up @@ -166,3 +169,22 @@ def restore(self) -> None:
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
torch.backends.cudnn.benchmark = self.cudnn_benchmark
_set_rng_states(self.rng_states)


def _check_bad_cuda_fork() -> None:
"""Checks whether it is safe to fork and initialize CUDA in the new processes, and raises an exception if not.

The error message replaces PyTorch's 'Cannot re-initialize CUDA in forked subprocess' with helpful advice for
Lightning users.
"""
if not torch.cuda.is_initialized():
return

message = (
"Lightning can't create new processes if CUDA is already initialized. Did you manually call"
" `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any"
" other way? Please remove any such calls, or change the selected strategy."
)
if _IS_INTERACTIVE:
message += " You will have to restart the Python kernel."
raise RuntimeError(message)
4 changes: 4 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).



- Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/issues/14709))



### Changed

- The `Trainer.{fit,validate,test,predict,tune}` methods now raise a useful error message if the input is not a `LightningModule` ([#13892](https://github.com/Lightning-AI/lightning/pull/13892))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import pytorch_lightning as pl
from lightning_lite.strategies.launchers.base import _Launcher
from lightning_lite.strategies.launchers.multiprocessing import _check_bad_cuda_fork
from lightning_lite.utilities import move_data_to_device
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
from lightning_lite.utilities.types import _PATH
Expand Down Expand Up @@ -90,6 +91,9 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
**kwargs: Optional keyword arguments to be passed to the given function.
"""
self._check_torchdistx_support()
if self._start_method in ("fork", "forkserver"):
_check_bad_cuda_fork()

# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
# through the same port
Expand Down
10 changes: 10 additions & 0 deletions tests/tests_lite/strategies/launchers/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,13 @@ def test_global_state_snapshot():
assert torch.are_deterministic_algorithms_enabled()
assert not torch.backends.cudnn.benchmark
assert torch.initial_seed() == 123


@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
@mock.patch("torch.cuda.is_initialized", return_value=True)
@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_check_for_bad_cuda_fork(mp_mock, _, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"):
launcher.launch(function=Mock())