Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.distributed import (
Expand Down Expand Up @@ -134,7 +134,7 @@ def process_group_backend(self) -> Optional[str]:
return self._process_group_backend

def _configure_launcher(self):
self._launcher = _SpawnLauncher(self, start_method=self._start_method)
self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method)

def setup(self, trainer: "pl.Trainer") -> None:
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
Expand Down
8 changes: 4 additions & 4 deletions src/pytorch_lightning/strategies/launchers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.strategies.launchers.base import _Launcher
from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher
from pytorch_lightning.strategies.launchers.xla import _XLALauncher

__all__ = [
"_Launcher",
"_SpawnLauncher",
"_MultiProcessingLauncher",
"_SubprocessScriptLauncher",
"_XLASpawnLauncher",
"_XLALauncher",
]
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from pytorch_lightning.utilities.types import _PATH


class _SpawnLauncher(_Launcher):
r"""Spawns processes that run a given function in parallel, and joins them all at the end.
class _MultiProcessingLauncher(_Launcher):
r"""Launches processes that run a given function in parallel, and joins them all at the end.

The main process in which this launcher is invoked creates N so-called worker processes (using
:func:`torch.multiprocessing.start_processes`) that run the given function.
Expand Down Expand Up @@ -71,20 +71,20 @@ def is_interactive_compatible(self) -> bool:
return self._start_method == "fork"

def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
"""Spawns processes that run the given function in parallel.
"""Launches processes that run the given function in parallel.

The function is allowed to have a return value. However, when all processes join, only the return value
of worker process 0 gets returned from this `launch` method in the main process.

Arguments:
function: The entry point for all spawned processes.
function: The entry point for all launched processes.
*args: Optional positional arguments to be passed to the given function.
trainer: Optional reference to the :class:`~pytorch_lightning.trainer.trainer.Trainer` for which
a selected set of attributes get restored in the main process after processes join.
**kwargs: Optional keyword arguments to be passed to the given function.
"""
# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before spawning to ensure each rank will connect
# This needs to be done in the main process here before starting processes to ensure each rank will connect
# through the same port
os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port)
context = mp.get_context(self._start_method)
Expand All @@ -95,12 +95,12 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
nprocs=self._strategy.num_processes,
start_method=self._start_method,
)
spawn_output = return_queue.get()
worker_output = return_queue.get()
if trainer is None:
return spawn_output
return worker_output

self._recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results
self._recover_results_in_main_process(worker_output, trainer)
return worker_output.trainer_results

def _wrapping_function(
self,
Expand All @@ -120,25 +120,25 @@ def _wrapping_function(
if self._strategy.local_rank == 0:
return_queue.put(move_data_to_device(results, "cpu"))

def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None:
def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", trainer: "pl.Trainer") -> None:
# transfer back the best path to the trainer
if trainer.checkpoint_callback and hasattr(trainer.checkpoint_callback, "best_model_path"):
trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path)
trainer.checkpoint_callback.best_model_path = str(worker_output.best_model_path)

# TODO: pass also best score
# load last weights
if spawn_output.weights_path is not None:
ckpt = self._strategy.checkpoint_io.load_checkpoint(spawn_output.weights_path)
if worker_output.weights_path is not None:
ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type]
self._strategy.checkpoint_io.remove_checkpoint(spawn_output.weights_path)
self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)

trainer.state = spawn_output.trainer_state
trainer.state = worker_output.trainer_state

# get the `callback_metrics` and set it to the trainer
self.get_from_queue(trainer, spawn_output.extra)
self.get_from_queue(trainer, worker_output.extra)

def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_debug("Finalizing the DDP spawn environment.")
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
rank_zero_debug("Collecting results from rank 0 process.")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = (
checkpoint_callback.best_model_path
Expand All @@ -162,7 +162,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
extra = _FakeQueue()
self.add_to_queue(trainer, extra)

return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)

def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
Expand Down Expand Up @@ -203,7 +203,7 @@ def empty(self) -> bool:
return len(self) == 0


class _SpawnOutput(NamedTuple):
class _WorkerOutput(NamedTuple):
best_model_path: Optional[_PATH]
weights_path: Optional[_PATH]
trainer_state: TrainerState
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.multiprocessing as mp

import pytorch_lightning as pl
from pytorch_lightning.strategies.launchers.spawn import _FakeQueue, _SpawnLauncher, _SpawnOutput
from pytorch_lightning.strategies.launchers.multiprocessing import _FakeQueue, _MultiProcessingLauncher, _WorkerOutput
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.apply_func import move_data_to_device
Expand All @@ -34,8 +34,9 @@
from pytorch_lightning.strategies import Strategy


class _XLASpawnLauncher(_SpawnLauncher):
r"""Spawns processes that run a given function in parallel on XLA supported hardware, and joins them all at the end.
class _XLALauncher(_MultiProcessingLauncher):
r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the
end.

The main process in which this launcher is invoked creates N so-called worker processes (using the
`torch_xla` :func:`xmp.spawn`) that run the given function.
Expand All @@ -57,13 +58,13 @@ def is_interactive_compatible(self) -> bool:
return True

def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
"""Spawns processes that run the given function in parallel.
"""Launches processes that run the given function in parallel.

The function is allowed to have a return value. However, when all processes join, only the return value
of worker process 0 gets returned from this `launch` method in the main process.

Arguments:
function: The entry point for all spawned processes.
function: The entry point for all launched processes.
*args: Optional positional arguments to be passed to the given function.
trainer: Optional reference to the :class:`~pytorch_lightning.trainer.trainer.Trainer` for which
a selected set of attributes get restored in the main process after processes join.
Expand All @@ -77,12 +78,12 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
nprocs=len(self._strategy.parallel_devices),
start_method=self._start_method,
)
spawn_output = return_queue.get()
worker_output = return_queue.get()
if trainer is None:
return spawn_output
return worker_output

self._recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results
self._recover_results_in_main_process(worker_output, trainer)
return worker_output.trainer_results

def _wrapping_function(
self,
Expand Down Expand Up @@ -110,8 +111,8 @@ def _wrapping_function(
if self._strategy.local_rank == 0:
time.sleep(2)

def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_debug("Finalizing the TPU spawn environment.")
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
rank_zero_debug("Collecting results from rank 0 process.")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = (
checkpoint_callback.best_model_path
Expand All @@ -136,4 +137,4 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
extra = _FakeQueue()
self.add_to_queue(trainer, extra)

return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)
4 changes: 2 additions & 2 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher
from pytorch_lightning.strategies.launchers.xla import _XLALauncher
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
Expand Down Expand Up @@ -120,7 +120,7 @@ def connect(self, model: "pl.LightningModule") -> None:
return super().connect(model)

def _configure_launcher(self):
self._launcher = _XLASpawnLauncher(self)
self._launcher = _XLALauncher(self)

def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@

import pytest

from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher


@mock.patch("pytorch_lightning.strategies.launchers.spawn.mp.get_all_start_methods", return_value=[])
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
def test_spawn_launcher_forking_on_unsupported_platform(_):
with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
_SpawnLauncher(strategy=Mock(), start_method="fork")
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")


@pytest.mark.parametrize("start_method", ["spawn", "fork"])
@mock.patch("pytorch_lightning.strategies.launchers.spawn.mp")
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
def test_spawn_launcher_start_method(mp_mock, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _SpawnLauncher(strategy=Mock(), start_method=start_method)
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
launcher.launch(function=Mock())
mp_mock.get_context.assert_called_with(start_method)
mp_mock.start_processes.assert_called_with(
Expand Down
6 changes: 3 additions & 3 deletions tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from pytorch_lightning.trainer.states import TrainerFn
from tests_pytorch.helpers.runif import RunIf

Expand Down Expand Up @@ -59,7 +59,7 @@ def test_ddp_cpu():
trainer.fit(model)


class CustomSpawnLauncher(_SpawnLauncher):
class CustomMultiProcessingLauncher(_MultiProcessingLauncher):
def add_to_queue(self, trainer, queue) -> None:
queue.put("test_val")
return super().add_to_queue(trainer, queue)
Expand All @@ -71,7 +71,7 @@ def get_from_queue(self, trainer: Trainer, queue) -> None:

class TestDDPSpawnStrategy(DDPSpawnStrategy):
def _configure_launcher(self):
self._launcher = CustomSpawnLauncher(self)
self._launcher = CustomMultiProcessingLauncher(self)


@RunIf(skip_windows=True)
Expand Down