diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ff6693823b586..098956a703a8a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1962,7 +1962,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: pl.strategies.ddp_spawn._FakeQueue) -> None: + def add_to_queue(self, queue: pl.strategies.launchers.spawn._FakeQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1970,11 +1970,10 @@ def add_to_queue(self, queue: pl.strategies.ddp_spawn._FakeQueue) -> None: queue: the instance of the queue to append the data. .. deprecated:: v1.5 - This method was deprecated in v1.5 in favor of `DDPSpawnStrategy.add_to_queue` - and will be removed in v1.7. + This method was deprecated in v1.5 and will be removed in v1.7. """ - def get_from_queue(self, queue: pl.strategies.ddp_spawn._FakeQueue) -> None: + def get_from_queue(self, queue: pl.strategies.launchers.spawn._FakeQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. @@ -1982,8 +1981,7 @@ def get_from_queue(self, queue: pl.strategies.ddp_spawn._FakeQueue) -> None: queue: the instance of the queue from where to get the data. .. deprecated:: v1.5 - This method was deprecated in v1.5 in favor of `DDPSpawnStrategy.get_from_queue` - and will be removed in v1.7. + This method was deprecated in v1.5 and will be removed in v1.7. """ @contextmanager diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 95beb85b1cdad..218decdddd969 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -27,7 +27,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.plugins import PLUGIN_INPUT -from pytorch_lightning.strategies import DDPSpawnStrategy, DeepSpeedStrategy, Strategy, TPUSpawnStrategy +from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy, TPUSpawnStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_data_to_device @@ -399,17 +399,16 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) return seed_everything(seed=seed, workers=workers) def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: - self._strategy.setup_environment() - # apply sharded context to prevent OOM - run_method = partial(self._run_with_sharded_context, run_method) + run_method = partial(self._run_with_strategy_setup, run_method) - if isinstance(self._strategy, DDPSpawnStrategy): - return self._strategy.spawn(run_method, *args, **kwargs) + if self._strategy.launcher is not None: + return self._strategy.launcher.launch(run_method, *args, **kwargs) else: return run_method(*args, **kwargs) - def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: + def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: + self._strategy.setup_environment() with self._strategy.model_sharded_context(), _replace_dataloader_init_method(): return run_method(*args, **kwargs) diff --git a/pytorch_lightning/strategies/__init__.py b/pytorch_lightning/strategies/__init__.py index 205d4acb8c115..f06edfa53ec7a 100644 --- a/pytorch_lightning/strategies/__init__.py +++ b/pytorch_lightning/strategies/__init__.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pathlib import Path from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401 diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index feff575719ad4..6444e489c9b2f 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -15,16 +15,11 @@ import os import shutil import signal -import subprocess -import sys import tempfile import time from pathlib import Path -from time import sleep from typing import Any, Dict, List, Optional, Union -import __main__ -import numpy as np import torch import torch.distributed from torch.nn import Module @@ -37,11 +32,11 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( _FAIRSCALE_AVAILABLE, - _HYDRA_AVAILABLE, _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, @@ -58,9 +53,6 @@ if _FAIRSCALE_AVAILABLE: from fairscale.optim import OSS -if _HYDRA_AVAILABLE: - from hydra.core.hydra_config import HydraConfig - from hydra.utils import get_original_cwd, to_absolute_path if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook @@ -69,11 +61,7 @@ class DDPStrategy(ParallelStrategy): - """Plugin for multi-process single-device training on one or multiple nodes. - - The main process in each node spawns N-1 child processes via :func:`subprocess.Popen`, where N is the number of - devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes. - """ + """Strategy for multi-process single-device training on one or multiple nodes.""" distributed_backend = _StrategyType.DDP @@ -98,7 +86,6 @@ def __init__( precision_plugin=precision_plugin, ) log.detail(f"{self.__class__.__name__}: initializing DDP plugin") - self.interactive_ddp_procs = [] self._num_nodes = 1 self.sync_batchnorm = False self._ddp_kwargs = kwargs @@ -108,7 +95,7 @@ def __init__( self._model_averaging_period = model_averaging_period self._pids: Optional[List[int]] = None self._sync_dir: Optional[str] = None - self._rank_0_has_called_call_children_scripts: bool = False + self._rank_0_will_call_children_scripts: bool = False self.set_world_ranks() @property @@ -142,18 +129,19 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self) -> bool: return True - def setup_environment(self) -> None: - # start the other scripts + def _configure_launcher(self) -> None: + self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) if not self.cluster_environment.creates_processes_externally: - self._call_children_scripts() + self._rank_0_will_call_children_scripts = True + def setup_environment(self) -> None: self.setup_distributed() super().setup_environment() def setup(self, trainer: "pl.Trainer") -> None: super().setup(trainer) # share ddp pids to all processes - self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts) + self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) if self._should_run_deadlock_detection(): self._share_information_to_prevent_deadlock() @@ -174,68 +162,6 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: log.detail(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) - def _call_children_scripts(self): - # bookkeeping of spawned processes - self._check_can_spawn_children() - - # DDP Environment variables - os.environ["MASTER_ADDR"] = self.cluster_environment.main_address - os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) - - # allow the user to pass the node rank - os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) - os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) - - # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c` - # See https://docs.python.org/3/reference/import.html#main-spec - if __main__.__spec__ is None: # pragma: no-cover - # Script called as `python a/b/c.py` - # when user is using hydra find the absolute path - path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path - - # pull out the commands used to run the script and resolve the abs file path - command = sys.argv - try: - full_path = path_lib(command[0]) - except Exception: - full_path = os.path.abspath(command[0]) - - command[0] = full_path - # use the same python interpreter and actually running - command = [sys.executable] + command - else: # Script called as `python -m a.b.c` - command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:] - - os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}" - - self.interactive_ddp_procs = [] - - for local_rank in range(1, self.num_processes): - env_copy = os.environ.copy() - env_copy["LOCAL_RANK"] = f"{local_rank}" - - # remove env var if global seed not set - if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: - del env_copy["PL_GLOBAL_SEED"] - - # start process - # if hydra is available and initialized, make sure to set the cwd correctly - cwd: Optional[str] = None - if _HYDRA_AVAILABLE: - if HydraConfig.initialized(): - cwd = get_original_cwd() - os_cwd = f'"{os.getcwd()}"' - command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"] - proc = subprocess.Popen(command, env=env_copy, cwd=cwd) - self.interactive_ddp_procs.append(proc) - - # starting all processes at once can cause issues - # with dataloaders delay between 1-10 seconds - delay = np.random.uniform(1, 5, 1)[0] - sleep(delay) - - self._rank_0_has_called_call_children_scripts = True - def setup_distributed(self): log.detail(f"{self.__class__.__name__}: setting up distributed...") reset_seed() @@ -251,14 +177,6 @@ def setup_distributed(self): # where to store ip_table init_dist_connection(self.cluster_environment, self.torch_distributed_backend) - def _check_can_spawn_children(self): - if self.local_rank != 0: - raise RuntimeError( - "Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen." - " Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user," - " 2) `ClusterEnvironment.creates_processes_externally` incorrectly implemented." - ) - def set_world_ranks(self) -> None: if self.cluster_environment is None: return @@ -436,7 +354,7 @@ def _should_run_deadlock_detection(self) -> bool: By default this is disabled. Otherwise, if the cluster environment creates the processes, allow the scheduler / parent process to perform the process termination, external to Lightning. """ - return os.getenv("PL_RECONCILE_PROCESS", "0") == "1" or self._rank_0_has_called_call_children_scripts + return os.getenv("PL_RECONCILE_PROCESS", "0") == "1" or self._rank_0_will_call_children_scripts def _share_information_to_prevent_deadlock(self) -> None: self._share_pids() diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 03407e1c14232..9b58137d2719d 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -13,14 +13,10 @@ # limitations under the License. import logging import os -from collections import UserList -from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union +from typing import Any, Dict, List, Optional, Union -import numpy as np import torch import torch.distributed -import torch.multiprocessing as mp from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel @@ -30,18 +26,17 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy -from pytorch_lightning.trainer.states import TrainerFn, TrainerState +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.enums import _StrategyType -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook @@ -114,6 +109,9 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self): return True + def _configure_launcher(self): + self._launcher = _SpawnLauncher(self) + def setup(self, trainer: "pl.Trainer") -> None: os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) super().setup(trainer) @@ -144,33 +142,6 @@ def set_world_ranks(self, process_idx: int = 0) -> None: def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return {"nprocs": self.num_processes} - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: - """Spawn processes that run the given function. - - Args: - function: The function to spawn processes from. - *args: Optional positional arguments that will be passed to the function in addition to the process index. - These arguments must be pickleable. - **kwargs: Optional named arguments that will be passed to the function in addition to the process index. - These arguments must be pickleable. - - Return: - The output of the function of process 0. - """ - os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) - context = mp.get_context("spawn") - return_queue = context.SimpleQueue() - mp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), nprocs=self.num_processes) - return return_queue.get() - - def _wrapped_function( - self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue - ) -> None: - self._worker_setup(process_idx) - result = function(*args, **kwargs) - if self.local_rank == 0: - return_queue.put(move_data_to_device(result, "cpu")) - def _worker_setup(self, process_idx: int): reset_seed() self.set_world_ranks(process_idx) @@ -216,55 +187,6 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: - rank_zero_debug("Finalizing the DDP spawn environment.") - checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None - - # requires to compute the state_dict on all processes in case Metrics are present - state_dict = self.lightning_module.state_dict() - - if self.global_rank != 0: - return - - # save the last weights - weights_path = None - if trainer.state.fn == TrainerFn.FITTING: - weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") - self.checkpoint_io.save_checkpoint(state_dict, weights_path) - - # adds the `callback_metrics` to the queue - extra = _FakeQueue() - if is_overridden("add_to_queue", self.lightning_module): - # TODO: Remove the if in v1.7 - self.lightning_module.add_to_queue(extra) - self.add_to_queue(trainer, extra) - - return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) - - def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: - # transfer back the best path to the trainer - if trainer.checkpoint_callback: - trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path - - # TODO: pass also best score - # load last weights - if spawn_output.weights_path is not None: - ckpt = self.checkpoint_io.load_checkpoint( - spawn_output.weights_path, map_location=(lambda storage, loc: storage) - ) - self.lightning_module.load_state_dict(ckpt) - self.checkpoint_io.remove_checkpoint(spawn_output.weights_path) - - trainer.state = spawn_output.trainer_state - - # get the `callback_metrics` and set it to the trainer - if is_overridden("get_from_queue", self.lightning_module): - # only in case the user does not override it. - # TODO: Remove the if in v1.7 - self.lightning_module.get_from_queue(spawn_output.extra) - self.get_from_queue(trainer, spawn_output.extra) - def barrier(self, *args, **kwargs) -> None: if not distributed_available(): return @@ -334,31 +256,6 @@ def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True - def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: - """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory - sharing, we cast the data to numpy. - - Args: - trainer: reference to the Trainer. - queue: the instance of the queue to append the data. - """ - callback_metrics: dict = apply_to_collection( - trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() - ) # send as numpy to avoid issues with memory sharing - queue.put(callback_metrics) - - def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: - """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, - we cast back the data to ``torch.Tensor``. - - Args: - trainer: reference to the Trainer. - queue: the instance of the queue from where to get the data. - """ - # NOTE: `add_to_queue` needs to be called before - callback_metrics: dict = queue.get() - trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))) - @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( @@ -381,24 +278,3 @@ def teardown(self) -> None: self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache() - - -class _FakeQueue(UserList): - """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list.""" - - def get(self) -> Any: - return self.pop(0) - - def put(self, item: Any) -> None: - self.append(item) - - def empty(self) -> bool: - return len(self) == 0 - - -class _SpawnOutput(NamedTuple): - best_model_path: Optional[_PATH] - weights_path: Optional[_PATH] - trainer_state: TrainerState - trainer_results: Any - extra: _FakeQueue diff --git a/pytorch_lightning/strategies/launchers/__init__.py b/pytorch_lightning/strategies/launchers/__init__.py new file mode 100644 index 0000000000000..340a2c0160b0e --- /dev/null +++ b/pytorch_lightning/strategies/launchers/__init__.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.strategies.launchers.base import _Launcher +from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher +from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher +from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher + +__all__ = [ + "_Launcher", + "_SpawnLauncher", + "_SubprocessScriptLauncher", + "_XLASpawnLauncher", +] diff --git a/pytorch_lightning/strategies/launchers/base.py b/pytorch_lightning/strategies/launchers/base.py new file mode 100644 index 0000000000000..293c0a2ce4508 --- /dev/null +++ b/pytorch_lightning/strategies/launchers/base.py @@ -0,0 +1,31 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Callable + + +class _Launcher(ABC): + r""" + Abstract base class for all Launchers. + + Launchers are responsible for the creation and instrumentation of new processes so that the + :class:`~pytorch_lightning.strategies.base.Strategy` can set up communication between all them. + + Subclass this class and override any of the relevant methods to provide a custom implementation depending on + cluster environment, hardware, strategy, etc. + """ + + @abstractmethod + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + """Launches the processes.""" diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py new file mode 100644 index 0000000000000..d1349fd39cd97 --- /dev/null +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -0,0 +1,191 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from collections import UserList +from multiprocessing.queues import SimpleQueue +from typing import Any, Callable, NamedTuple, Optional + +import numpy as np +import torch +import torch.multiprocessing as mp + +import pytorch_lightning as pl +from pytorch_lightning.strategies.launchers.base import _Launcher +from pytorch_lightning.strategies.strategy import Strategy +from pytorch_lightning.trainer.states import TrainerFn, TrainerState +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_debug +from pytorch_lightning.utilities.types import _PATH + + +class _SpawnLauncher(_Launcher): + r"""Spawns processes that run a given function in parallel, and joins them all at the end. + + The main process in which this launcher is invoked creates N so-called worker processes (using + :func:`torch.multiprocessing.spawn`) that run the given function. + Worker processes have a rank that ranges from 0 to N - 1. + + Note: + - This launcher requires all objects to be pickleable. + - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. + + Args: + strategy: A reference to the strategy that is used together with this launcher. + """ + + def __init__(self, strategy: Strategy) -> None: + self._strategy = strategy + + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + """Spawns processes that run the given function in parallel. + + The function is allowed to have a return value. However, when all processes join, only the return value + of worker process 0 gets returned from this `launch` method in the main process. + + Arguments: + function: The entry point for all spawned processes. + *args: Optional positional arguments to be passed to the given function. + **kwargs: Optional keyword arguments to be passed to the given function. + If a keyword argument named `trainer` is present and is an instance of + :class:`~pytorch_lightning.trainer.trainer.Trainer`, a selected set of attributes from the trainer get + restored in the main process after processes join. The `trainer` keyword argument will NOT be passed + into the function. + """ + trainer = kwargs.pop("trainer", None) + os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port) + context = mp.get_context("spawn") + return_queue = context.SimpleQueue() + mp.spawn( + self._wrapping_function, + args=(trainer, function, args, kwargs, return_queue), + nprocs=self._strategy.num_processes, + ) + spawn_output = return_queue.get() + if trainer is None: + return spawn_output + + self._recover_results_in_main_process(spawn_output, trainer) + return spawn_output.trainer_results + + def _wrapping_function( + self, + process_idx: int, + trainer: Optional["pl.Trainer"], + function: Callable, + args: Any, + kwargs: Any, + return_queue: SimpleQueue, + ) -> None: + self._strategy._worker_setup(process_idx) + results = function(*args, **kwargs) + + if trainer is not None: + results = self._collect_rank_zero_results(trainer, results) + + if self._strategy.local_rank == 0: + return_queue.put(move_data_to_device(results, "cpu")) + + def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: + # transfer back the best path to the trainer + if trainer.checkpoint_callback: + trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path) + + # TODO: pass also best score + # load last weights + if spawn_output.weights_path is not None: + ckpt = self._strategy.checkpoint_io.load_checkpoint(spawn_output.weights_path) + trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type] + self._strategy.checkpoint_io.remove_checkpoint(spawn_output.weights_path) + + trainer.state = spawn_output.trainer_state + + # get the `callback_metrics` and set it to the trainer + if is_overridden("get_from_queue", trainer.lightning_module): + # only in case the user does not override it. + # TODO: Remove the if in v1.7 + trainer.lightning_module.get_from_queue(spawn_output.extra) + self.get_from_queue(trainer, spawn_output.extra) + + def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + rank_zero_debug("Finalizing the DDP spawn environment.") + checkpoint_callback = trainer.checkpoint_callback + best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + + # requires to compute the state_dict on all processes in case Metrics are present + state_dict = trainer.lightning_module.state_dict() + + if self._strategy.global_rank != 0: + return None + + # save the last weights + weights_path = None + if trainer.state.fn == TrainerFn.FITTING: + weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") + self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) + + # adds the `callback_metrics` to the queue + extra = _FakeQueue() + if is_overridden("add_to_queue", trainer.lightning_module): + # TODO: Remove the if in v1.7 + trainer.lightning_module.add_to_queue(extra) + self.add_to_queue(trainer, extra) + + return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) + + def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: + """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory + sharing, we cast the data to numpy. + + Args: + trainer: reference to the Trainer. + queue: the instance of the queue to append the data. + """ + callback_metrics: dict = apply_to_collection( + trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() + ) # send as numpy to avoid issues with memory sharing + queue.put(callback_metrics) + + def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: + """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, + we cast back the data to ``torch.Tensor``. + + Args: + trainer: reference to the Trainer. + queue: the instance of the queue from where to get the data. + """ + # NOTE: `add_to_queue` needs to be called before + callback_metrics: dict = queue.get() + trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))) + + +class _FakeQueue(UserList): + """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list.""" + + def get(self) -> Any: + return self.pop(0) + + def put(self, item: Any) -> None: + self.append(item) + + def empty(self) -> bool: + return len(self) == 0 + + +class _SpawnOutput(NamedTuple): + best_model_path: Optional[_PATH] + weights_path: Optional[_PATH] + trainer_state: TrainerState + trainer_results: Any + extra: _FakeQueue diff --git a/pytorch_lightning/strategies/launchers/subprocess_script.py b/pytorch_lightning/strategies/launchers/subprocess_script.py new file mode 100644 index 0000000000000..e4b41500412d3 --- /dev/null +++ b/pytorch_lightning/strategies/launchers/subprocess_script.py @@ -0,0 +1,158 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import subprocess +import sys +from subprocess import Popen +from time import sleep +from typing import Any, Callable, List, Optional + +import __main__ +import numpy as np + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.strategies.launchers.base import _Launcher +from pytorch_lightning.utilities import _HYDRA_AVAILABLE + +if _HYDRA_AVAILABLE: + from hydra.core.hydra_config import HydraConfig + from hydra.utils import get_original_cwd, to_absolute_path + + +class _SubprocessScriptLauncher(_Launcher): + r""" + A process laucher that invokes the current script as many times as desired in a single node. + + This launcher needs to be invoked on each node. + In its default behavior, the main process in each node then spawns N-1 child processes via :func:`subprocess.Popen`, + where N is the number of devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.run` + launches processes. + + For example, if the script gets invoked with the command + + .. code-block:: bash + + python train.py --devices 4 + + The launcher will create three additional subprocesses that get called like so: + + .. code-block:: bash + + LOCAL_RANK=1 python train.py --devices 4 + LOCAL_RANK=2 python train.py --devices 4 + LOCAL_RANK=3 python train.py --devices 4 + + It is implied that the main process which launched the others has ``LOCAL_RANK=0``. + Beside the local rank, the following other environment variables also get set, but unlike the local rank, these + get determined by the cluster environment: + + 1. `MASTER_ADDR`: The IP address of the main node. + 2. `MASTER_PORT`: The port number of the main node through which all processes communicate. + 3. `NODE_RANK`: The index of the node the current process is running on. Ranges from 0 to ``num_nodes - 1``. + 4. `WORLD_SIZE`: The total number of processes across all nodes, i.e., ``num_processes * num_nodes``. + + Arguments: + cluster_environment: A cluster environment that provides access to world size, node rank, etc. + num_processes: The number of processes to launch in the current node. + num_nodes: The total number of nodes that participate in this process group. + """ + + def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None: + super().__init__() + self.cluster_environment = cluster_environment + self.num_processes = num_processes + self.num_nodes = num_nodes + self.interactive_ddp_procs: List[Popen] = [] + + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + """Creates new processes, then calls the given function. + + Arguments: + function: A callback function to execute after all processes have been created. + It is up to the implementation of this function to synchronize the processes, e.g., with barriers. + *args: Optional positional arguments to be passed to the given function. + **kwargs: Optional keyword arguments to be passed to the given function. + """ + kwargs.pop("trainer", None) + if not self.cluster_environment.creates_processes_externally: + self._call_children_scripts() + return function(*args, **kwargs) + + def _call_children_scripts(self) -> None: + # bookkeeping of spawned processes + self._check_can_spawn_children() + + # DDP Environment variables + os.environ["MASTER_ADDR"] = self.cluster_environment.main_address + os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) + + # allow the user to pass the node rank + os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) + os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) + + # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c` + # See https://docs.python.org/3/reference/import.html#main-spec + if __main__.__spec__ is None: # pragma: no-cover + # Script called as `python a/b/c.py` + # when user is using hydra find the absolute path + path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path + + # pull out the commands used to run the script and resolve the abs file path + command = sys.argv + try: + full_path = path_lib(command[0]) + except Exception: + full_path = os.path.abspath(command[0]) + + command[0] = full_path + # use the same python interpreter and actually running + command = [sys.executable] + command + else: # Script called as `python -m a.b.c` + command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:] + + os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}" + + self.interactive_ddp_procs = [] + + for local_rank in range(1, self.num_processes): + env_copy = os.environ.copy() + env_copy["LOCAL_RANK"] = f"{local_rank}" + + # remove env var if global seed not set + if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: + del env_copy["PL_GLOBAL_SEED"] + + # start process + # if hydra is available and initialized, make sure to set the cwd correctly + cwd: Optional[str] = None + if _HYDRA_AVAILABLE: + if HydraConfig.initialized(): + cwd = get_original_cwd() + os_cwd = f'"{os.getcwd()}"' + command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"] + proc = subprocess.Popen(command, env=env_copy, cwd=cwd) + self.interactive_ddp_procs.append(proc) + + # starting all processes at once can cause issues + # with dataloaders delay between 1-10 seconds + delay = np.random.uniform(1, 5, 1)[0] + sleep(delay) + + def _check_can_spawn_children(self) -> None: + if self.cluster_environment.local_rank() != 0: + raise RuntimeError( + "Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen." + " Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user," + " 2) `ClusterEnvironment.creates_processes_externally` incorrectly implemented." + ) diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py new file mode 100644 index 0000000000000..8bac7888c568b --- /dev/null +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -0,0 +1,128 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +from multiprocessing.queues import SimpleQueue +from typing import Any, Callable, Optional + +import torch.multiprocessing as mp + +import pytorch_lightning as pl +from pytorch_lightning.strategies.launchers.spawn import _FakeQueue, _SpawnLauncher, _SpawnOutput +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_debug + +if _TPU_AVAILABLE: + import torch_xla.distributed.xla_multiprocessing as xmp +else: + xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 + + +class _XLASpawnLauncher(_SpawnLauncher): + r"""Spawns processes that run a given function in parallel on XLA supported hardware, and joins them all at the end. + + The main process in which this launcher is invoked creates N so-called worker processes (using the + `torch_xla` :func:`xmp.spawn`) that run the given function. + Worker processes have a rank that ranges from 0 to N - 1. + + Note: + - This launcher requires all objects to be pickleable. + - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. + """ + + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + """Spawns processes that run the given function in parallel. + + The function is allowed to have a return value. However, when all processes join, only the return value + of worker process 0 gets returned from this `launch` method in the main process. + + Arguments: + function: The entry point for all spawned processes. + *args: Optional positional arguments to be passed to the given function. + **kwargs: Optional keyword arguments to be passed to the given function. + If a keyword argument named `trainer` is present and is an instance of + :class:`~pytorch_lightning.trainer.trainer.Trainer`, a selected set of attributes from the trainer get + restored in the main process after processes join. The `trainer` keyword argument will NOT be passed + into the function. + """ + trainer = kwargs.pop("trainer", None) + context = mp.get_context(self._strategy.start_method or "fork") + return_queue = context.SimpleQueue() + xmp.spawn( + self._wrapping_function, + args=(trainer, function, args, kwargs, return_queue), + **self._strategy.get_mp_spawn_kwargs() + ) + spawn_output = return_queue.get() + if trainer is None: + return spawn_output + + self._recover_results_in_main_process(spawn_output, trainer) + return spawn_output.trainer_results + + def _wrapping_function( + self, + process_idx: int, + trainer: Optional["pl.Trainer"], + function: Callable, + args: Any, + kwargs: Any, + return_queue: SimpleQueue, + ) -> None: + self._strategy._worker_setup(process_idx) + results = function(*args, **kwargs) + + if trainer is not None: + results = self._collect_rank_zero_results(trainer, results) + + if self._strategy.local_rank == 0: + return_queue.put(move_data_to_device(results, "cpu")) + + # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 + self._strategy.barrier("end-process") + + # Ensure that the rank 0 process is the one exiting last + # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 + if self._strategy.local_rank == 0: + time.sleep(2) + + def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + rank_zero_debug("Finalizing the TPU spawn environment.") + checkpoint_callback = trainer.checkpoint_callback + best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + + # requires to compute the state_dict on all processes in case Metrics are present + state_dict = trainer.lightning_module.state_dict() + + # save the last weights + weights_path = None + if trainer.state.fn == TrainerFn.FITTING: + weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") + self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) + + # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training + if self._strategy.local_rank != 0: + return None + + # adds the `callback_metrics` to the queue + extra = _FakeQueue() + if is_overridden("add_to_queue", trainer.lightning_module): + # TODO: Remove the if in v1.7 + trainer.lightning_module.add_to_queue(extra) + self.add_to_queue(trainer, extra) + + return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index 629911911b780..37b9b435b7413 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -27,6 +27,7 @@ from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.launchers.base import _Launcher from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device @@ -48,6 +49,7 @@ def __init__( precision_plugin: Optional[PrecisionPlugin] = None, ) -> None: self.accelerator = accelerator + self._launcher: Optional[_Launcher] = None self._model: Optional[Module] = None self.checkpoint_io = checkpoint_io self.precision_plugin = precision_plugin @@ -61,6 +63,10 @@ def __init__( f" Move your implementation to `{self.__class__.__name__}.teardown()` instead." ) + @property + def launcher(self) -> Optional[_Launcher]: + return self._launcher + @property def accelerator(self) -> "pl.accelerators.accelerator.Accelerator": return self._accelerator @@ -100,6 +106,9 @@ def connect(self, model: Module) -> None: """Called by the accelerator to connect the accelerator and the model with this plugin.""" self.model = model + def _configure_launcher(self): + """Attach the launcher based on Strategy.""" + def setup_environment(self) -> None: """Setup any processes or distributed connections. diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index a6e82441da296..f3d855b43f8a6 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -13,12 +13,9 @@ # limitations under the License. import io import os -import time -from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch -import torch.multiprocessing as mp from torch.nn import Module from torch.utils.data import DataLoader @@ -26,16 +23,15 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnStrategy +from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy +from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher from pytorch_lightning.trainer.connectors.data_connector import DataConnector -from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters -from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only +from pytorch_lightning.utilities.rank_zero import rank_zero_only from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT @@ -120,6 +116,9 @@ def connect(self, model: "pl.LightningModule") -> None: self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model)) return super().connect(model) + def _configure_launcher(self): + self._launcher = _XLASpawnLauncher(self) + def setup(self, trainer: "pl.Trainer") -> None: self.start_method = "fork" self.accelerator.setup(trainer) @@ -175,33 +174,6 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: - rank_zero_debug("Finalizing the TPU spawn environment.") - checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None - - # requires to compute the state_dict on all processes in case Metrics are present - state_dict = self.lightning_module.state_dict() - - # save the last weights - weights_path = None - if trainer.state.fn == TrainerFn.FITTING: - weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") - self.checkpoint_io.save_checkpoint(state_dict, weights_path) - - # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training - if self.local_rank != 0: - return - - # adds the `callback_metrics` to the queue - extra = _FakeQueue() - if is_overridden("add_to_queue", self.lightning_module): - # TODO: Remove the if in v1.7 - self.lightning_module.add_to_queue(extra) - self.add_to_queue(trainer, extra) - - return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) - def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: return obj @@ -244,28 +216,6 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st "start_method": self.start_method, } - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: - context = mp.get_context(self.start_method or "fork") - return_queue = context.SimpleQueue() - xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) - return return_queue.get() - - def _wrapped_function( - self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue - ) -> None: - self._worker_setup(process_idx) - result = function(*args, **kwargs) - if self.local_rank == 0: - return_queue.put(move_data_to_device(result, "cpu")) - - # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 - self.barrier("end-process") - - # Ensure that the rank 0 process is the one exiting last - # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 - if self.local_rank == 0: - time.sleep(2) - def _worker_setup(self, process_idx: int): reset_seed() self.tpu_local_core_rank = xm.get_local_ordinal() diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 5734914c3bd13..1abad1464f8c0 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -261,13 +261,11 @@ def _check_add_get_queue(model: "pl.LightningModule") -> None: """ if is_overridden("add_to_queue", model): rank_zero_deprecation( - "The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7 in " - "favor of `DDPSpawnStrategy.add_to_queue`" + "The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7." ) if is_overridden("get_from_queue", model): rank_zero_deprecation( - "The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7 in " - "favor of `DDPSpawnStrategy.get_from_queue`" + "The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7." ) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index fd65975618f02..8f770feb790c3 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -172,6 +172,7 @@ def __init__( self._set_devices_if_none() self.strategy = self.final_strategy() + self.strategy._configure_launcher() self.accelerator = self.strategy.accelerator self._check_plugin_compatibility() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b6a0d7fa452e0..658c7ed4d09ca 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -58,7 +58,7 @@ XLAProfiler, ) from pytorch_lightning.strategies import ParallelStrategy, Strategy -from pytorch_lightning.strategies.ddp_spawn import _SpawnOutput, DDPSpawnStrategy +from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector @@ -669,10 +669,8 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: **kwargs: keyword arguments to be passed to `trainer_fn` """ try: - if isinstance(self.strategy, DDPSpawnStrategy): - spawn_output: _SpawnOutput = self.strategy.spawn(trainer_fn, *args, **kwargs) - self.strategy._recover_results_in_main_process(spawn_output, self) - return spawn_output.trainer_results + if self.strategy.launcher is not None: + return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs) else: return trainer_fn(*args, **kwargs) # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 @@ -1203,9 +1201,6 @@ def _run( self.state.status = TrainerStatus.FINISHED self.state.stage = None - if isinstance(self.strategy, DDPSpawnStrategy): - results = self.strategy._collect_rank_zero_results(self, results) - return results def _log_hyperparams(self) -> None: diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 18bb04bd0ae17..ecdcc743ea822 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -40,7 +40,12 @@ def test_evaluate(tmpdir, trainer_kwargs): dm = ClassifDataModule() model = CustomClassificationModelDP() trainer = Trainer( - default_root_dir=tmpdir, max_epochs=2, limit_train_batches=10, limit_val_batches=10, **trainer_kwargs + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + limit_test_batches=10, + **trainer_kwargs ) trainer.fit(model, datamodule=dm) diff --git a/tests/strategies/test_ddp_spawn_strategy.py b/tests/strategies/test_ddp_spawn_strategy.py index 2c0dff23aafd3..ef33a3ba56a43 100644 --- a/tests/strategies/test_ddp_spawn_strategy.py +++ b/tests/strategies/test_ddp_spawn_strategy.py @@ -20,6 +20,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.strategies import DDPSpawnStrategy +from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher from pytorch_lightning.trainer.states import TrainerFn from tests.helpers.boring_model import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -82,16 +83,21 @@ def test_ddp_spawn_extra_parameters(tmpdir): assert model.test_val == "test_val" -class TestDDPSpawnStrategy(DDPSpawnStrategy): +class CustomSpawnLauncher(_SpawnLauncher): def add_to_queue(self, trainer, queue) -> None: queue.put("new_test_val") return super().add_to_queue(trainer, queue) def get_from_queue(self, trainer: Trainer, queue) -> None: - self.new_test_val = queue.get() + trainer.strategy.new_test_val = queue.get() return super().get_from_queue(trainer, queue) +class TestDDPSpawnStrategy(DDPSpawnStrategy): + def _configure_launcher(self): + self._launcher = CustomSpawnLauncher(self) + + @RunIf(skip_windows=True, skip_49370=True) def test_ddp_spawn_add_get_queue(tmpdir): """Tests add_to_queue/get_from_queue with DDPSpawnStrategy.""" @@ -148,13 +154,13 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): file.""" model = Mock(wraps=BoringModel(), spec=BoringModel) strategy = DDPSpawnStrategy() - strategy.model = model - trainer = Trainer(default_root_dir=tmpdir) + trainer = Trainer(default_root_dir=tmpdir, strategy=strategy) + trainer.strategy.connect(model) trainer.state.fn = trainer_fn # pretend we are in a particular trainer state temp_file = Path(tmpdir, ".temp.ckpt") assert not temp_file.exists() - spawn_output = strategy._collect_rank_zero_results(trainer, {}) + spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {}) model.state_dict.assert_called_once() if trainer_fn == TrainerFn.FITTING: @@ -165,6 +171,6 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): assert not temp_file.exists() # <-- here would normally be the multiprocessing boundary - strategy._recover_results_in_main_process(spawn_output, trainer) + strategy._launcher._recover_results_in_main_process(spawn_output, trainer) assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None) assert not temp_file.exists()