Skip to content

Commit 81f149e

Browse files
authored
Rename spawn-based launchers (#13743)
1 parent fa886f2 commit 81f149e

File tree

7 files changed

+49
-48
lines changed

7 files changed

+49
-48
lines changed

src/pytorch_lightning/strategies/ddp_spawn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3131
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
3232
from pytorch_lightning.plugins.precision import PrecisionPlugin
33-
from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher
33+
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
3434
from pytorch_lightning.strategies.parallel import ParallelStrategy
3535
from pytorch_lightning.trainer.states import TrainerFn
3636
from pytorch_lightning.utilities.distributed import (
@@ -134,7 +134,7 @@ def process_group_backend(self) -> Optional[str]:
134134
return self._process_group_backend
135135

136136
def _configure_launcher(self):
137-
self._launcher = _SpawnLauncher(self, start_method=self._start_method)
137+
self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method)
138138

139139
def setup(self, trainer: "pl.Trainer") -> None:
140140
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)

src/pytorch_lightning/strategies/launchers/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pytorch_lightning.strategies.launchers.base import _Launcher
15-
from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher
15+
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
1616
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
17-
from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher
17+
from pytorch_lightning.strategies.launchers.xla import _XLALauncher
1818

1919
__all__ = [
2020
"_Launcher",
21-
"_SpawnLauncher",
21+
"_MultiProcessingLauncher",
2222
"_SubprocessScriptLauncher",
23-
"_XLASpawnLauncher",
23+
"_XLALauncher",
2424
]

src/pytorch_lightning/strategies/launchers/spawn.py renamed to src/pytorch_lightning/strategies/launchers/multiprocessing.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
from pytorch_lightning.utilities.types import _PATH
3232

3333

34-
class _SpawnLauncher(_Launcher):
35-
r"""Spawns processes that run a given function in parallel, and joins them all at the end.
34+
class _MultiProcessingLauncher(_Launcher):
35+
r"""Launches processes that run a given function in parallel, and joins them all at the end.
3636
3737
The main process in which this launcher is invoked creates N so-called worker processes (using
3838
:func:`torch.multiprocessing.start_processes`) that run the given function.
@@ -71,20 +71,20 @@ def is_interactive_compatible(self) -> bool:
7171
return self._start_method == "fork"
7272

7373
def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
74-
"""Spawns processes that run the given function in parallel.
74+
"""Launches processes that run the given function in parallel.
7575
7676
The function is allowed to have a return value. However, when all processes join, only the return value
7777
of worker process 0 gets returned from this `launch` method in the main process.
7878
7979
Arguments:
80-
function: The entry point for all spawned processes.
80+
function: The entry point for all launched processes.
8181
*args: Optional positional arguments to be passed to the given function.
8282
trainer: Optional reference to the :class:`~pytorch_lightning.trainer.trainer.Trainer` for which
8383
a selected set of attributes get restored in the main process after processes join.
8484
**kwargs: Optional keyword arguments to be passed to the given function.
8585
"""
8686
# The default cluster environment in Lightning chooses a random free port number
87-
# This needs to be done in the main process here before spawning to ensure each rank will connect
87+
# This needs to be done in the main process here before starting processes to ensure each rank will connect
8888
# through the same port
8989
os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port)
9090
context = mp.get_context(self._start_method)
@@ -95,12 +95,12 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
9595
nprocs=self._strategy.num_processes,
9696
start_method=self._start_method,
9797
)
98-
spawn_output = return_queue.get()
98+
worker_output = return_queue.get()
9999
if trainer is None:
100-
return spawn_output
100+
return worker_output
101101

102-
self._recover_results_in_main_process(spawn_output, trainer)
103-
return spawn_output.trainer_results
102+
self._recover_results_in_main_process(worker_output, trainer)
103+
return worker_output.trainer_results
104104

105105
def _wrapping_function(
106106
self,
@@ -120,25 +120,25 @@ def _wrapping_function(
120120
if self._strategy.local_rank == 0:
121121
return_queue.put(move_data_to_device(results, "cpu"))
122122

123-
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None:
123+
def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", trainer: "pl.Trainer") -> None:
124124
# transfer back the best path to the trainer
125125
if trainer.checkpoint_callback and hasattr(trainer.checkpoint_callback, "best_model_path"):
126-
trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path)
126+
trainer.checkpoint_callback.best_model_path = str(worker_output.best_model_path)
127127

128128
# TODO: pass also best score
129129
# load last weights
130-
if spawn_output.weights_path is not None:
131-
ckpt = self._strategy.checkpoint_io.load_checkpoint(spawn_output.weights_path)
130+
if worker_output.weights_path is not None:
131+
ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
132132
trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type]
133-
self._strategy.checkpoint_io.remove_checkpoint(spawn_output.weights_path)
133+
self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)
134134

135-
trainer.state = spawn_output.trainer_state
135+
trainer.state = worker_output.trainer_state
136136

137137
# get the `callback_metrics` and set it to the trainer
138-
self.get_from_queue(trainer, spawn_output.extra)
138+
self.get_from_queue(trainer, worker_output.extra)
139139

140-
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
141-
rank_zero_debug("Finalizing the DDP spawn environment.")
140+
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
141+
rank_zero_debug("Collecting results from rank 0 process.")
142142
checkpoint_callback = trainer.checkpoint_callback
143143
best_model_path = (
144144
checkpoint_callback.best_model_path
@@ -162,7 +162,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
162162
extra = _FakeQueue()
163163
self.add_to_queue(trainer, extra)
164164

165-
return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
165+
return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)
166166

167167
def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
168168
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
@@ -203,7 +203,7 @@ def empty(self) -> bool:
203203
return len(self) == 0
204204

205205

206-
class _SpawnOutput(NamedTuple):
206+
class _WorkerOutput(NamedTuple):
207207
best_model_path: Optional[_PATH]
208208
weights_path: Optional[_PATH]
209209
trainer_state: TrainerState

src/pytorch_lightning/strategies/launchers/xla_spawn.py renamed to src/pytorch_lightning/strategies/launchers/xla.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch.multiprocessing as mp
2020

2121
import pytorch_lightning as pl
22-
from pytorch_lightning.strategies.launchers.spawn import _FakeQueue, _SpawnLauncher, _SpawnOutput
22+
from pytorch_lightning.strategies.launchers.multiprocessing import _FakeQueue, _MultiProcessingLauncher, _WorkerOutput
2323
from pytorch_lightning.trainer.states import TrainerFn
2424
from pytorch_lightning.utilities import _TPU_AVAILABLE
2525
from pytorch_lightning.utilities.apply_func import move_data_to_device
@@ -34,8 +34,9 @@
3434
from pytorch_lightning.strategies import Strategy
3535

3636

37-
class _XLASpawnLauncher(_SpawnLauncher):
38-
r"""Spawns processes that run a given function in parallel on XLA supported hardware, and joins them all at the end.
37+
class _XLALauncher(_MultiProcessingLauncher):
38+
r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the
39+
end.
3940
4041
The main process in which this launcher is invoked creates N so-called worker processes (using the
4142
`torch_xla` :func:`xmp.spawn`) that run the given function.
@@ -57,13 +58,13 @@ def is_interactive_compatible(self) -> bool:
5758
return True
5859

5960
def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
60-
"""Spawns processes that run the given function in parallel.
61+
"""Launches processes that run the given function in parallel.
6162
6263
The function is allowed to have a return value. However, when all processes join, only the return value
6364
of worker process 0 gets returned from this `launch` method in the main process.
6465
6566
Arguments:
66-
function: The entry point for all spawned processes.
67+
function: The entry point for all launched processes.
6768
*args: Optional positional arguments to be passed to the given function.
6869
trainer: Optional reference to the :class:`~pytorch_lightning.trainer.trainer.Trainer` for which
6970
a selected set of attributes get restored in the main process after processes join.
@@ -77,12 +78,12 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
7778
nprocs=len(self._strategy.parallel_devices),
7879
start_method=self._start_method,
7980
)
80-
spawn_output = return_queue.get()
81+
worker_output = return_queue.get()
8182
if trainer is None:
82-
return spawn_output
83+
return worker_output
8384

84-
self._recover_results_in_main_process(spawn_output, trainer)
85-
return spawn_output.trainer_results
85+
self._recover_results_in_main_process(worker_output, trainer)
86+
return worker_output.trainer_results
8687

8788
def _wrapping_function(
8889
self,
@@ -110,8 +111,8 @@ def _wrapping_function(
110111
if self._strategy.local_rank == 0:
111112
time.sleep(2)
112113

113-
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
114-
rank_zero_debug("Finalizing the TPU spawn environment.")
114+
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
115+
rank_zero_debug("Collecting results from rank 0 process.")
115116
checkpoint_callback = trainer.checkpoint_callback
116117
best_model_path = (
117118
checkpoint_callback.best_model_path
@@ -136,4 +137,4 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
136137
extra = _FakeQueue()
137138
self.add_to_queue(trainer, extra)
138139

139-
return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
140+
return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)

src/pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
2828
from pytorch_lightning.plugins.precision import PrecisionPlugin
2929
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
30-
from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher
30+
from pytorch_lightning.strategies.launchers.xla import _XLALauncher
3131
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
3232
from pytorch_lightning.trainer.states import TrainerFn
3333
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
@@ -120,7 +120,7 @@ def connect(self, model: "pl.LightningModule") -> None:
120120
return super().connect(model)
121121

122122
def _configure_launcher(self):
123-
self._launcher = _XLASpawnLauncher(self)
123+
self._launcher = _XLALauncher(self)
124124

125125
def setup(self, trainer: "pl.Trainer") -> None:
126126
self.accelerator.setup(trainer)

tests/tests_pytorch/strategies/launchers/test_spawn.py renamed to tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,20 @@
1616

1717
import pytest
1818

19-
from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher
19+
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
2020

2121

22-
@mock.patch("pytorch_lightning.strategies.launchers.spawn.mp.get_all_start_methods", return_value=[])
22+
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
2323
def test_spawn_launcher_forking_on_unsupported_platform(_):
2424
with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
25-
_SpawnLauncher(strategy=Mock(), start_method="fork")
25+
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")
2626

2727

2828
@pytest.mark.parametrize("start_method", ["spawn", "fork"])
29-
@mock.patch("pytorch_lightning.strategies.launchers.spawn.mp")
29+
@mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp")
3030
def test_spawn_launcher_start_method(mp_mock, start_method):
3131
mp_mock.get_all_start_methods.return_value = [start_method]
32-
launcher = _SpawnLauncher(strategy=Mock(), start_method=start_method)
32+
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
3333
launcher.launch(function=Mock())
3434
mp_mock.get_context.assert_called_with(start_method)
3535
mp_mock.start_processes.assert_called_with(

tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pytorch_lightning import LightningModule, Trainer
2424
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
2525
from pytorch_lightning.strategies import DDPSpawnStrategy
26-
from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher
26+
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
2727
from pytorch_lightning.trainer.states import TrainerFn
2828
from tests_pytorch.helpers.runif import RunIf
2929

@@ -59,7 +59,7 @@ def test_ddp_cpu():
5959
trainer.fit(model)
6060

6161

62-
class CustomSpawnLauncher(_SpawnLauncher):
62+
class CustomMultiProcessingLauncher(_MultiProcessingLauncher):
6363
def add_to_queue(self, trainer, queue) -> None:
6464
queue.put("test_val")
6565
return super().add_to_queue(trainer, queue)
@@ -71,7 +71,7 @@ def get_from_queue(self, trainer: Trainer, queue) -> None:
7171

7272
class TestDDPSpawnStrategy(DDPSpawnStrategy):
7373
def _configure_launcher(self):
74-
self._launcher = CustomSpawnLauncher(self)
74+
self._launcher = CustomMultiProcessingLauncher(self)
7575

7676

7777
@RunIf(skip_windows=True)

0 commit comments

Comments
 (0)