From 9687e73e346525d58e834b4dce9ad1e5111812e3 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 27 Jan 2022 15:15:13 +0530 Subject: [PATCH 01/36] add executors for strategies --- pytorch_lightning/strategies/ddp.py | 101 ++------------- pytorch_lightning/strategies/ddp_spawn.py | 9 ++ .../strategies/executors/__init__.py | 13 ++ .../strategies/executors/base.py | 10 ++ pytorch_lightning/strategies/executors/ddp.py | 103 +++++++++++++++ .../strategies/executors/ddp_spawn.py | 122 ++++++++++++++++++ .../strategies/executors/single_process.py | 6 + .../strategies/executors/tpu_spawn.py | 38 ++++++ pytorch_lightning/strategies/parallel.py | 5 + pytorch_lightning/strategies/single_device.py | 5 + pytorch_lightning/strategies/strategy.py | 4 + pytorch_lightning/strategies/tpu_spawn.py | 6 + pytorch_lightning/trainer/trainer.py | 12 +- 13 files changed, 332 insertions(+), 102 deletions(-) create mode 100644 pytorch_lightning/strategies/executors/__init__.py create mode 100644 pytorch_lightning/strategies/executors/base.py create mode 100644 pytorch_lightning/strategies/executors/ddp.py create mode 100644 pytorch_lightning/strategies/executors/ddp_spawn.py create mode 100644 pytorch_lightning/strategies/executors/single_process.py create mode 100644 pytorch_lightning/strategies/executors/tpu_spawn.py diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 4aa67baaed422..d2f9387f11626 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -15,16 +15,11 @@ import os import shutil import signal -import subprocess -import sys import tempfile import time from pathlib import Path -from time import sleep from typing import Any, Dict, List, Optional, Union -import __main__ -import numpy as np import torch import torch.distributed from torch.nn import Module @@ -37,11 +32,12 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.executors.ddp import DDPSubprocessExecutor +from pytorch_lightning.strategies.executors.single_process import SingleProcessExecutor from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( _FAIRSCALE_AVAILABLE, - _HYDRA_AVAILABLE, _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, @@ -57,15 +53,12 @@ sync_ddp_if_available, ) from pytorch_lightning.utilities.enums import _StrategyType -from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException +from pytorch_lightning.utilities.exceptions import DeadlockDetectedException from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT if _FAIRSCALE_AVAILABLE: from fairscale.optim import OSS -if _HYDRA_AVAILABLE: - from hydra.core.hydra_config import HydraConfig - from hydra.utils import get_original_cwd, to_absolute_path if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook @@ -103,7 +96,6 @@ def __init__( precision_plugin=precision_plugin, ) log.detail(f"{self.__class__.__name__}: initializing DDP plugin") - self.interactive_ddp_procs = [] self._num_nodes = 1 self.sync_batchnorm = False self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 @@ -144,11 +136,13 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self) -> bool: return True - def setup_environment(self) -> None: - # start the other scripts - if not self.cluster_environment.creates_processes_externally: - self._call_children_scripts() + def execute(self, trainer, function, *args, **kwargs): + executer = ( + SingleProcessExecutor if self.cluster_environment.creates_processes_externally else DDPSubprocessExecutor + )(self) + executer.execute(trainer, function, *args, **kwargs) + def setup_environment(self) -> None: self.setup_distributed() super().setup_environment() @@ -176,75 +170,6 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: log.detail(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) - def _call_children_scripts(self): - # bookkeeping of spawned processes - self._check_can_spawn_children() - - # DDP Environment variables - os.environ["MASTER_ADDR"] = self.cluster_environment.main_address - os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) - - # allow the user to pass the node rank - os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) - os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) - - # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c` - # See https://docs.python.org/3/reference/import.html#main-spec - if __main__.__spec__ is None: # pragma: no-cover - # Script called as `python a/b/c.py` - # when user is using hydra find the absolute path - path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path - - # pull out the commands used to run the script and resolve the abs file path - command = sys.argv - try: - full_path = path_lib(command[0]) - except Exception: - full_path = os.path.abspath(command[0]) - - command[0] = full_path - # use the same python interpreter and actually running - command = [sys.executable] + command - else: # Script called as `python -m a.b.c` - command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:] - - # the visible devices tell us how many GPUs we want to use. - # when the trainer script was called the device has already been scoped by the time - # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone - # but forward the GPUs selected via environment variables - if self.parallel_devices is None: - raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)") - - os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}" - - self.interactive_ddp_procs = [] - - for local_rank in range(1, self.num_processes): - env_copy = os.environ.copy() - env_copy["LOCAL_RANK"] = f"{local_rank}" - - # remove env var if global seed not set - if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: - del env_copy["PL_GLOBAL_SEED"] - - # start process - # if hydra is available and initialized, make sure to set the cwd correctly - cwd: Optional[str] = None - if _HYDRA_AVAILABLE: - if HydraConfig.initialized(): - cwd = get_original_cwd() - os_cwd = f'"{os.getcwd()}"' - command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"] - proc = subprocess.Popen(command, env=env_copy, cwd=cwd) - self.interactive_ddp_procs.append(proc) - - # starting all processes at once can cause issues - # with dataloaders delay between 1-10 seconds - delay = np.random.uniform(1, 5, 1)[0] - sleep(delay) - - self._rank_0_has_called_call_children_scripts = True - def setup_distributed(self): log.detail(f"{self.__class__.__name__}: setting up distributed...") reset_seed() @@ -260,14 +185,6 @@ def setup_distributed(self): # where to store ip_table init_dist_connection(self.cluster_environment, self.torch_distributed_backend) - def _check_can_spawn_children(self): - if self.local_rank != 0: - raise RuntimeError( - "Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen." - " Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user," - " 2) `ClusterEnvironment.creates_processes_externally` incorrectly implemented." - ) - def set_world_ranks(self) -> None: if self.cluster_environment is None: return diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 097992dc1975e..e3ef0111a7d75 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -30,6 +30,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.executors.ddp_spawn import DDPSpawnExecutor from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.trainer.states import TrainerFn, TrainerState from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn @@ -44,6 +45,7 @@ sync_ddp_if_available, ) from pytorch_lightning.utilities.enums import _StrategyType +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT @@ -146,6 +148,10 @@ def set_world_ranks(self, process_idx: int = 0) -> None: def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return {"nprocs": self.num_processes} + def execute(self, trainer, function, *args, **kwargs): + executor = DDPSpawnExecutor(self) + executor.execute(trainer, function, *args, **kwargs) + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: """Spawn processes that run the given function. @@ -159,6 +165,7 @@ def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union Return: The output of the function of process 0. """ + raise MisconfigurationException("this should not run") os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) context = mp.get_context("spawn") return_queue = context.SimpleQueue() @@ -219,6 +226,7 @@ def determine_ddp_device_ids(self): return [self.root_device.index] def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + raise MisconfigurationException("this should not run") rank_zero_debug("Finalizing the DDP spawn environment.") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -245,6 +253,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: + raise MisconfigurationException("this should not run") # transfer back the best path to the trainer if trainer.checkpoint_callback: trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path diff --git a/pytorch_lightning/strategies/executors/__init__.py b/pytorch_lightning/strategies/executors/__init__.py new file mode 100644 index 0000000000000..87964cb46996a --- /dev/null +++ b/pytorch_lightning/strategies/executors/__init__.py @@ -0,0 +1,13 @@ +from pytorch_lightning.strategies.executors.base import Executor +from pytorch_lightning.strategies.executors.ddp import DDPSubprocessExecutor +from pytorch_lightning.strategies.executors.ddp_spawn import DDPSpawnExecutor +from pytorch_lightning.strategies.executors.single_process import SingleProcessExecutor +from pytorch_lightning.strategies.executors.tpu_spawn import TPUSpawnExecutor + +__all__ = [ + "DDPSpawnExecutor", + "DDPSubprocessExecutor", + "Executor", + "SingleProcessExecutor", + "TPUSpawnExecutor", +] diff --git a/pytorch_lightning/strategies/executors/base.py b/pytorch_lightning/strategies/executors/base.py new file mode 100644 index 0000000000000..a90ac648b4702 --- /dev/null +++ b/pytorch_lightning/strategies/executors/base.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + + +class Executor(ABC): + def __init__(self, strategy): + self.strategy = strategy + + @abstractmethod + def execute(self, trainer, fn, *args, **kwargs) -> bool: + """Executes the proceses.""" diff --git a/pytorch_lightning/strategies/executors/ddp.py b/pytorch_lightning/strategies/executors/ddp.py new file mode 100644 index 0000000000000..98350b61c4331 --- /dev/null +++ b/pytorch_lightning/strategies/executors/ddp.py @@ -0,0 +1,103 @@ +import os +import subprocess +import sys +from time import sleep +from typing import Optional + +import __main__ +import numpy as np + +from pytorch_lightning.strategies.executors.base import Executor +from pytorch_lightning.utilities import _HYDRA_AVAILABLE + +if _HYDRA_AVAILABLE: + from hydra.core.hydra_config import HydraConfig + from hydra.utils import get_original_cwd, to_absolute_path + +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class DDPSubprocessExecutor(Executor): + def __init__(self, strategy): + super().__init__(self, strategy) + self.interactive_ddp_procs = [] + + def execute(self, trainer, function, *args, **kwargs): + self._call_children_scripts() + + def _call_children_scripts(self): + # bookkeeping of spawned processes + self._check_can_spawn_children() + + # DDP Environment variables + os.environ["MASTER_ADDR"] = self.strategy.cluster_environment.main_address + os.environ["MASTER_PORT"] = str(self.strategy.cluster_environment.main_port) + + # allow the user to pass the node rank + os.environ["NODE_RANK"] = str(self.strategy.cluster_environment.node_rank()) + os.environ["LOCAL_RANK"] = str(self.startegy.cluster_environment.local_rank()) + + # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c` + # See https://docs.python.org/3/reference/import.html#main-spec + if __main__.__spec__ is None: # pragma: no-cover + # Script called as `python a/b/c.py` + # when user is using hydra find the absolute path + path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path + + # pull out the commands used to run the script and resolve the abs file path + command = sys.argv + try: + full_path = path_lib(command[0]) + except Exception: + full_path = os.path.abspath(command[0]) + + command[0] = full_path + # use the same python interpreter and actually running + command = [sys.executable] + command + else: # Script called as `python -m a.b.c` + command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:] + + # the visible devices tell us how many GPUs we want to use. + # when the trainer script was called the device has already been scoped by the time + # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone + # but forward the GPUs selected via environment variables + if self.strategy.parallel_devices is None: + raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)") + + os.environ["WORLD_SIZE"] = f"{self.strategy.num_processes * self.strategy.num_nodes}" + + self.interactive_ddp_procs = [] + + for local_rank in range(1, self.strategy.num_processes): + env_copy = os.environ.copy() + env_copy["LOCAL_RANK"] = f"{local_rank}" + + # remove env var if global seed not set + if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: + del env_copy["PL_GLOBAL_SEED"] + + # start process + # if hydra is available and initialized, make sure to set the cwd correctly + cwd: Optional[str] = None + if _HYDRA_AVAILABLE: + if HydraConfig.initialized(): + cwd = get_original_cwd() + os_cwd = f'"{os.getcwd()}"' + command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"] + proc = subprocess.Popen(command, env=env_copy, cwd=cwd) + self.interactive_ddp_procs.append(proc) + + # starting all processes at once can cause issues + # with dataloaders delay between 1-10 seconds + delay = np.random.uniform(1, 5, 1)[0] + sleep(delay) + + self.strategy._rank_0_has_called_call_children_scripts = True + + def _check_can_spawn_children(self): + if self.strategy.local_rank != 0: + raise RuntimeError( + "Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen." + " Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user," + " 2) `ClusterEnvironment.creates_processes_externally` incorrectly implemented." + ) diff --git a/pytorch_lightning/strategies/executors/ddp_spawn.py b/pytorch_lightning/strategies/executors/ddp_spawn.py new file mode 100644 index 0000000000000..f9086410c6419 --- /dev/null +++ b/pytorch_lightning/strategies/executors/ddp_spawn.py @@ -0,0 +1,122 @@ +import os +from collections import UserList +from typing import Any, NamedTuple, Optional + +import torch +import torch.multiprocessing as mp + +import pytorch_lightning as pl +from pytorch_lightning.strategies.executors.base import Executor +from pytorch_lightning.trainer.states import TrainerFn, TrainerState +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device +from pytorch_lightning.utilities.distributed import rank_zero_debug +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.types import _PATH + + +class DDPSpawnExecutor(Executor): + def __init__(self, strategy): + super().__init__(strategy) + + def execute(self, trainer, function, *args, **kwargs): + os.environ["MASTER_PORT"] = str(self.strategy.cluster_environment.main_port) + context = mp.get_context("spawn") + return_queue = context.SimpleQueue() + mp.spawn( + self._wrapped_function, + args=(trainer, function, args, kwargs, return_queue), + nprocs=self.strategy.num_processes, + ) + spawn_output = return_queue.get() + self._recover_results_in_main_process(spawn_output, trainer) + return spawn_output.trainer_results + + def _wrapped_function(self, process_idx, trainer, function, args, kwargs, return_queue): + self.strategy._worker_setup(process_idx) + results = function(*args, **kwargs) + results = self._collect_rank_zero_results(trainer, results) + if self.local_rank == 0: + return_queue.put(move_data_to_device(results, "cpu")) + + def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: + # transfer back the best path to the trainer + if trainer.checkpoint_callback: + trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path + + # TODO: pass also best score + # load last weights + if spawn_output.weights_path is not None: + ckpt = self.checkpoint_io.load_checkpoint( + spawn_output.weights_path, map_location=(lambda storage, loc: storage) + ) + self.lightning_module.load_state_dict(ckpt) + self.checkpoint_io.remove_checkpoint(spawn_output.weights_path) + + trainer.state = spawn_output.trainer_state + + # get the `callback_metrics` and set it to the trainer + if is_overridden("get_from_queue", self.lightning_module): + # only in case the user does not override it. + # TODO: Remove the if in v1.7 + self.lightning_module.get_from_queue(spawn_output.extra) + self.get_from_queue(trainer, spawn_output.extra) + + def _collect_ranku_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + rank_zero_debug("Finalizing the DDP spawn environment.") + checkpoint_callback = trainer.checkpoint_callback + best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + + # requires to compute the state_dict on all processes in case Metrics are present + state_dict = self.strategy.lightning_module.state_dict() + + if self.strategy.global_rank != 0: + return + + # save the last weights + weights_path = None + if trainer.state.fn == TrainerFn.FITTING: + weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") + self.strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) + + # adds the `callback_metrics` to the queue + extra = _FakeQueue() + if is_overridden("add_to_queue", self.lightning_module): + # TODO: Remove the if in v1.7 + self.strategy.lightning_module.add_to_queue(extra) + self.add_to_queue(trainer, extra) + + return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) + + def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: + """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory + sharing, we cast the data to numpy. + + Args: + trainer: reference to the Trainer. + queue: the instance of the queue to append the data. + """ + callback_metrics: dict = apply_to_collection( + trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() + ) # send as numpy to avoid issues with memory sharing + queue.put(callback_metrics) + + +class _FakeQueue(UserList): + """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list.""" + + def get(self) -> Any: + return self.pop(0) + + def put(self, item: Any) -> None: + self.append(item) + + def empty(self) -> bool: + return len(self) == 0 + + +class _SpawnOutput(NamedTuple): + best_model_path: Optional[_PATH] + weights_path: Optional[_PATH] + trainer_state: TrainerState + trainer_results: Any + extra: _FakeQueue diff --git a/pytorch_lightning/strategies/executors/single_process.py b/pytorch_lightning/strategies/executors/single_process.py new file mode 100644 index 0000000000000..3158187c090dd --- /dev/null +++ b/pytorch_lightning/strategies/executors/single_process.py @@ -0,0 +1,6 @@ +from pytorch_lightning.strategies.executors.base import Executor + + +class SingleProcessExecutor(Executor): + def execute(self, function, *args, **kwargs): + return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/executors/tpu_spawn.py b/pytorch_lightning/strategies/executors/tpu_spawn.py new file mode 100644 index 0000000000000..f53a793e0ab93 --- /dev/null +++ b/pytorch_lightning/strategies/executors/tpu_spawn.py @@ -0,0 +1,38 @@ +import time +from multiprocessing.queues import SimpleQueue +from typing import Any, Callable + +import torch.multiprocessing as mp + +from pytorch_lightning.strategies.executors.ddp_spawn import DDPSpawnExecutor +from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.apply_func import move_data_to_device + +if _TPU_AVAILABLE: + import torch_xla.distributed.xla_multiprocessing as xmp +else: + xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 + + +class TPUSpawnExecutor(DDPSpawnExecutor): + def execute(self, trainer, function, *args, **kwargs): + context = mp.get_context(self.start_method or "fork") + return_queue = context.SimpleQueue() + xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) + return return_queue.get() + + def _wrapped_function( + self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue + ) -> None: + self._worker_setup(process_idx) + result = function(*args, **kwargs) + if self.local_rank == 0: + return_queue.put(move_data_to_device(result, "cpu")) + + # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 + self.barrier("end-process") + + # Ensure that the rank 0 process is the one exiting last + # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 + if self.local_rank == 0: + time.sleep(2) diff --git a/pytorch_lightning/strategies/parallel.py b/pytorch_lightning/strategies/parallel.py index 5d7d487a214e3..9fcf823986ef0 100644 --- a/pytorch_lightning/strategies/parallel.py +++ b/pytorch_lightning/strategies/parallel.py @@ -24,6 +24,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.executors.single_process import SingleProcessExecutor from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp @@ -125,6 +126,10 @@ def block_backward_sync(self): else: yield None + def execute(self, trainer, function, *args, **kwargs): + executer = SingleProcessExecutor(self) + executer.execute(trainer, function, *args, **kwargs) + def teardown(self) -> None: self.cluster_environment.teardown() super().teardown() diff --git a/pytorch_lightning/strategies/single_device.py b/pytorch_lightning/strategies/single_device.py index 440c73afce8fc..16c2901f27379 100644 --- a/pytorch_lightning/strategies/single_device.py +++ b/pytorch_lightning/strategies/single_device.py @@ -20,6 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.executors.single_process import SingleProcessExecutor from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.types import _DEVICE @@ -40,6 +41,10 @@ def __init__( self.local_rank = 0 self.world_size = 1 + def execute(self, trainer, function, *args, **kwargs): + executer = SingleProcessExecutor(self) + executer.execute(trainer, function, *args, **kwargs) + def reduce(self, tensor: Any | torch.Tensor, *args: Any, **kwargs: Any) -> Any | torch.Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only operates with a single device, the reduction is simply the identity. diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index be77932fc6250..0ca635df34924 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -139,6 +139,10 @@ def setup_precision_plugin(self) -> None: self.optimizers = optimizers self.lr_schedulers = schedulers + @abstractmethod + def execute(self, trainer, function, *args, **kwargs): + """Executes the proceses using an Executor.""" + def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the appropriate device if needed.""" for opt in self.optimizers: diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 3592352f8f2a1..17b0d2bf4afa6 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -27,6 +27,7 @@ from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnStrategy +from pytorch_lightning.strategies.executors.tpu_spawn import TPUSpawnExecutor from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters @@ -243,7 +244,12 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st "start_method": self.start_method, } + def execute(self, trainer, fn, *args, **kwargs): + executor = TPUSpawnExecutor(self) + executor.execute(trainer, fn, *args, **kwargs) + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: + raise MisconfigurationException("this should not run") context = mp.get_context(self.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ac01227fd00ac..1f0fa14dad0e8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -58,7 +58,7 @@ XLAProfiler, ) from pytorch_lightning.strategies import ParallelStrategy, Strategy -from pytorch_lightning.strategies.ddp_spawn import _SpawnOutput, DDPSpawnStrategy +from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector @@ -668,12 +668,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: **kwargs: keyword arguments to be passed to `trainer_fn` """ try: - if isinstance(self.strategy, DDPSpawnStrategy): - spawn_output: _SpawnOutput = self.strategy.spawn(trainer_fn, *args, **kwargs) - self.strategy._recover_results_in_main_process(spawn_output, self) - return spawn_output.trainer_results - else: - return trainer_fn(*args, **kwargs) + return trainer_fn(*args, **kwargs) # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") @@ -1192,9 +1187,6 @@ def _run( self.state.status = TrainerStatus.FINISHED self.state.stage = None - if isinstance(self.strategy, DDPSpawnStrategy): - results = self.strategy._collect_rank_zero_results(self, results) - return results def _log_hyperparams(self) -> None: From 7f325f50cf5c922e9df0b77e4ac9c960a5bc0da2 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 27 Jan 2022 17:01:27 +0530 Subject: [PATCH 02/36] add license --- pytorch_lightning/strategies/__init__.py | 13 +++++++++++++ pytorch_lightning/strategies/executors/__init__.py | 13 +++++++++++++ pytorch_lightning/strategies/executors/base.py | 13 +++++++++++++ pytorch_lightning/strategies/executors/ddp.py | 13 +++++++++++++ pytorch_lightning/strategies/executors/ddp_spawn.py | 13 +++++++++++++ .../strategies/executors/single_process.py | 13 +++++++++++++ pytorch_lightning/strategies/executors/tpu_spawn.py | 13 +++++++++++++ 7 files changed, 91 insertions(+) diff --git a/pytorch_lightning/strategies/__init__.py b/pytorch_lightning/strategies/__init__.py index 9b06e2fe587ee..22d0d1d5db902 100644 --- a/pytorch_lightning/strategies/__init__.py +++ b/pytorch_lightning/strategies/__init__.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pathlib import Path from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401 diff --git a/pytorch_lightning/strategies/executors/__init__.py b/pytorch_lightning/strategies/executors/__init__.py index 87964cb46996a..127bd5258ef70 100644 --- a/pytorch_lightning/strategies/executors/__init__.py +++ b/pytorch_lightning/strategies/executors/__init__.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pytorch_lightning.strategies.executors.base import Executor from pytorch_lightning.strategies.executors.ddp import DDPSubprocessExecutor from pytorch_lightning.strategies.executors.ddp_spawn import DDPSpawnExecutor diff --git a/pytorch_lightning/strategies/executors/base.py b/pytorch_lightning/strategies/executors/base.py index a90ac648b4702..76fd10157c4a2 100644 --- a/pytorch_lightning/strategies/executors/base.py +++ b/pytorch_lightning/strategies/executors/base.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from abc import ABC, abstractmethod diff --git a/pytorch_lightning/strategies/executors/ddp.py b/pytorch_lightning/strategies/executors/ddp.py index 98350b61c4331..8cf18e47d6e7a 100644 --- a/pytorch_lightning/strategies/executors/ddp.py +++ b/pytorch_lightning/strategies/executors/ddp.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import subprocess import sys diff --git a/pytorch_lightning/strategies/executors/ddp_spawn.py b/pytorch_lightning/strategies/executors/ddp_spawn.py index f9086410c6419..200ec4caedeb8 100644 --- a/pytorch_lightning/strategies/executors/ddp_spawn.py +++ b/pytorch_lightning/strategies/executors/ddp_spawn.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os from collections import UserList from typing import Any, NamedTuple, Optional diff --git a/pytorch_lightning/strategies/executors/single_process.py b/pytorch_lightning/strategies/executors/single_process.py index 3158187c090dd..0eb558ab9ac34 100644 --- a/pytorch_lightning/strategies/executors/single_process.py +++ b/pytorch_lightning/strategies/executors/single_process.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pytorch_lightning.strategies.executors.base import Executor diff --git a/pytorch_lightning/strategies/executors/tpu_spawn.py b/pytorch_lightning/strategies/executors/tpu_spawn.py index f53a793e0ab93..1db5ec478359c 100644 --- a/pytorch_lightning/strategies/executors/tpu_spawn.py +++ b/pytorch_lightning/strategies/executors/tpu_spawn.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import time from multiprocessing.queues import SimpleQueue from typing import Any, Callable From 9114b09d16bc2e54129327a0f77aec59c9ebb0e4 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 27 Jan 2022 17:55:38 +0530 Subject: [PATCH 03/36] fix issues --- pytorch_lightning/strategies/ddp.py | 4 +-- pytorch_lightning/strategies/executors/ddp.py | 4 +-- .../strategies/executors/ddp_spawn.py | 32 +++++++++++------ .../strategies/executors/single_process.py | 2 +- .../strategies/executors/tpu_spawn.py | 36 +++++++++++++++++-- pytorch_lightning/strategies/parallel.py | 4 +-- pytorch_lightning/strategies/single_device.py | 4 +-- pytorch_lightning/strategies/tpu_spawn.py | 1 + pytorch_lightning/trainer/trainer.py | 3 +- 9 files changed, 67 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index d2f9387f11626..5a925442984b4 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -137,10 +137,10 @@ def _is_single_process_single_device(self) -> bool: return True def execute(self, trainer, function, *args, **kwargs): - executer = ( + executor = ( SingleProcessExecutor if self.cluster_environment.creates_processes_externally else DDPSubprocessExecutor )(self) - executer.execute(trainer, function, *args, **kwargs) + executor.execute(trainer, function, *args, **kwargs) def setup_environment(self) -> None: self.setup_distributed() diff --git a/pytorch_lightning/strategies/executors/ddp.py b/pytorch_lightning/strategies/executors/ddp.py index 8cf18e47d6e7a..0e1f5dac2a260 100644 --- a/pytorch_lightning/strategies/executors/ddp.py +++ b/pytorch_lightning/strategies/executors/ddp.py @@ -32,7 +32,7 @@ class DDPSubprocessExecutor(Executor): def __init__(self, strategy): - super().__init__(self, strategy) + super().__init__(strategy=strategy) self.interactive_ddp_procs = [] def execute(self, trainer, function, *args, **kwargs): @@ -48,7 +48,7 @@ def _call_children_scripts(self): # allow the user to pass the node rank os.environ["NODE_RANK"] = str(self.strategy.cluster_environment.node_rank()) - os.environ["LOCAL_RANK"] = str(self.startegy.cluster_environment.local_rank()) + os.environ["LOCAL_RANK"] = str(self.strategy.cluster_environment.local_rank()) # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c` # See https://docs.python.org/3/reference/import.html#main-spec diff --git a/pytorch_lightning/strategies/executors/ddp_spawn.py b/pytorch_lightning/strategies/executors/ddp_spawn.py index 200ec4caedeb8..f93f2924a3d80 100644 --- a/pytorch_lightning/strategies/executors/ddp_spawn.py +++ b/pytorch_lightning/strategies/executors/ddp_spawn.py @@ -15,6 +15,7 @@ from collections import UserList from typing import Any, NamedTuple, Optional +import numpy as np import torch import torch.multiprocessing as mp @@ -28,9 +29,6 @@ class DDPSpawnExecutor(Executor): - def __init__(self, strategy): - super().__init__(strategy) - def execute(self, trainer, function, *args, **kwargs): os.environ["MASTER_PORT"] = str(self.strategy.cluster_environment.main_port) context = mp.get_context("spawn") @@ -48,7 +46,7 @@ def _wrapped_function(self, process_idx, trainer, function, args, kwargs, return self.strategy._worker_setup(process_idx) results = function(*args, **kwargs) results = self._collect_rank_zero_results(trainer, results) - if self.local_rank == 0: + if self.strategy.local_rank == 0: return_queue.put(move_data_to_device(results, "cpu")) def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: @@ -59,22 +57,22 @@ def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer # TODO: pass also best score # load last weights if spawn_output.weights_path is not None: - ckpt = self.checkpoint_io.load_checkpoint( + ckpt = self.strategy.checkpoint_io.load_checkpoint( spawn_output.weights_path, map_location=(lambda storage, loc: storage) ) - self.lightning_module.load_state_dict(ckpt) - self.checkpoint_io.remove_checkpoint(spawn_output.weights_path) + self.strategy.lightning_module.load_state_dict(ckpt) + self.strategy.checkpoint_io.remove_checkpoint(spawn_output.weights_path) trainer.state = spawn_output.trainer_state # get the `callback_metrics` and set it to the trainer - if is_overridden("get_from_queue", self.lightning_module): + if is_overridden("get_from_queue", self.strategy.lightning_module): # only in case the user does not override it. # TODO: Remove the if in v1.7 - self.lightning_module.get_from_queue(spawn_output.extra) + self.strategy.lightning_module.get_from_queue(spawn_output.extra) self.get_from_queue(trainer, spawn_output.extra) - def _collect_ranku_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_debug("Finalizing the DDP spawn environment.") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -93,7 +91,7 @@ def _collect_ranku_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op # adds the `callback_metrics` to the queue extra = _FakeQueue() - if is_overridden("add_to_queue", self.lightning_module): + if is_overridden("add_to_queue", self.strategy.lightning_module): # TODO: Remove the if in v1.7 self.strategy.lightning_module.add_to_queue(extra) self.add_to_queue(trainer, extra) @@ -113,6 +111,18 @@ def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: ) # send as numpy to avoid issues with memory sharing queue.put(callback_metrics) + def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: + """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, + we cast back the data to ``torch.Tensor``. + + Args: + trainer: reference to the Trainer. + queue: the instance of the queue from where to get the data. + """ + # NOTE: `add_to_queue` needs to be called before + callback_metrics: dict = queue.get() + trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))) + class _FakeQueue(UserList): """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list.""" diff --git a/pytorch_lightning/strategies/executors/single_process.py b/pytorch_lightning/strategies/executors/single_process.py index 0eb558ab9ac34..4f391b3a71f11 100644 --- a/pytorch_lightning/strategies/executors/single_process.py +++ b/pytorch_lightning/strategies/executors/single_process.py @@ -15,5 +15,5 @@ class SingleProcessExecutor(Executor): - def execute(self, function, *args, **kwargs): + def execute(self, trainer, function, *args, **kwargs): return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/executors/tpu_spawn.py b/pytorch_lightning/strategies/executors/tpu_spawn.py index 1db5ec478359c..ffbd43eb21626 100644 --- a/pytorch_lightning/strategies/executors/tpu_spawn.py +++ b/pytorch_lightning/strategies/executors/tpu_spawn.py @@ -13,13 +13,18 @@ # limitations under the License. import time from multiprocessing.queues import SimpleQueue -from typing import Any, Callable +from typing import Any, Callable, Optional import torch.multiprocessing as mp +from build import os -from pytorch_lightning.strategies.executors.ddp_spawn import DDPSpawnExecutor +import pytorch_lightning as pl +from pytorch_lightning.strategies.executors.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnExecutor +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.distributed import rank_zero_debug +from pytorch_lightning.utilities.model_helpers import is_overridden if _TPU_AVAILABLE: import torch_xla.distributed.xla_multiprocessing as xmp @@ -49,3 +54,30 @@ def _wrapped_function( # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 if self.local_rank == 0: time.sleep(2) + + def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + rank_zero_debug("Finalizing the TPU spawn environment.") + checkpoint_callback = trainer.checkpoint_callback + best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + + # requires to compute the state_dict on all processes in case Metrics are present + state_dict = self.strategy.lightning_module.state_dict() + + # save the last weights + weights_path = None + if trainer.state.fn == TrainerFn.FITTING: + weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") + self.strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) + + # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training + if self.strategy.local_rank != 0: + return + + # adds the `callback_metrics` to the queue + extra = _FakeQueue() + if is_overridden("add_to_queue", self.strategy.lightning_module): + # TODO: Remove the if in v1.7 + self.strategy.lightning_module.add_to_queue(extra) + self.strategy.add_to_queue(trainer, extra) + + return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) diff --git a/pytorch_lightning/strategies/parallel.py b/pytorch_lightning/strategies/parallel.py index 9fcf823986ef0..cbe396b9c879f 100644 --- a/pytorch_lightning/strategies/parallel.py +++ b/pytorch_lightning/strategies/parallel.py @@ -127,8 +127,8 @@ def block_backward_sync(self): yield None def execute(self, trainer, function, *args, **kwargs): - executer = SingleProcessExecutor(self) - executer.execute(trainer, function, *args, **kwargs) + executor = SingleProcessExecutor(self) + executor.execute(trainer, function, *args, **kwargs) def teardown(self) -> None: self.cluster_environment.teardown() diff --git a/pytorch_lightning/strategies/single_device.py b/pytorch_lightning/strategies/single_device.py index 16c2901f27379..10f2a8f70329e 100644 --- a/pytorch_lightning/strategies/single_device.py +++ b/pytorch_lightning/strategies/single_device.py @@ -42,8 +42,8 @@ def __init__( self.world_size = 1 def execute(self, trainer, function, *args, **kwargs): - executer = SingleProcessExecutor(self) - executer.execute(trainer, function, *args, **kwargs) + executor = SingleProcessExecutor(self) + executor.execute(trainer, function, *args, **kwargs) def reduce(self, tensor: Any | torch.Tensor, *args: Any, **kwargs: Any) -> Any | torch.Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 17b0d2bf4afa6..aa44f0943407d 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -176,6 +176,7 @@ def barrier(self, name: Optional[str] = None) -> None: rendezvous(name) def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + raise MisconfigurationException("this should not run") rank_zero_debug("Finalizing the TPU spawn environment.") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1f0fa14dad0e8..854b8a792eedf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -668,7 +668,8 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: **kwargs: keyword arguments to be passed to `trainer_fn` """ try: - return trainer_fn(*args, **kwargs) + return self.strategy.execute(self, trainer_fn, *args, **kwargs) + # return trainer_fn(*args, **kwargs) # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") From e054d0e3b539dd08b4f303a1b900d0de80c51074 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 27 Jan 2022 18:28:14 +0530 Subject: [PATCH 04/36] return results --- pytorch_lightning/strategies/ddp.py | 2 +- pytorch_lightning/strategies/ddp_spawn.py | 2 +- pytorch_lightning/strategies/parallel.py | 2 +- pytorch_lightning/strategies/single_device.py | 2 +- pytorch_lightning/strategies/tpu_spawn.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 5a925442984b4..c7ed1c1607f82 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -140,7 +140,7 @@ def execute(self, trainer, function, *args, **kwargs): executor = ( SingleProcessExecutor if self.cluster_environment.creates_processes_externally else DDPSubprocessExecutor )(self) - executor.execute(trainer, function, *args, **kwargs) + return executor.execute(trainer, function, *args, **kwargs) def setup_environment(self) -> None: self.setup_distributed() diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index e3ef0111a7d75..03f8b7b015602 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -150,7 +150,7 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st def execute(self, trainer, function, *args, **kwargs): executor = DDPSpawnExecutor(self) - executor.execute(trainer, function, *args, **kwargs) + return executor.execute(trainer, function, *args, **kwargs) def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: """Spawn processes that run the given function. diff --git a/pytorch_lightning/strategies/parallel.py b/pytorch_lightning/strategies/parallel.py index cbe396b9c879f..3fd2e9d2533a9 100644 --- a/pytorch_lightning/strategies/parallel.py +++ b/pytorch_lightning/strategies/parallel.py @@ -128,7 +128,7 @@ def block_backward_sync(self): def execute(self, trainer, function, *args, **kwargs): executor = SingleProcessExecutor(self) - executor.execute(trainer, function, *args, **kwargs) + return executor.execute(trainer, function, *args, **kwargs) def teardown(self) -> None: self.cluster_environment.teardown() diff --git a/pytorch_lightning/strategies/single_device.py b/pytorch_lightning/strategies/single_device.py index 10f2a8f70329e..5b0363e7b40f3 100644 --- a/pytorch_lightning/strategies/single_device.py +++ b/pytorch_lightning/strategies/single_device.py @@ -43,7 +43,7 @@ def __init__( def execute(self, trainer, function, *args, **kwargs): executor = SingleProcessExecutor(self) - executor.execute(trainer, function, *args, **kwargs) + return executor.execute(trainer, function, *args, **kwargs) def reduce(self, tensor: Any | torch.Tensor, *args: Any, **kwargs: Any) -> Any | torch.Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index aa44f0943407d..9bd0f2df78fc7 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -247,7 +247,7 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st def execute(self, trainer, fn, *args, **kwargs): executor = TPUSpawnExecutor(self) - executor.execute(trainer, fn, *args, **kwargs) + return executor.execute(trainer, fn, *args, **kwargs) def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: raise MisconfigurationException("this should not run") From a2248b40d3822b3cb6bf74d563c6bffb6fae9138 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 27 Jan 2022 18:36:17 +0530 Subject: [PATCH 05/36] fix os --- pytorch_lightning/strategies/executors/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/executors/tpu_spawn.py b/pytorch_lightning/strategies/executors/tpu_spawn.py index ffbd43eb21626..ea3588b2d84de 100644 --- a/pytorch_lightning/strategies/executors/tpu_spawn.py +++ b/pytorch_lightning/strategies/executors/tpu_spawn.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import time from multiprocessing.queues import SimpleQueue from typing import Any, Callable, Optional import torch.multiprocessing as mp -from build import os import pytorch_lightning as pl from pytorch_lightning.strategies.executors.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnExecutor From 8751c3672ade7f2e3cde33034eec5ecace95eff5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 27 Jan 2022 14:07:09 -0500 Subject: [PATCH 06/36] fix DDP --- pytorch_lightning/strategies/executors/ddp.py | 1 + pytorch_lightning/trainer/trainer.py | 1 - tests/accelerators/test_common.py | 7 ++++++- tests/strategies/test_ddp_spawn_strategy.py | 16 ++++++++++++---- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/strategies/executors/ddp.py b/pytorch_lightning/strategies/executors/ddp.py index 0e1f5dac2a260..55ea4058f8798 100644 --- a/pytorch_lightning/strategies/executors/ddp.py +++ b/pytorch_lightning/strategies/executors/ddp.py @@ -37,6 +37,7 @@ def __init__(self, strategy): def execute(self, trainer, function, *args, **kwargs): self._call_children_scripts() + return function(*args, **kwargs) def _call_children_scripts(self): # bookkeeping of spawned processes diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 854b8a792eedf..19139dccf5ec5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -669,7 +669,6 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: return self.strategy.execute(self, trainer_fn, *args, **kwargs) - # return trainer_fn(*args, **kwargs) # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 18bb04bd0ae17..ecdcc743ea822 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -40,7 +40,12 @@ def test_evaluate(tmpdir, trainer_kwargs): dm = ClassifDataModule() model = CustomClassificationModelDP() trainer = Trainer( - default_root_dir=tmpdir, max_epochs=2, limit_train_batches=10, limit_val_batches=10, **trainer_kwargs + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + limit_test_batches=10, + **trainer_kwargs ) trainer.fit(model, datamodule=dm) diff --git a/tests/strategies/test_ddp_spawn_strategy.py b/tests/strategies/test_ddp_spawn_strategy.py index 2c0dff23aafd3..abf652362db90 100644 --- a/tests/strategies/test_ddp_spawn_strategy.py +++ b/tests/strategies/test_ddp_spawn_strategy.py @@ -20,6 +20,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.strategies import DDPSpawnStrategy +from pytorch_lightning.strategies.executors.ddp_spawn import DDPSpawnExecutor from pytorch_lightning.trainer.states import TrainerFn from tests.helpers.boring_model import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -82,16 +83,22 @@ def test_ddp_spawn_extra_parameters(tmpdir): assert model.test_val == "test_val" -class TestDDPSpawnStrategy(DDPSpawnStrategy): +class CustomDDPSpawnExecutor(DDPSpawnExecutor): def add_to_queue(self, trainer, queue) -> None: queue.put("new_test_val") return super().add_to_queue(trainer, queue) def get_from_queue(self, trainer: Trainer, queue) -> None: - self.new_test_val = queue.get() + self.strategy.new_test_val = queue.get() return super().get_from_queue(trainer, queue) +class TestDDPSpawnStrategy(DDPSpawnStrategy): + def execute(self, trainer, function, *args, **kwargs): + executor = CustomDDPSpawnExecutor(self) + return executor.execute(trainer, function, *args, **kwargs) + + @RunIf(skip_windows=True, skip_49370=True) def test_ddp_spawn_add_get_queue(tmpdir): """Tests add_to_queue/get_from_queue with DDPSpawnStrategy.""" @@ -148,13 +155,14 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): file.""" model = Mock(wraps=BoringModel(), spec=BoringModel) strategy = DDPSpawnStrategy() + executor = DDPSpawnExecutor(strategy) strategy.model = model trainer = Trainer(default_root_dir=tmpdir) trainer.state.fn = trainer_fn # pretend we are in a particular trainer state temp_file = Path(tmpdir, ".temp.ckpt") assert not temp_file.exists() - spawn_output = strategy._collect_rank_zero_results(trainer, {}) + spawn_output = executor._collect_rank_zero_results(trainer, {}) model.state_dict.assert_called_once() if trainer_fn == TrainerFn.FITTING: @@ -165,6 +173,6 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): assert not temp_file.exists() # <-- here would normally be the multiprocessing boundary - strategy._recover_results_in_main_process(spawn_output, trainer) + executor._recover_results_in_main_process(spawn_output, trainer) assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None) assert not temp_file.exists() From eea1fa7bef068248a95f7d0dd0f17f1343b05123 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 27 Jan 2022 14:12:56 -0500 Subject: [PATCH 07/36] rm redundant code --- pytorch_lightning/strategies/ddp_spawn.py | 139 +--------------------- pytorch_lightning/strategies/tpu_spawn.py | 62 +--------- 2 files changed, 6 insertions(+), 195 deletions(-) diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 03f8b7b015602..17d095edd96d5 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -13,14 +13,10 @@ # limitations under the License. import logging import os -from collections import UserList -from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union +from typing import Any, Dict, List, Optional, Union -import numpy as np import torch import torch.distributed -import torch.multiprocessing as mp from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel @@ -32,23 +28,19 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.executors.ddp_spawn import DDPSpawnExecutor from pytorch_lightning.strategies.parallel import ParallelStrategy -from pytorch_lightning.trainer.states import TrainerFn, TrainerState +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.distributed import ( init_dist_connection, - rank_zero_debug, rank_zero_only, ReduceOp, sync_ddp_if_available, ) from pytorch_lightning.utilities.enums import _StrategyType -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook @@ -152,34 +144,6 @@ def execute(self, trainer, function, *args, **kwargs): executor = DDPSpawnExecutor(self) return executor.execute(trainer, function, *args, **kwargs) - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: - """Spawn processes that run the given function. - - Args: - function: The function to spawn processes from. - *args: Optional positional arguments that will be passed to the function in addition to the process index. - These arguments must be pickleable. - **kwargs: Optional named arguments that will be passed to the function in addition to the process index. - These arguments must be pickleable. - - Return: - The output of the function of process 0. - """ - raise MisconfigurationException("this should not run") - os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) - context = mp.get_context("spawn") - return_queue = context.SimpleQueue() - mp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), nprocs=self.num_processes) - return return_queue.get() - - def _wrapped_function( - self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue - ) -> None: - self._worker_setup(process_idx) - result = function(*args, **kwargs) - if self.local_rank == 0: - return_queue.put(move_data_to_device(result, "cpu")) - def _worker_setup(self, process_idx: int): reset_seed() self.set_world_ranks(process_idx) @@ -225,57 +189,6 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: - raise MisconfigurationException("this should not run") - rank_zero_debug("Finalizing the DDP spawn environment.") - checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None - - # requires to compute the state_dict on all processes in case Metrics are present - state_dict = self.lightning_module.state_dict() - - if self.global_rank != 0: - return - - # save the last weights - weights_path = None - if trainer.state.fn == TrainerFn.FITTING: - weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") - self.checkpoint_io.save_checkpoint(state_dict, weights_path) - - # adds the `callback_metrics` to the queue - extra = _FakeQueue() - if is_overridden("add_to_queue", self.lightning_module): - # TODO: Remove the if in v1.7 - self.lightning_module.add_to_queue(extra) - self.add_to_queue(trainer, extra) - - return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) - - def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: - raise MisconfigurationException("this should not run") - # transfer back the best path to the trainer - if trainer.checkpoint_callback: - trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path - - # TODO: pass also best score - # load last weights - if spawn_output.weights_path is not None: - ckpt = self.checkpoint_io.load_checkpoint( - spawn_output.weights_path, map_location=(lambda storage, loc: storage) - ) - self.lightning_module.load_state_dict(ckpt) - self.checkpoint_io.remove_checkpoint(spawn_output.weights_path) - - trainer.state = spawn_output.trainer_state - - # get the `callback_metrics` and set it to the trainer - if is_overridden("get_from_queue", self.lightning_module): - # only in case the user does not override it. - # TODO: Remove the if in v1.7 - self.lightning_module.get_from_queue(spawn_output.extra) - self.get_from_queue(trainer, spawn_output.extra) - def barrier(self, *args, **kwargs) -> None: if not distributed_available(): return @@ -345,31 +258,6 @@ def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True - def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: - """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory - sharing, we cast the data to numpy. - - Args: - trainer: reference to the Trainer. - queue: the instance of the queue to append the data. - """ - callback_metrics: dict = apply_to_collection( - trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() - ) # send as numpy to avoid issues with memory sharing - queue.put(callback_metrics) - - def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: - """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, - we cast back the data to ``torch.Tensor``. - - Args: - trainer: reference to the Trainer. - queue: the instance of the queue from where to get the data. - """ - # NOTE: `add_to_queue` needs to be called before - callback_metrics: dict = queue.get() - trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))) - @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( @@ -392,24 +280,3 @@ def teardown(self) -> None: self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache() - - -class _FakeQueue(UserList): - """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list.""" - - def get(self) -> Any: - return self.pop(0) - - def put(self, item: Any) -> None: - self.append(item) - - def empty(self) -> bool: - return len(self) == 0 - - -class _SpawnOutput(NamedTuple): - best_model_path: Optional[_PATH] - weights_path: Optional[_PATH] - trainer_state: TrainerState - trainer_results: Any - extra: _FakeQueue diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 9bd0f2df78fc7..506605ee440a5 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -13,12 +13,9 @@ # limitations under the License. import io import os -import time -from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch -import torch.multiprocessing as mp from torch.nn import Module from torch.utils.data import DataLoader @@ -26,14 +23,12 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnStrategy +from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.strategies.executors.tpu_spawn import TPUSpawnExecutor from pytorch_lightning.trainer.connectors.data_connector import DataConnector -from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters -from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.data import has_len -from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed @@ -175,34 +170,6 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: - raise MisconfigurationException("this should not run") - rank_zero_debug("Finalizing the TPU spawn environment.") - checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None - - # requires to compute the state_dict on all processes in case Metrics are present - state_dict = self.lightning_module.state_dict() - - # save the last weights - weights_path = None - if trainer.state.fn == TrainerFn.FITTING: - weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") - self.checkpoint_io.save_checkpoint(state_dict, weights_path) - - # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training - if self.local_rank != 0: - return - - # adds the `callback_metrics` to the queue - extra = _FakeQueue() - if is_overridden("add_to_queue", self.lightning_module): - # TODO: Remove the if in v1.7 - self.lightning_module.add_to_queue(extra) - self.add_to_queue(trainer, extra) - - return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) - def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: return obj @@ -249,29 +216,6 @@ def execute(self, trainer, fn, *args, **kwargs): executor = TPUSpawnExecutor(self) return executor.execute(trainer, fn, *args, **kwargs) - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: - raise MisconfigurationException("this should not run") - context = mp.get_context(self.start_method or "fork") - return_queue = context.SimpleQueue() - xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) - return return_queue.get() - - def _wrapped_function( - self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue - ) -> None: - self._worker_setup(process_idx) - result = function(*args, **kwargs) - if self.local_rank == 0: - return_queue.put(move_data_to_device(result, "cpu")) - - # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 - self.barrier("end-process") - - # Ensure that the rank 0 process is the one exiting last - # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 - if self.local_rank == 0: - time.sleep(2) - def _worker_setup(self, process_idx: int): reset_seed() self.tpu_local_core_rank = xm.get_local_ordinal() From e46b8244f31774aef941c20b45ae694df75b8dd0 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 28 Jan 2022 00:51:56 +0530 Subject: [PATCH 08/36] fix import --- pytorch_lightning/core/lightning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 05cc8d87eaac6..fb6c7c2c3fa1b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1954,7 +1954,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: pl.strategies.ddp_spawn._FakeQueue) -> None: + def add_to_queue(self, queue: pl.strategies.executors.ddp_spawn._FakeQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1966,7 +1966,7 @@ def add_to_queue(self, queue: pl.strategies.ddp_spawn._FakeQueue) -> None: and will be removed in v1.7. """ - def get_from_queue(self, queue: pl.strategies.ddp_spawn._FakeQueue) -> None: + def get_from_queue(self, queue: pl.strategies.executors.ddp_spawn._FakeQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. From 4457eff11a3ed4df4f64724dc3548dd3a3d7c527 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 2 Feb 2022 17:48:28 +0530 Subject: [PATCH 09/36] executor -> launcher --- pytorch_lightning/core/lightning.py | 4 ++-- pytorch_lightning/strategies/ddp.py | 12 +++++------ pytorch_lightning/strategies/ddp_spawn.py | 8 ++++---- .../{executors => launchers}/__init__.py | 20 +++++++++---------- .../{executors => launchers}/base.py | 6 +++--- .../{executors => launchers}/ddp.py | 6 +++--- .../{executors => launchers}/ddp_spawn.py | 6 +++--- .../single_process.py | 6 +++--- .../{executors => launchers}/tpu_spawn.py | 6 +++--- pytorch_lightning/strategies/parallel.py | 8 ++++---- pytorch_lightning/strategies/single_device.py | 8 ++++---- pytorch_lightning/strategies/strategy.py | 4 ++-- pytorch_lightning/strategies/tpu_spawn.py | 8 ++++---- pytorch_lightning/trainer/trainer.py | 2 +- tests/strategies/test_ddp_spawn_strategy.py | 16 +++++++-------- 15 files changed, 60 insertions(+), 60 deletions(-) rename pytorch_lightning/strategies/{executors => launchers}/__init__.py (54%) rename pytorch_lightning/strategies/{executors => launchers}/base.py (85%) rename pytorch_lightning/strategies/{executors => launchers}/ddp.py (96%) rename pytorch_lightning/strategies/{executors => launchers}/ddp_spawn.py (97%) rename pytorch_lightning/strategies/{executors => launchers}/single_process.py (79%) rename pytorch_lightning/strategies/{executors => launchers}/tpu_spawn.py (94%) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index fb6c7c2c3fa1b..b38843f8e527e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1954,7 +1954,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: pl.strategies.executors.ddp_spawn._FakeQueue) -> None: + def add_to_queue(self, queue: pl.strategies.launchers.ddp_spawn._FakeQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1966,7 +1966,7 @@ def add_to_queue(self, queue: pl.strategies.executors.ddp_spawn._FakeQueue) -> N and will be removed in v1.7. """ - def get_from_queue(self, queue: pl.strategies.executors.ddp_spawn._FakeQueue) -> None: + def get_from_queue(self, queue: pl.strategies.launchers.ddp_spawn._FakeQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index c7ed1c1607f82..64df7d0b9c6d0 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -32,8 +32,8 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.executors.ddp import DDPSubprocessExecutor -from pytorch_lightning.strategies.executors.single_process import SingleProcessExecutor +from pytorch_lightning.strategies.launchers.ddp import DDPSubprocessLauncher +from pytorch_lightning.strategies.launchers.single_process import SingleProcessLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( @@ -136,11 +136,11 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self) -> bool: return True - def execute(self, trainer, function, *args, **kwargs): - executor = ( - SingleProcessExecutor if self.cluster_environment.creates_processes_externally else DDPSubprocessExecutor + def launch(self, trainer, function, *args, **kwargs): + launcher = ( + SingleProcessLauncher if self.cluster_environment.creates_processes_externally else DDPSubprocessLauncher )(self) - return executor.execute(trainer, function, *args, **kwargs) + return launcher.launch(trainer, function, *args, **kwargs) def setup_environment(self) -> None: self.setup_distributed() diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 17d095edd96d5..637eac083a2af 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -26,7 +26,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.executors.ddp_spawn import DDPSpawnExecutor +from pytorch_lightning.strategies.launchers.ddp_spawn import DDPSpawnLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn @@ -140,9 +140,9 @@ def set_world_ranks(self, process_idx: int = 0) -> None: def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return {"nprocs": self.num_processes} - def execute(self, trainer, function, *args, **kwargs): - executor = DDPSpawnExecutor(self) - return executor.execute(trainer, function, *args, **kwargs) + def launch(self, trainer, function, *args, **kwargs): + launcher = DDPSpawnLauncher(self) + return launcher.launch(trainer, function, *args, **kwargs) def _worker_setup(self, process_idx: int): reset_seed() diff --git a/pytorch_lightning/strategies/executors/__init__.py b/pytorch_lightning/strategies/launchers/__init__.py similarity index 54% rename from pytorch_lightning/strategies/executors/__init__.py rename to pytorch_lightning/strategies/launchers/__init__.py index 127bd5258ef70..d67f81d237937 100644 --- a/pytorch_lightning/strategies/executors/__init__.py +++ b/pytorch_lightning/strategies/launchers/__init__.py @@ -11,16 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.strategies.executors.base import Executor -from pytorch_lightning.strategies.executors.ddp import DDPSubprocessExecutor -from pytorch_lightning.strategies.executors.ddp_spawn import DDPSpawnExecutor -from pytorch_lightning.strategies.executors.single_process import SingleProcessExecutor -from pytorch_lightning.strategies.executors.tpu_spawn import TPUSpawnExecutor +from pytorch_lightning.strategies.launchers.base import Launcher +from pytorch_lightning.strategies.launchers.ddp import DDPSubprocessLauncher +from pytorch_lightning.strategies.launchers.ddp_spawn import DDPSpawnLauncher +from pytorch_lightning.strategies.launchers.single_process import SingleProcessLauncher +from pytorch_lightning.strategies.launchers.tpu_spawn import TPUSpawnLauncher __all__ = [ - "DDPSpawnExecutor", - "DDPSubprocessExecutor", - "Executor", - "SingleProcessExecutor", - "TPUSpawnExecutor", + "DDPSpawnLauncher", + "DDPSubprocessLauncher", + "Launcher", + "SingleProcessLauncher", + "TPUSpawnLauncher", ] diff --git a/pytorch_lightning/strategies/executors/base.py b/pytorch_lightning/strategies/launchers/base.py similarity index 85% rename from pytorch_lightning/strategies/executors/base.py rename to pytorch_lightning/strategies/launchers/base.py index 76fd10157c4a2..426171d92f2a8 100644 --- a/pytorch_lightning/strategies/executors/base.py +++ b/pytorch_lightning/strategies/launchers/base.py @@ -14,10 +14,10 @@ from abc import ABC, abstractmethod -class Executor(ABC): +class Launcher(ABC): def __init__(self, strategy): self.strategy = strategy @abstractmethod - def execute(self, trainer, fn, *args, **kwargs) -> bool: - """Executes the proceses.""" + def launch(self, trainer, fn, *args, **kwargs) -> bool: + """Launches the proceses.""" diff --git a/pytorch_lightning/strategies/executors/ddp.py b/pytorch_lightning/strategies/launchers/ddp.py similarity index 96% rename from pytorch_lightning/strategies/executors/ddp.py rename to pytorch_lightning/strategies/launchers/ddp.py index 55ea4058f8798..cb0e3c7ef107b 100644 --- a/pytorch_lightning/strategies/executors/ddp.py +++ b/pytorch_lightning/strategies/launchers/ddp.py @@ -20,7 +20,7 @@ import __main__ import numpy as np -from pytorch_lightning.strategies.executors.base import Executor +from pytorch_lightning.strategies.launchers.base import Launcher from pytorch_lightning.utilities import _HYDRA_AVAILABLE if _HYDRA_AVAILABLE: @@ -30,12 +30,12 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException -class DDPSubprocessExecutor(Executor): +class DDPSubprocessLauncher(Launcher): def __init__(self, strategy): super().__init__(strategy=strategy) self.interactive_ddp_procs = [] - def execute(self, trainer, function, *args, **kwargs): + def launch(self, trainer, function, *args, **kwargs): self._call_children_scripts() return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/executors/ddp_spawn.py b/pytorch_lightning/strategies/launchers/ddp_spawn.py similarity index 97% rename from pytorch_lightning/strategies/executors/ddp_spawn.py rename to pytorch_lightning/strategies/launchers/ddp_spawn.py index f93f2924a3d80..e567c261f6d4d 100644 --- a/pytorch_lightning/strategies/executors/ddp_spawn.py +++ b/pytorch_lightning/strategies/launchers/ddp_spawn.py @@ -20,7 +20,7 @@ import torch.multiprocessing as mp import pytorch_lightning as pl -from pytorch_lightning.strategies.executors.base import Executor +from pytorch_lightning.strategies.launchers.base import Launcher from pytorch_lightning.trainer.states import TrainerFn, TrainerState from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import rank_zero_debug @@ -28,8 +28,8 @@ from pytorch_lightning.utilities.types import _PATH -class DDPSpawnExecutor(Executor): - def execute(self, trainer, function, *args, **kwargs): +class DDPSpawnLauncher(Launcher): + def launch(self, trainer, function, *args, **kwargs): os.environ["MASTER_PORT"] = str(self.strategy.cluster_environment.main_port) context = mp.get_context("spawn") return_queue = context.SimpleQueue() diff --git a/pytorch_lightning/strategies/executors/single_process.py b/pytorch_lightning/strategies/launchers/single_process.py similarity index 79% rename from pytorch_lightning/strategies/executors/single_process.py rename to pytorch_lightning/strategies/launchers/single_process.py index 4f391b3a71f11..65cbfdac51c4c 100644 --- a/pytorch_lightning/strategies/executors/single_process.py +++ b/pytorch_lightning/strategies/launchers/single_process.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.strategies.executors.base import Executor +from pytorch_lightning.strategies.launchers.base import Launcher -class SingleProcessExecutor(Executor): - def execute(self, trainer, function, *args, **kwargs): +class SingleProcessLauncher(Launcher): + def launch(self, trainer, function, *args, **kwargs): return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/executors/tpu_spawn.py b/pytorch_lightning/strategies/launchers/tpu_spawn.py similarity index 94% rename from pytorch_lightning/strategies/executors/tpu_spawn.py rename to pytorch_lightning/strategies/launchers/tpu_spawn.py index ea3588b2d84de..66f4b6ae3f06e 100644 --- a/pytorch_lightning/strategies/executors/tpu_spawn.py +++ b/pytorch_lightning/strategies/launchers/tpu_spawn.py @@ -19,7 +19,7 @@ import torch.multiprocessing as mp import pytorch_lightning as pl -from pytorch_lightning.strategies.executors.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnExecutor +from pytorch_lightning.strategies.launchers.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnLauncher from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device @@ -32,8 +32,8 @@ xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 -class TPUSpawnExecutor(DDPSpawnExecutor): - def execute(self, trainer, function, *args, **kwargs): +class TPUSpawnLauncher(DDPSpawnLauncher): + def launch(self, trainer, function, *args, **kwargs): context = mp.get_context(self.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) diff --git a/pytorch_lightning/strategies/parallel.py b/pytorch_lightning/strategies/parallel.py index 3fd2e9d2533a9..74a24fc78c87f 100644 --- a/pytorch_lightning/strategies/parallel.py +++ b/pytorch_lightning/strategies/parallel.py @@ -24,7 +24,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.executors.single_process import SingleProcessExecutor +from pytorch_lightning.strategies.launchers.single_process import SingleProcessLauncher from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp @@ -126,9 +126,9 @@ def block_backward_sync(self): else: yield None - def execute(self, trainer, function, *args, **kwargs): - executor = SingleProcessExecutor(self) - return executor.execute(trainer, function, *args, **kwargs) + def launch(self, trainer, function, *args, **kwargs): + launcher = SingleProcessLauncher(self) + return launcher.launch(trainer, function, *args, **kwargs) def teardown(self) -> None: self.cluster_environment.teardown() diff --git a/pytorch_lightning/strategies/single_device.py b/pytorch_lightning/strategies/single_device.py index 5b0363e7b40f3..b8ae46ffd956f 100644 --- a/pytorch_lightning/strategies/single_device.py +++ b/pytorch_lightning/strategies/single_device.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.executors.single_process import SingleProcessExecutor +from pytorch_lightning.strategies.launchers.single_process import SingleProcessLauncher from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.types import _DEVICE @@ -41,9 +41,9 @@ def __init__( self.local_rank = 0 self.world_size = 1 - def execute(self, trainer, function, *args, **kwargs): - executor = SingleProcessExecutor(self) - return executor.execute(trainer, function, *args, **kwargs) + def launch(self, trainer, function, *args, **kwargs): + launcher = SingleProcessLauncher(self) + return launcher.launch(trainer, function, *args, **kwargs) def reduce(self, tensor: Any | torch.Tensor, *args: Any, **kwargs: Any) -> Any | torch.Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index 0ca635df34924..045fed34b98de 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -140,8 +140,8 @@ def setup_precision_plugin(self) -> None: self.lr_schedulers = schedulers @abstractmethod - def execute(self, trainer, function, *args, **kwargs): - """Executes the proceses using an Executor.""" + def launch(self, trainer, function, *args, **kwargs): + """Launch the proceses using a Launcher.""" def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the appropriate device if needed.""" diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 506605ee440a5..05219ca677839 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -24,7 +24,7 @@ from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy -from pytorch_lightning.strategies.executors.tpu_spawn import TPUSpawnExecutor +from pytorch_lightning.strategies.launchers.tpu_spawn import TPUSpawnLauncher from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters from pytorch_lightning.utilities.data import has_len @@ -212,9 +212,9 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st "start_method": self.start_method, } - def execute(self, trainer, fn, *args, **kwargs): - executor = TPUSpawnExecutor(self) - return executor.execute(trainer, fn, *args, **kwargs) + def launch(self, trainer, fn, *args, **kwargs): + launcher = TPUSpawnLauncher(self) + return launcher.launch(trainer, fn, *args, **kwargs) def _worker_setup(self, process_idx: int): reset_seed() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 19139dccf5ec5..952b8870da47c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -668,7 +668,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: **kwargs: keyword arguments to be passed to `trainer_fn` """ try: - return self.strategy.execute(self, trainer_fn, *args, **kwargs) + return self.strategy.launch(self, trainer_fn, *args, **kwargs) # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") diff --git a/tests/strategies/test_ddp_spawn_strategy.py b/tests/strategies/test_ddp_spawn_strategy.py index abf652362db90..3bd953b3a1d7c 100644 --- a/tests/strategies/test_ddp_spawn_strategy.py +++ b/tests/strategies/test_ddp_spawn_strategy.py @@ -20,7 +20,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.strategies import DDPSpawnStrategy -from pytorch_lightning.strategies.executors.ddp_spawn import DDPSpawnExecutor +from pytorch_lightning.strategies.launchers.ddp_spawn import DDPSpawnLauncher from pytorch_lightning.trainer.states import TrainerFn from tests.helpers.boring_model import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -83,7 +83,7 @@ def test_ddp_spawn_extra_parameters(tmpdir): assert model.test_val == "test_val" -class CustomDDPSpawnExecutor(DDPSpawnExecutor): +class CustomDDPSpawnLauncher(DDPSpawnLauncher): def add_to_queue(self, trainer, queue) -> None: queue.put("new_test_val") return super().add_to_queue(trainer, queue) @@ -94,9 +94,9 @@ def get_from_queue(self, trainer: Trainer, queue) -> None: class TestDDPSpawnStrategy(DDPSpawnStrategy): - def execute(self, trainer, function, *args, **kwargs): - executor = CustomDDPSpawnExecutor(self) - return executor.execute(trainer, function, *args, **kwargs) + def launch(self, trainer, function, *args, **kwargs): + launcher = CustomDDPSpawnLauncher(self) + return launcher.launch(trainer, function, *args, **kwargs) @RunIf(skip_windows=True, skip_49370=True) @@ -155,14 +155,14 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): file.""" model = Mock(wraps=BoringModel(), spec=BoringModel) strategy = DDPSpawnStrategy() - executor = DDPSpawnExecutor(strategy) + launcher = DDPSpawnLauncher(strategy) strategy.model = model trainer = Trainer(default_root_dir=tmpdir) trainer.state.fn = trainer_fn # pretend we are in a particular trainer state temp_file = Path(tmpdir, ".temp.ckpt") assert not temp_file.exists() - spawn_output = executor._collect_rank_zero_results(trainer, {}) + spawn_output = launcher._collect_rank_zero_results(trainer, {}) model.state_dict.assert_called_once() if trainer_fn == TrainerFn.FITTING: @@ -173,6 +173,6 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): assert not temp_file.exists() # <-- here would normally be the multiprocessing boundary - executor._recover_results_in_main_process(spawn_output, trainer) + launcher._recover_results_in_main_process(spawn_output, trainer) assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None) assert not temp_file.exists() From 184b22224a9bf8164cd6eeca90d1ff21961ce0e3 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 2 Feb 2022 18:52:09 +0530 Subject: [PATCH 10/36] reduce arguments --- pytorch_lightning/strategies/ddp.py | 13 +++--- pytorch_lightning/strategies/ddp_spawn.py | 6 +-- .../strategies/launchers/base.py | 5 +-- pytorch_lightning/strategies/launchers/ddp.py | 35 ++++++--------- .../strategies/launchers/ddp_spawn.py | 44 ++++++++++++------- .../strategies/launchers/single_process.py | 3 +- .../strategies/launchers/tpu_spawn.py | 40 +++++++++++------ pytorch_lightning/strategies/parallel.py | 6 +-- pytorch_lightning/strategies/single_device.py | 6 +-- pytorch_lightning/strategies/strategy.py | 2 +- pytorch_lightning/strategies/tpu_spawn.py | 6 +-- pytorch_lightning/trainer/trainer.py | 3 +- tests/strategies/test_ddp_spawn_strategy.py | 14 +++--- 13 files changed, 100 insertions(+), 83 deletions(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 64df7d0b9c6d0..9c031dd3ed54b 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -136,13 +136,16 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self) -> bool: return True - def launch(self, trainer, function, *args, **kwargs): - launcher = ( - SingleProcessLauncher if self.cluster_environment.creates_processes_externally else DDPSubprocessLauncher - )(self) - return launcher.launch(trainer, function, *args, **kwargs) + def launch(self, function, *args, **kwargs): + if self.cluster_environment.creates_processes_externally: + launcher = SingleProcessLauncher() + else: + launcher = DDPSubprocessLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + + return launcher.launch(function, *args, **kwargs) def setup_environment(self) -> None: + self._rank_0_has_called_call_children_scripts = not self.cluster_environment.creates_processes_externally self.setup_distributed() super().setup_environment() diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 637eac083a2af..28d7aef21a136 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -140,9 +140,9 @@ def set_world_ranks(self, process_idx: int = 0) -> None: def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return {"nprocs": self.num_processes} - def launch(self, trainer, function, *args, **kwargs): - launcher = DDPSpawnLauncher(self) - return launcher.launch(trainer, function, *args, **kwargs) + def launch(self, function, *args, **kwargs): + launcher = DDPSpawnLauncher() + return launcher.launch(function, *args, **kwargs) def _worker_setup(self, process_idx: int): reset_seed() diff --git a/pytorch_lightning/strategies/launchers/base.py b/pytorch_lightning/strategies/launchers/base.py index 426171d92f2a8..4e8073aeb55b8 100644 --- a/pytorch_lightning/strategies/launchers/base.py +++ b/pytorch_lightning/strategies/launchers/base.py @@ -15,9 +15,6 @@ class Launcher(ABC): - def __init__(self, strategy): - self.strategy = strategy - @abstractmethod - def launch(self, trainer, fn, *args, **kwargs) -> bool: + def launch(self, fn, *args, **kwargs) -> bool: """Launches the proceses.""" diff --git a/pytorch_lightning/strategies/launchers/ddp.py b/pytorch_lightning/strategies/launchers/ddp.py index cb0e3c7ef107b..7b919e0b7ea35 100644 --- a/pytorch_lightning/strategies/launchers/ddp.py +++ b/pytorch_lightning/strategies/launchers/ddp.py @@ -27,15 +27,17 @@ from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path -from pytorch_lightning.utilities.exceptions import MisconfigurationException - class DDPSubprocessLauncher(Launcher): - def __init__(self, strategy): - super().__init__(strategy=strategy) + def __init__(self, cluster_environment, num_processes, num_nodes): + super().__init__() + self.cluster_environment = cluster_environment + self.num_processes = num_processes + self.num_nodes = num_nodes self.interactive_ddp_procs = [] - def launch(self, trainer, function, *args, **kwargs): + def launch(self, function, *args, **kwargs): + kwargs.pop("trainer") self._call_children_scripts() return function(*args, **kwargs) @@ -44,12 +46,12 @@ def _call_children_scripts(self): self._check_can_spawn_children() # DDP Environment variables - os.environ["MASTER_ADDR"] = self.strategy.cluster_environment.main_address - os.environ["MASTER_PORT"] = str(self.strategy.cluster_environment.main_port) + os.environ["MASTER_ADDR"] = self.cluster_environment.main_address + os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) # allow the user to pass the node rank - os.environ["NODE_RANK"] = str(self.strategy.cluster_environment.node_rank()) - os.environ["LOCAL_RANK"] = str(self.strategy.cluster_environment.local_rank()) + os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) + os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c` # See https://docs.python.org/3/reference/import.html#main-spec @@ -71,18 +73,11 @@ def _call_children_scripts(self): else: # Script called as `python -m a.b.c` command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:] - # the visible devices tell us how many GPUs we want to use. - # when the trainer script was called the device has already been scoped by the time - # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone - # but forward the GPUs selected via environment variables - if self.strategy.parallel_devices is None: - raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)") - - os.environ["WORLD_SIZE"] = f"{self.strategy.num_processes * self.strategy.num_nodes}" + os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}" self.interactive_ddp_procs = [] - for local_rank in range(1, self.strategy.num_processes): + for local_rank in range(1, self.num_processes): env_copy = os.environ.copy() env_copy["LOCAL_RANK"] = f"{local_rank}" @@ -106,10 +101,8 @@ def _call_children_scripts(self): delay = np.random.uniform(1, 5, 1)[0] sleep(delay) - self.strategy._rank_0_has_called_call_children_scripts = True - def _check_can_spawn_children(self): - if self.strategy.local_rank != 0: + if self.cluster_environment.local_rank() != 0: raise RuntimeError( "Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen." " Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user," diff --git a/pytorch_lightning/strategies/launchers/ddp_spawn.py b/pytorch_lightning/strategies/launchers/ddp_spawn.py index e567c261f6d4d..01ee443b20745 100644 --- a/pytorch_lightning/strategies/launchers/ddp_spawn.py +++ b/pytorch_lightning/strategies/launchers/ddp_spawn.py @@ -13,7 +13,8 @@ # limitations under the License. import os from collections import UserList -from typing import Any, NamedTuple, Optional +from multiprocessing.queues import SimpleQueue +from typing import Any, Callable, NamedTuple, Optional import numpy as np import torch @@ -29,24 +30,33 @@ class DDPSpawnLauncher(Launcher): - def launch(self, trainer, function, *args, **kwargs): - os.environ["MASTER_PORT"] = str(self.strategy.cluster_environment.main_port) + def launch(self, function, *args, **kwargs): + trainer = kwargs.pop("trainer") + os.environ["MASTER_PORT"] = str(trainer.strategy.cluster_environment.main_port) context = mp.get_context("spawn") return_queue = context.SimpleQueue() mp.spawn( self._wrapped_function, args=(trainer, function, args, kwargs, return_queue), - nprocs=self.strategy.num_processes, + nprocs=trainer.strategy.num_processes, ) spawn_output = return_queue.get() self._recover_results_in_main_process(spawn_output, trainer) return spawn_output.trainer_results - def _wrapped_function(self, process_idx, trainer, function, args, kwargs, return_queue): - self.strategy._worker_setup(process_idx) + def _wrapped_function( + self, + process_idx: int, + trainer: "pl.Trainer", + function: Callable, + args: Any, + kwargs: Any, + return_queue: SimpleQueue, + ): + trainer.strategy._worker_setup(process_idx) results = function(*args, **kwargs) results = self._collect_rank_zero_results(trainer, results) - if self.strategy.local_rank == 0: + if trainer.strategy.local_rank == 0: return_queue.put(move_data_to_device(results, "cpu")) def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: @@ -57,19 +67,19 @@ def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer # TODO: pass also best score # load last weights if spawn_output.weights_path is not None: - ckpt = self.strategy.checkpoint_io.load_checkpoint( + ckpt = trainer.strategy.checkpoint_io.load_checkpoint( spawn_output.weights_path, map_location=(lambda storage, loc: storage) ) - self.strategy.lightning_module.load_state_dict(ckpt) - self.strategy.checkpoint_io.remove_checkpoint(spawn_output.weights_path) + trainer.lightning_module.load_state_dict(ckpt) + trainer.strategy.checkpoint_io.remove_checkpoint(spawn_output.weights_path) trainer.state = spawn_output.trainer_state # get the `callback_metrics` and set it to the trainer - if is_overridden("get_from_queue", self.strategy.lightning_module): + if is_overridden("get_from_queue", trainer.lightning_module): # only in case the user does not override it. # TODO: Remove the if in v1.7 - self.strategy.lightning_module.get_from_queue(spawn_output.extra) + trainer.lightning_module.get_from_queue(spawn_output.extra) self.get_from_queue(trainer, spawn_output.extra) def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: @@ -78,22 +88,22 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None # requires to compute the state_dict on all processes in case Metrics are present - state_dict = self.strategy.lightning_module.state_dict() + state_dict = trainer.lightning_module.state_dict() - if self.strategy.global_rank != 0: + if trainer.strategy.global_rank != 0: return # save the last weights weights_path = None if trainer.state.fn == TrainerFn.FITTING: weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") - self.strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) + trainer.strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) # adds the `callback_metrics` to the queue extra = _FakeQueue() - if is_overridden("add_to_queue", self.strategy.lightning_module): + if is_overridden("add_to_queue", trainer.lightning_module): # TODO: Remove the if in v1.7 - self.strategy.lightning_module.add_to_queue(extra) + trainer.lightning_module.add_to_queue(extra) self.add_to_queue(trainer, extra) return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) diff --git a/pytorch_lightning/strategies/launchers/single_process.py b/pytorch_lightning/strategies/launchers/single_process.py index 65cbfdac51c4c..2914cc7302913 100644 --- a/pytorch_lightning/strategies/launchers/single_process.py +++ b/pytorch_lightning/strategies/launchers/single_process.py @@ -15,5 +15,6 @@ class SingleProcessLauncher(Launcher): - def launch(self, trainer, function, *args, **kwargs): + def launch(self, function, *args, **kwargs): + kwargs.pop("trainer") return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/launchers/tpu_spawn.py b/pytorch_lightning/strategies/launchers/tpu_spawn.py index 66f4b6ae3f06e..85d613f0f6f0e 100644 --- a/pytorch_lightning/strategies/launchers/tpu_spawn.py +++ b/pytorch_lightning/strategies/launchers/tpu_spawn.py @@ -33,26 +33,38 @@ class TPUSpawnLauncher(DDPSpawnLauncher): - def launch(self, trainer, function, *args, **kwargs): + def launch(self, function, *args, **kwargs): + trainer = kwargs.pop("trainer") context = mp.get_context(self.start_method or "fork") return_queue = context.SimpleQueue() - xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) + xmp.spawn( + self._wrapped_function, + args=(trainer, function, args, kwargs, return_queue), + **trainer.strategy.get_mp_spawn_kwargs() + ) return return_queue.get() def _wrapped_function( - self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue + self, + process_idx: int, + trainer: "pl.Trainer", + function: Callable, + args: Any, + kwargs: Any, + return_queue: SimpleQueue, ) -> None: self._worker_setup(process_idx) - result = function(*args, **kwargs) - if self.local_rank == 0: - return_queue.put(move_data_to_device(result, "cpu")) + results = function(*args, **kwargs) + results = self._collect_rank_zero_results(trainer, results) + if trainer.strategy.local_rank == 0: + return_queue.put(move_data_to_device(results, "cpu")) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 - self.barrier("end-process") + trainer.strategy.barrier("end-process") # Ensure that the rank 0 process is the one exiting last # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 - if self.local_rank == 0: + if trainer.strategy.local_rank == 0: time.sleep(2) def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: @@ -61,23 +73,23 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None # requires to compute the state_dict on all processes in case Metrics are present - state_dict = self.strategy.lightning_module.state_dict() + state_dict = trainer.lightning_module.state_dict() # save the last weights weights_path = None if trainer.state.fn == TrainerFn.FITTING: weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") - self.strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) + trainer.strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training - if self.strategy.local_rank != 0: + if trainer.strategy.local_rank != 0: return # adds the `callback_metrics` to the queue extra = _FakeQueue() - if is_overridden("add_to_queue", self.strategy.lightning_module): + if is_overridden("add_to_queue", trainer.lightning_module): # TODO: Remove the if in v1.7 - self.strategy.lightning_module.add_to_queue(extra) - self.strategy.add_to_queue(trainer, extra) + trainer.lightning_module.add_to_queue(extra) + self.add_to_queue(trainer, extra) return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) diff --git a/pytorch_lightning/strategies/parallel.py b/pytorch_lightning/strategies/parallel.py index 74a24fc78c87f..1025ae62b9e78 100644 --- a/pytorch_lightning/strategies/parallel.py +++ b/pytorch_lightning/strategies/parallel.py @@ -126,9 +126,9 @@ def block_backward_sync(self): else: yield None - def launch(self, trainer, function, *args, **kwargs): - launcher = SingleProcessLauncher(self) - return launcher.launch(trainer, function, *args, **kwargs) + def launch(self, function, *args, **kwargs): + launcher = SingleProcessLauncher() + return launcher.launch(function, *args, **kwargs) def teardown(self) -> None: self.cluster_environment.teardown() diff --git a/pytorch_lightning/strategies/single_device.py b/pytorch_lightning/strategies/single_device.py index b8ae46ffd956f..c053ea453ecd7 100644 --- a/pytorch_lightning/strategies/single_device.py +++ b/pytorch_lightning/strategies/single_device.py @@ -41,9 +41,9 @@ def __init__( self.local_rank = 0 self.world_size = 1 - def launch(self, trainer, function, *args, **kwargs): - launcher = SingleProcessLauncher(self) - return launcher.launch(trainer, function, *args, **kwargs) + def launch(self, function, *args, **kwargs): + launcher = SingleProcessLauncher() + return launcher.launch(function, *args, **kwargs) def reduce(self, tensor: Any | torch.Tensor, *args: Any, **kwargs: Any) -> Any | torch.Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index 045fed34b98de..5084d319595c7 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -140,7 +140,7 @@ def setup_precision_plugin(self) -> None: self.lr_schedulers = schedulers @abstractmethod - def launch(self, trainer, function, *args, **kwargs): + def launch(self, function, *args, **kwargs): """Launch the proceses using a Launcher.""" def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 05219ca677839..79165aa12d802 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -212,9 +212,9 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st "start_method": self.start_method, } - def launch(self, trainer, fn, *args, **kwargs): - launcher = TPUSpawnLauncher(self) - return launcher.launch(trainer, fn, *args, **kwargs) + def launch(self, fn, *args, **kwargs): + launcher = TPUSpawnLauncher() + return launcher.launch(fn, *args, **kwargs) def _worker_setup(self, process_idx: int): reset_seed() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 952b8870da47c..8fc7b1e53baf4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -668,7 +668,8 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: **kwargs: keyword arguments to be passed to `trainer_fn` """ try: - return self.strategy.launch(self, trainer_fn, *args, **kwargs) + kwargs["trainer"] = self + return self.strategy.launch(trainer_fn, *args, **kwargs) # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") diff --git a/tests/strategies/test_ddp_spawn_strategy.py b/tests/strategies/test_ddp_spawn_strategy.py index 3bd953b3a1d7c..3da24ea1a7c88 100644 --- a/tests/strategies/test_ddp_spawn_strategy.py +++ b/tests/strategies/test_ddp_spawn_strategy.py @@ -89,14 +89,14 @@ def add_to_queue(self, trainer, queue) -> None: return super().add_to_queue(trainer, queue) def get_from_queue(self, trainer: Trainer, queue) -> None: - self.strategy.new_test_val = queue.get() + trainer.strategy.new_test_val = queue.get() return super().get_from_queue(trainer, queue) class TestDDPSpawnStrategy(DDPSpawnStrategy): - def launch(self, trainer, function, *args, **kwargs): - launcher = CustomDDPSpawnLauncher(self) - return launcher.launch(trainer, function, *args, **kwargs) + def launch(self, function, *args, **kwargs): + launcher = CustomDDPSpawnLauncher() + return launcher.launch(function, *args, **kwargs) @RunIf(skip_windows=True, skip_49370=True) @@ -155,9 +155,9 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): file.""" model = Mock(wraps=BoringModel(), spec=BoringModel) strategy = DDPSpawnStrategy() - launcher = DDPSpawnLauncher(strategy) - strategy.model = model - trainer = Trainer(default_root_dir=tmpdir) + launcher = DDPSpawnLauncher() + trainer = Trainer(default_root_dir=tmpdir, strategy=strategy) + trainer.strategy.connect(model) trainer.state.fn = trainer_fn # pretend we are in a particular trainer state temp_file = Path(tmpdir, ".temp.ckpt") From a2f8d0a4472e0f6f7ad23857ff615497d9616455 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 2 Feb 2022 21:01:45 +0530 Subject: [PATCH 11/36] fix tpu and mypy --- pytorch_lightning/strategies/launchers/ddp.py | 11 ++++++----- pytorch_lightning/strategies/launchers/ddp_spawn.py | 4 ++-- .../strategies/launchers/single_process.py | 4 +++- pytorch_lightning/strategies/launchers/tpu_spawn.py | 4 ++-- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/strategies/launchers/ddp.py b/pytorch_lightning/strategies/launchers/ddp.py index 7b919e0b7ea35..4c1ff1017c0f7 100644 --- a/pytorch_lightning/strategies/launchers/ddp.py +++ b/pytorch_lightning/strategies/launchers/ddp.py @@ -15,11 +15,12 @@ import subprocess import sys from time import sleep -from typing import Optional +from typing import Any, Callable, Optional import __main__ import numpy as np +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.strategies.launchers.base import Launcher from pytorch_lightning.utilities import _HYDRA_AVAILABLE @@ -29,19 +30,19 @@ class DDPSubprocessLauncher(Launcher): - def __init__(self, cluster_environment, num_processes, num_nodes): + def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None: super().__init__() self.cluster_environment = cluster_environment self.num_processes = num_processes self.num_nodes = num_nodes self.interactive_ddp_procs = [] - def launch(self, function, *args, **kwargs): + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: kwargs.pop("trainer") self._call_children_scripts() return function(*args, **kwargs) - def _call_children_scripts(self): + def _call_children_scripts(self) -> None: # bookkeeping of spawned processes self._check_can_spawn_children() @@ -101,7 +102,7 @@ def _call_children_scripts(self): delay = np.random.uniform(1, 5, 1)[0] sleep(delay) - def _check_can_spawn_children(self): + def _check_can_spawn_children(self) -> None: if self.cluster_environment.local_rank() != 0: raise RuntimeError( "Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen." diff --git a/pytorch_lightning/strategies/launchers/ddp_spawn.py b/pytorch_lightning/strategies/launchers/ddp_spawn.py index 01ee443b20745..fcc8ee4e5ac89 100644 --- a/pytorch_lightning/strategies/launchers/ddp_spawn.py +++ b/pytorch_lightning/strategies/launchers/ddp_spawn.py @@ -30,7 +30,7 @@ class DDPSpawnLauncher(Launcher): - def launch(self, function, *args, **kwargs): + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: trainer = kwargs.pop("trainer") os.environ["MASTER_PORT"] = str(trainer.strategy.cluster_environment.main_port) context = mp.get_context("spawn") @@ -52,7 +52,7 @@ def _wrapped_function( args: Any, kwargs: Any, return_queue: SimpleQueue, - ): + ) -> None: trainer.strategy._worker_setup(process_idx) results = function(*args, **kwargs) results = self._collect_rank_zero_results(trainer, results) diff --git a/pytorch_lightning/strategies/launchers/single_process.py b/pytorch_lightning/strategies/launchers/single_process.py index 2914cc7302913..5294359aac42f 100644 --- a/pytorch_lightning/strategies/launchers/single_process.py +++ b/pytorch_lightning/strategies/launchers/single_process.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Callable + from pytorch_lightning.strategies.launchers.base import Launcher class SingleProcessLauncher(Launcher): - def launch(self, function, *args, **kwargs): + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: kwargs.pop("trainer") return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/launchers/tpu_spawn.py b/pytorch_lightning/strategies/launchers/tpu_spawn.py index 85d613f0f6f0e..a614ed1fbce42 100644 --- a/pytorch_lightning/strategies/launchers/tpu_spawn.py +++ b/pytorch_lightning/strategies/launchers/tpu_spawn.py @@ -33,9 +33,9 @@ class TPUSpawnLauncher(DDPSpawnLauncher): - def launch(self, function, *args, **kwargs): + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: trainer = kwargs.pop("trainer") - context = mp.get_context(self.start_method or "fork") + context = mp.get_context(trainer.strategy.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn( self._wrapped_function, From 72b1bb9250b6b0a1214f4e74ad4cc035d2f17db4 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 2 Feb 2022 23:20:31 +0530 Subject: [PATCH 12/36] fix tpu and mypy --- pytorch_lightning/strategies/ddp.py | 4 ++-- pytorch_lightning/strategies/ddp_spawn.py | 4 ++-- pytorch_lightning/strategies/launchers/base.py | 3 ++- pytorch_lightning/strategies/launchers/tpu_spawn.py | 2 +- pytorch_lightning/strategies/parallel.py | 4 ++-- pytorch_lightning/strategies/single_device.py | 4 ++-- pytorch_lightning/strategies/strategy.py | 2 +- pytorch_lightning/strategies/tpu_spawn.py | 4 ++-- 8 files changed, 14 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 9c031dd3ed54b..28f33e2242f96 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -18,7 +18,7 @@ import tempfile import time from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.distributed @@ -136,7 +136,7 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self) -> bool: return True - def launch(self, function, *args, **kwargs): + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: if self.cluster_environment.creates_processes_externally: launcher = SingleProcessLauncher() else: diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 28d7aef21a136..08b8aec0d970d 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.distributed @@ -140,7 +140,7 @@ def set_world_ranks(self, process_idx: int = 0) -> None: def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return {"nprocs": self.num_processes} - def launch(self, function, *args, **kwargs): + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: launcher = DDPSpawnLauncher() return launcher.launch(function, *args, **kwargs) diff --git a/pytorch_lightning/strategies/launchers/base.py b/pytorch_lightning/strategies/launchers/base.py index 4e8073aeb55b8..8227226c8c425 100644 --- a/pytorch_lightning/strategies/launchers/base.py +++ b/pytorch_lightning/strategies/launchers/base.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod +from typing import Any, Callable class Launcher(ABC): @abstractmethod - def launch(self, fn, *args, **kwargs) -> bool: + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """Launches the proceses.""" diff --git a/pytorch_lightning/strategies/launchers/tpu_spawn.py b/pytorch_lightning/strategies/launchers/tpu_spawn.py index a614ed1fbce42..5272bbb1bd402 100644 --- a/pytorch_lightning/strategies/launchers/tpu_spawn.py +++ b/pytorch_lightning/strategies/launchers/tpu_spawn.py @@ -53,7 +53,7 @@ def _wrapped_function( kwargs: Any, return_queue: SimpleQueue, ) -> None: - self._worker_setup(process_idx) + self.trainer.strategy._worker_setup(process_idx) results = function(*args, **kwargs) results = self._collect_rank_zero_results(trainer, results) if trainer.strategy.local_rank == 0: diff --git a/pytorch_lightning/strategies/parallel.py b/pytorch_lightning/strategies/parallel.py index 1025ae62b9e78..2e152d4495e42 100644 --- a/pytorch_lightning/strategies/parallel.py +++ b/pytorch_lightning/strategies/parallel.py @@ -14,7 +14,7 @@ import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, List, Optional +from typing import Any, Callable, List, Optional import torch from torch.nn.parallel import DistributedDataParallel @@ -126,7 +126,7 @@ def block_backward_sync(self): else: yield None - def launch(self, function, *args, **kwargs): + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: launcher = SingleProcessLauncher() return launcher.launch(function, *args, **kwargs) diff --git a/pytorch_lightning/strategies/single_device.py b/pytorch_lightning/strategies/single_device.py index c053ea453ecd7..95cedf52c838a 100644 --- a/pytorch_lightning/strategies/single_device.py +++ b/pytorch_lightning/strategies/single_device.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any +from typing import Any, Callable import torch @@ -41,7 +41,7 @@ def __init__( self.local_rank = 0 self.world_size = 1 - def launch(self, function, *args, **kwargs): + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: launcher = SingleProcessLauncher() return launcher.launch(function, *args, **kwargs) diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index 5084d319595c7..71c7702c1f0d8 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -140,7 +140,7 @@ def setup_precision_plugin(self) -> None: self.lr_schedulers = schedulers @abstractmethod - def launch(self, function, *args, **kwargs): + def launch(self, function: Any, *args: Any, **kwargs: Any) -> Any: """Launch the proceses using a Launcher.""" def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 79165aa12d802..6d6e4a4537845 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -212,9 +212,9 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st "start_method": self.start_method, } - def launch(self, fn, *args, **kwargs): + def launch(self, function: Any, *args: Any, **kwargs: Any) -> Any: launcher = TPUSpawnLauncher() - return launcher.launch(fn, *args, **kwargs) + return launcher.launch(function, *args, **kwargs) def _worker_setup(self, process_idx: int): reset_seed() From e7bb90b475949e60065a3e70cab6f8acae4976cb Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 2 Feb 2022 23:43:48 +0530 Subject: [PATCH 13/36] fix tpu and mypy --- pytorch_lightning/strategies/launchers/ddp.py | 5 +++-- pytorch_lightning/strategies/launchers/tpu_spawn.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/strategies/launchers/ddp.py b/pytorch_lightning/strategies/launchers/ddp.py index 4c1ff1017c0f7..951ff42549ccd 100644 --- a/pytorch_lightning/strategies/launchers/ddp.py +++ b/pytorch_lightning/strategies/launchers/ddp.py @@ -14,8 +14,9 @@ import os import subprocess import sys +from subprocess import Popen from time import sleep -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional import __main__ import numpy as np @@ -35,7 +36,7 @@ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, self.cluster_environment = cluster_environment self.num_processes = num_processes self.num_nodes = num_nodes - self.interactive_ddp_procs = [] + self.interactive_ddp_procs: List[Popen] = [] def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: kwargs.pop("trainer") diff --git a/pytorch_lightning/strategies/launchers/tpu_spawn.py b/pytorch_lightning/strategies/launchers/tpu_spawn.py index 5272bbb1bd402..ca40b6dabb84c 100644 --- a/pytorch_lightning/strategies/launchers/tpu_spawn.py +++ b/pytorch_lightning/strategies/launchers/tpu_spawn.py @@ -53,7 +53,7 @@ def _wrapped_function( kwargs: Any, return_queue: SimpleQueue, ) -> None: - self.trainer.strategy._worker_setup(process_idx) + trainer.strategy._worker_setup(process_idx) results = function(*args, **kwargs) results = self._collect_rank_zero_results(trainer, results) if trainer.strategy.local_rank == 0: From 9857ffe1346e6fd181b6e777b0519debeb8d0cd3 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 3 Feb 2022 18:22:59 +0530 Subject: [PATCH 14/36] recover results for TPU --- pytorch_lightning/strategies/launchers/tpu_spawn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/launchers/tpu_spawn.py b/pytorch_lightning/strategies/launchers/tpu_spawn.py index ca40b6dabb84c..69c4cb2f66d3c 100644 --- a/pytorch_lightning/strategies/launchers/tpu_spawn.py +++ b/pytorch_lightning/strategies/launchers/tpu_spawn.py @@ -42,7 +42,9 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: args=(trainer, function, args, kwargs, return_queue), **trainer.strategy.get_mp_spawn_kwargs() ) - return return_queue.get() + spawn_output = return_queue.get() + self._recover_results_in_main_process(spawn_output, trainer) + return spawn_output.trainer_results def _wrapped_function( self, From d2817f36d51a714bfbf0e000616abfb5510a1109 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 3 Feb 2022 10:51:10 -0500 Subject: [PATCH 15/36] lite patch --- pytorch_lightning/lite/lite.py | 13 ++++++------- pytorch_lightning/strategies/launchers/ddp.py | 2 +- pytorch_lightning/strategies/launchers/ddp_spawn.py | 5 ++++- .../strategies/launchers/single_process.py | 2 +- pytorch_lightning/strategies/launchers/tpu_spawn.py | 5 ++++- 5 files changed, 16 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index fb7cd80e61909..66807481dc63a 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -27,7 +27,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.plugins import PLUGIN_INPUT -from pytorch_lightning.strategies import DDPSpawnStrategy, DeepSpeedStrategy, Strategy, TPUSpawnStrategy +from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy, TPUSpawnStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_data_to_device @@ -399,15 +399,14 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) return seed_everything(seed=seed, workers=workers) def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: - self._strategy.setup_environment() - # apply sharded context to prevent OOM run_method = partial(self._run_with_sharded_context, run_method) + run_method = partial(self._wrapped_run_method, run_method) + return self._strategy.launch(run_method, *args, **kwargs) - if isinstance(self._strategy, DDPSpawnStrategy): - return self._strategy.spawn(run_method, *args, **kwargs) - else: - return run_method(*args, **kwargs) + def _wrapped_run_method(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: + self._strategy.setup_environment() + return run_method(*args, **kwargs) def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: with self._strategy.model_sharded_context(), _replace_dataloader_init_method(): diff --git a/pytorch_lightning/strategies/launchers/ddp.py b/pytorch_lightning/strategies/launchers/ddp.py index 951ff42549ccd..0057f5ad2de18 100644 --- a/pytorch_lightning/strategies/launchers/ddp.py +++ b/pytorch_lightning/strategies/launchers/ddp.py @@ -39,7 +39,7 @@ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, self.interactive_ddp_procs: List[Popen] = [] def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - kwargs.pop("trainer") + kwargs.pop("trainer", None) self._call_children_scripts() return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/launchers/ddp_spawn.py b/pytorch_lightning/strategies/launchers/ddp_spawn.py index fcc8ee4e5ac89..160b21eabfe0b 100644 --- a/pytorch_lightning/strategies/launchers/ddp_spawn.py +++ b/pytorch_lightning/strategies/launchers/ddp_spawn.py @@ -31,7 +31,7 @@ class DDPSpawnLauncher(Launcher): def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - trainer = kwargs.pop("trainer") + trainer = kwargs.pop("trainer", None) os.environ["MASTER_PORT"] = str(trainer.strategy.cluster_environment.main_port) context = mp.get_context("spawn") return_queue = context.SimpleQueue() @@ -41,6 +41,9 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: nprocs=trainer.strategy.num_processes, ) spawn_output = return_queue.get() + if trainer is None: + return spawn_output + self._recover_results_in_main_process(spawn_output, trainer) return spawn_output.trainer_results diff --git a/pytorch_lightning/strategies/launchers/single_process.py b/pytorch_lightning/strategies/launchers/single_process.py index 5294359aac42f..428073c5e7cb3 100644 --- a/pytorch_lightning/strategies/launchers/single_process.py +++ b/pytorch_lightning/strategies/launchers/single_process.py @@ -18,5 +18,5 @@ class SingleProcessLauncher(Launcher): def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - kwargs.pop("trainer") + kwargs.pop("trainer", None) return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/launchers/tpu_spawn.py b/pytorch_lightning/strategies/launchers/tpu_spawn.py index 69c4cb2f66d3c..f0ec87db6f7d7 100644 --- a/pytorch_lightning/strategies/launchers/tpu_spawn.py +++ b/pytorch_lightning/strategies/launchers/tpu_spawn.py @@ -34,7 +34,7 @@ class TPUSpawnLauncher(DDPSpawnLauncher): def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - trainer = kwargs.pop("trainer") + trainer = kwargs.pop("trainer", None) context = mp.get_context(trainer.strategy.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn( @@ -43,6 +43,9 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: **trainer.strategy.get_mp_spawn_kwargs() ) spawn_output = return_queue.get() + if trainer is None: + return spawn_output + self._recover_results_in_main_process(spawn_output, trainer) return spawn_output.trainer_results From b6bd4d8a3a1bc488d7992285946ba58515c63999 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 3 Feb 2022 13:02:57 -0500 Subject: [PATCH 16/36] fix deadlock detection --- pytorch_lightning/strategies/ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 28f33e2242f96..65d0145375078 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -141,11 +141,11 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: launcher = SingleProcessLauncher() else: launcher = DDPSubprocessLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + self._rank_0_has_called_call_children_scripts = True return launcher.launch(function, *args, **kwargs) def setup_environment(self) -> None: - self._rank_0_has_called_call_children_scripts = not self.cluster_environment.creates_processes_externally self.setup_distributed() super().setup_environment() From ce49e156b350a3a6ed0d8a0216b9c97cdf3e2a50 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 3 Feb 2022 14:30:52 -0500 Subject: [PATCH 17/36] lite spawn patch --- pytorch_lightning/lite/lite.py | 1 + pytorch_lightning/strategies/launchers/ddp.py | 1 + .../strategies/launchers/ddp_spawn.py | 20 ++++++++++------ .../strategies/launchers/single_process.py | 1 + .../strategies/launchers/tpu_spawn.py | 24 ++++++++++++------- pytorch_lightning/trainer/trainer.py | 1 + 6 files changed, 32 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 66807481dc63a..4490ed0f536bc 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -402,6 +402,7 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: # apply sharded context to prevent OOM run_method = partial(self._run_with_sharded_context, run_method) run_method = partial(self._wrapped_run_method, run_method) + kwargs["strategy"] = self._strategy return self._strategy.launch(run_method, *args, **kwargs) def _wrapped_run_method(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: diff --git a/pytorch_lightning/strategies/launchers/ddp.py b/pytorch_lightning/strategies/launchers/ddp.py index 0057f5ad2de18..c47dc3eafb7a7 100644 --- a/pytorch_lightning/strategies/launchers/ddp.py +++ b/pytorch_lightning/strategies/launchers/ddp.py @@ -40,6 +40,7 @@ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: kwargs.pop("trainer", None) + kwargs.pop("strategy") self._call_children_scripts() return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/launchers/ddp_spawn.py b/pytorch_lightning/strategies/launchers/ddp_spawn.py index 160b21eabfe0b..9f1dfd48a2963 100644 --- a/pytorch_lightning/strategies/launchers/ddp_spawn.py +++ b/pytorch_lightning/strategies/launchers/ddp_spawn.py @@ -22,6 +22,7 @@ import pytorch_lightning as pl from pytorch_lightning.strategies.launchers.base import Launcher +from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.trainer.states import TrainerFn, TrainerState from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import rank_zero_debug @@ -32,13 +33,14 @@ class DDPSpawnLauncher(Launcher): def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: trainer = kwargs.pop("trainer", None) - os.environ["MASTER_PORT"] = str(trainer.strategy.cluster_environment.main_port) + strategy = kwargs.pop("strategy") + os.environ["MASTER_PORT"] = str(strategy.cluster_environment.main_port) context = mp.get_context("spawn") return_queue = context.SimpleQueue() mp.spawn( self._wrapped_function, - args=(trainer, function, args, kwargs, return_queue), - nprocs=trainer.strategy.num_processes, + args=(strategy, trainer, function, args, kwargs, return_queue), + nprocs=strategy.num_processes, ) spawn_output = return_queue.get() if trainer is None: @@ -50,16 +52,20 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: def _wrapped_function( self, process_idx: int, - trainer: "pl.Trainer", + strategy: Strategy, + trainer: Optional["pl.Trainer"], function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue, ) -> None: - trainer.strategy._worker_setup(process_idx) + strategy._worker_setup(process_idx) results = function(*args, **kwargs) - results = self._collect_rank_zero_results(trainer, results) - if trainer.strategy.local_rank == 0: + + if trainer is not None: + results = self._collect_rank_zero_results(trainer, results) + + if strategy.local_rank == 0: return_queue.put(move_data_to_device(results, "cpu")) def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: diff --git a/pytorch_lightning/strategies/launchers/single_process.py b/pytorch_lightning/strategies/launchers/single_process.py index 428073c5e7cb3..a20243c8e085f 100644 --- a/pytorch_lightning/strategies/launchers/single_process.py +++ b/pytorch_lightning/strategies/launchers/single_process.py @@ -19,4 +19,5 @@ class SingleProcessLauncher(Launcher): def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: kwargs.pop("trainer", None) + kwargs.pop("strategy") return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/launchers/tpu_spawn.py b/pytorch_lightning/strategies/launchers/tpu_spawn.py index f0ec87db6f7d7..956e70867eddd 100644 --- a/pytorch_lightning/strategies/launchers/tpu_spawn.py +++ b/pytorch_lightning/strategies/launchers/tpu_spawn.py @@ -20,6 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.strategies.launchers.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnLauncher +from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device @@ -35,12 +36,13 @@ class TPUSpawnLauncher(DDPSpawnLauncher): def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: trainer = kwargs.pop("trainer", None) - context = mp.get_context(trainer.strategy.start_method or "fork") + strategy = kwargs.pop("strategy") + context = mp.get_context(strategy.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn( self._wrapped_function, - args=(trainer, function, args, kwargs, return_queue), - **trainer.strategy.get_mp_spawn_kwargs() + args=(strategy, trainer, function, args, kwargs, return_queue), + **strategy.get_mp_spawn_kwargs() ) spawn_output = return_queue.get() if trainer is None: @@ -52,24 +54,28 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: def _wrapped_function( self, process_idx: int, - trainer: "pl.Trainer", + strategy: Strategy, + trainer: Optional["pl.Trainer"], function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue, ) -> None: - trainer.strategy._worker_setup(process_idx) + strategy._worker_setup(process_idx) results = function(*args, **kwargs) - results = self._collect_rank_zero_results(trainer, results) - if trainer.strategy.local_rank == 0: + + if trainer is not None: + results = self._collect_rank_zero_results(trainer, results) + + if strategy.local_rank == 0: return_queue.put(move_data_to_device(results, "cpu")) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 - trainer.strategy.barrier("end-process") + strategy.barrier("end-process") # Ensure that the rank 0 process is the one exiting last # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 - if trainer.strategy.local_rank == 0: + if strategy.local_rank == 0: time.sleep(2) def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8fc7b1e53baf4..1a18a46023a34 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -669,6 +669,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: kwargs["trainer"] = self + kwargs["strategy"] = self.strategy return self.strategy.launch(trainer_fn, *args, **kwargs) # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: From 2314253328edba34229d5bd66b723d24b84841eb Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 4 Feb 2022 07:13:47 -0500 Subject: [PATCH 18/36] constructor initialize --- pytorch_lightning/core/lightning.py | 4 +-- pytorch_lightning/lite/lite.py | 7 ++++-- pytorch_lightning/strategies/ddp.py | 23 +++++++---------- pytorch_lightning/strategies/ddp_spawn.py | 9 +++---- .../strategies/launchers/__init__.py | 14 +++++------ .../launchers/{ddp.py => multi_process.py} | 3 +-- .../strategies/launchers/single_process.py | 23 ----------------- .../launchers/{ddp_spawn.py => spawn.py} | 25 ++++++++++--------- .../launchers/{tpu_spawn.py => xla_spawn.py} | 25 ++++++++----------- pytorch_lightning/strategies/parallel.py | 7 +----- pytorch_lightning/strategies/single_device.py | 7 +----- pytorch_lightning/strategies/strategy.py | 10 +++++--- pytorch_lightning/strategies/tpu_spawn.py | 7 ++---- .../connectors/accelerator_connector.py | 3 +++ pytorch_lightning/trainer/trainer.py | 8 +++--- tests/strategies/test_ddp_spawn_strategy.py | 15 ++++++----- 16 files changed, 75 insertions(+), 115 deletions(-) rename pytorch_lightning/strategies/launchers/{ddp.py => multi_process.py} (98%) delete mode 100644 pytorch_lightning/strategies/launchers/single_process.py rename pytorch_lightning/strategies/launchers/{ddp_spawn.py => spawn.py} (89%) rename pytorch_lightning/strategies/launchers/{tpu_spawn.py => xla_spawn.py} (82%) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b38843f8e527e..75ed636c7be7b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1954,7 +1954,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: pl.strategies.launchers.ddp_spawn._FakeQueue) -> None: + def add_to_queue(self, queue: pl.strategies.launchers.spawn._FakeQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1966,7 +1966,7 @@ def add_to_queue(self, queue: pl.strategies.launchers.ddp_spawn._FakeQueue) -> N and will be removed in v1.7. """ - def get_from_queue(self, queue: pl.strategies.launchers.ddp_spawn._FakeQueue) -> None: + def get_from_queue(self, queue: pl.strategies.launchers.spawn._FakeQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 4490ed0f536bc..c7351d92af2cb 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -402,8 +402,11 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: # apply sharded context to prevent OOM run_method = partial(self._run_with_sharded_context, run_method) run_method = partial(self._wrapped_run_method, run_method) - kwargs["strategy"] = self._strategy - return self._strategy.launch(run_method, *args, **kwargs) + + if self._strategy.launcher is not None: + return self._strategy.launcher.launch(run_method, *args, **kwargs) + else: + return run_method(*args, **kwargs) def _wrapped_run_method(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._strategy.setup_environment() diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 65d0145375078..87c5f646bc687 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -18,7 +18,7 @@ import tempfile import time from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.distributed @@ -32,8 +32,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.launchers.ddp import DDPSubprocessLauncher -from pytorch_lightning.strategies.launchers.single_process import SingleProcessLauncher +from pytorch_lightning.strategies.launchers.multi_process import MultiProcessLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( @@ -106,7 +105,7 @@ def __init__( self._model_averaging_period = model_averaging_period self._pids: Optional[List[int]] = None self._sync_dir: Optional[str] = None - self._rank_0_has_called_call_children_scripts: bool = False + self._rank_0_will_call_children_scripts: bool = False self.set_world_ranks() @property @@ -136,14 +135,10 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self) -> bool: return True - def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - if self.cluster_environment.creates_processes_externally: - launcher = SingleProcessLauncher() - else: - launcher = DDPSubprocessLauncher(self.cluster_environment, self.num_processes, self.num_nodes) - self._rank_0_has_called_call_children_scripts = True - - return launcher.launch(function, *args, **kwargs) + def configure_multi_process_launcher(self): + if self.launcher is None and not self.cluster_environment.creates_processes_externally: + self._launcher = MultiProcessLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + self._rank_0_will_call_children_scripts = True def setup_environment(self) -> None: self.setup_distributed() @@ -152,7 +147,7 @@ def setup_environment(self) -> None: def setup(self, trainer: "pl.Trainer") -> None: super().setup(trainer) # share ddp pids to all processes - self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts) + self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) if self._should_run_deadlock_detection(): self._share_information_to_prevent_deadlock() @@ -365,7 +360,7 @@ def _should_run_deadlock_detection(self) -> bool: By default this is disabled. Otherwise, if the cluster environment creates the processes, allow the scheduler / parent process to perform the process termination, external to Lightning. """ - return os.getenv("PL_RECONCILE_PROCESS", "0") == "1" or self._rank_0_has_called_call_children_scripts + return os.getenv("PL_RECONCILE_PROCESS", "0") == "1" or self._rank_0_will_call_children_scripts def _share_information_to_prevent_deadlock(self) -> None: self._share_pids() diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 08b8aec0d970d..cd63f5d922964 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.distributed @@ -26,7 +26,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.launchers.ddp_spawn import DDPSpawnLauncher +from pytorch_lightning.strategies.launchers.spawn import SpawnLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn @@ -82,6 +82,7 @@ def __init__( self._ddp_comm_wrapper = ddp_comm_wrapper self._local_rank = 0 self.set_world_ranks() + self._launcher = SpawnLauncher(self) @property def num_nodes(self) -> int: @@ -140,10 +141,6 @@ def set_world_ranks(self, process_idx: int = 0) -> None: def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return {"nprocs": self.num_processes} - def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - launcher = DDPSpawnLauncher() - return launcher.launch(function, *args, **kwargs) - def _worker_setup(self, process_idx: int): reset_seed() self.set_world_ranks(process_idx) diff --git a/pytorch_lightning/strategies/launchers/__init__.py b/pytorch_lightning/strategies/launchers/__init__.py index d67f81d237937..ce1330c295494 100644 --- a/pytorch_lightning/strategies/launchers/__init__.py +++ b/pytorch_lightning/strategies/launchers/__init__.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.strategies.launchers.base import Launcher -from pytorch_lightning.strategies.launchers.ddp import DDPSubprocessLauncher -from pytorch_lightning.strategies.launchers.ddp_spawn import DDPSpawnLauncher -from pytorch_lightning.strategies.launchers.single_process import SingleProcessLauncher -from pytorch_lightning.strategies.launchers.tpu_spawn import TPUSpawnLauncher +from pytorch_lightning.strategies.launchers.multi_process import MultiProcessLauncher +from pytorch_lightning.strategies.launchers.spawn import SpawnLauncher +from pytorch_lightning.strategies.launchers.xla_spawn import XLASpawnLauncher __all__ = [ - "DDPSpawnLauncher", - "DDPSubprocessLauncher", "Launcher", - "SingleProcessLauncher", - "TPUSpawnLauncher", + "MultiProcessLauncher", + "SpawnLauncher", + "XLASpawnLauncher", ] diff --git a/pytorch_lightning/strategies/launchers/ddp.py b/pytorch_lightning/strategies/launchers/multi_process.py similarity index 98% rename from pytorch_lightning/strategies/launchers/ddp.py rename to pytorch_lightning/strategies/launchers/multi_process.py index c47dc3eafb7a7..1c4f30a134d47 100644 --- a/pytorch_lightning/strategies/launchers/ddp.py +++ b/pytorch_lightning/strategies/launchers/multi_process.py @@ -30,7 +30,7 @@ from hydra.utils import get_original_cwd, to_absolute_path -class DDPSubprocessLauncher(Launcher): +class MultiProcessLauncher(Launcher): def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None: super().__init__() self.cluster_environment = cluster_environment @@ -40,7 +40,6 @@ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: kwargs.pop("trainer", None) - kwargs.pop("strategy") self._call_children_scripts() return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/launchers/single_process.py b/pytorch_lightning/strategies/launchers/single_process.py deleted file mode 100644 index a20243c8e085f..0000000000000 --- a/pytorch_lightning/strategies/launchers/single_process.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable - -from pytorch_lightning.strategies.launchers.base import Launcher - - -class SingleProcessLauncher(Launcher): - def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - kwargs.pop("trainer", None) - kwargs.pop("strategy") - return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/launchers/ddp_spawn.py b/pytorch_lightning/strategies/launchers/spawn.py similarity index 89% rename from pytorch_lightning/strategies/launchers/ddp_spawn.py rename to pytorch_lightning/strategies/launchers/spawn.py index 9f1dfd48a2963..e1f9c63b0a4a1 100644 --- a/pytorch_lightning/strategies/launchers/ddp_spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -30,17 +30,19 @@ from pytorch_lightning.utilities.types import _PATH -class DDPSpawnLauncher(Launcher): +class SpawnLauncher(Launcher): + def __init__(self, strategy: Strategy): + self._strategy = strategy + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: trainer = kwargs.pop("trainer", None) - strategy = kwargs.pop("strategy") - os.environ["MASTER_PORT"] = str(strategy.cluster_environment.main_port) + os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port) context = mp.get_context("spawn") return_queue = context.SimpleQueue() mp.spawn( self._wrapped_function, - args=(strategy, trainer, function, args, kwargs, return_queue), - nprocs=strategy.num_processes, + args=(trainer, function, args, kwargs, return_queue), + nprocs=self._strategy.num_processes, ) spawn_output = return_queue.get() if trainer is None: @@ -52,20 +54,19 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: def _wrapped_function( self, process_idx: int, - strategy: Strategy, trainer: Optional["pl.Trainer"], function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue, ) -> None: - strategy._worker_setup(process_idx) + self._strategy._worker_setup(process_idx) results = function(*args, **kwargs) if trainer is not None: results = self._collect_rank_zero_results(trainer, results) - if strategy.local_rank == 0: + if self._strategy.local_rank == 0: return_queue.put(move_data_to_device(results, "cpu")) def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: @@ -76,11 +77,11 @@ def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer # TODO: pass also best score # load last weights if spawn_output.weights_path is not None: - ckpt = trainer.strategy.checkpoint_io.load_checkpoint( + ckpt = self._strategy.checkpoint_io.load_checkpoint( spawn_output.weights_path, map_location=(lambda storage, loc: storage) ) trainer.lightning_module.load_state_dict(ckpt) - trainer.strategy.checkpoint_io.remove_checkpoint(spawn_output.weights_path) + self._strategy.checkpoint_io.remove_checkpoint(spawn_output.weights_path) trainer.state = spawn_output.trainer_state @@ -99,14 +100,14 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt # requires to compute the state_dict on all processes in case Metrics are present state_dict = trainer.lightning_module.state_dict() - if trainer.strategy.global_rank != 0: + if self._strategy.global_rank != 0: return # save the last weights weights_path = None if trainer.state.fn == TrainerFn.FITTING: weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") - trainer.strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) + self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) # adds the `callback_metrics` to the queue extra = _FakeQueue() diff --git a/pytorch_lightning/strategies/launchers/tpu_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py similarity index 82% rename from pytorch_lightning/strategies/launchers/tpu_spawn.py rename to pytorch_lightning/strategies/launchers/xla_spawn.py index 956e70867eddd..35982c377feab 100644 --- a/pytorch_lightning/strategies/launchers/tpu_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -19,8 +19,7 @@ import torch.multiprocessing as mp import pytorch_lightning as pl -from pytorch_lightning.strategies.launchers.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnLauncher -from pytorch_lightning.strategies.strategy import Strategy +from pytorch_lightning.strategies.launchers.spawn import _FakeQueue, _SpawnOutput, SpawnLauncher from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device @@ -33,16 +32,15 @@ xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 -class TPUSpawnLauncher(DDPSpawnLauncher): +class XLASpawnLauncher(SpawnLauncher): def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: trainer = kwargs.pop("trainer", None) - strategy = kwargs.pop("strategy") - context = mp.get_context(strategy.start_method or "fork") + context = mp.get_context(self._strategy.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn( self._wrapped_function, - args=(strategy, trainer, function, args, kwargs, return_queue), - **strategy.get_mp_spawn_kwargs() + args=(trainer, function, args, kwargs, return_queue), + **self._strategy.get_mp_spawn_kwargs() ) spawn_output = return_queue.get() if trainer is None: @@ -54,28 +52,27 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: def _wrapped_function( self, process_idx: int, - strategy: Strategy, trainer: Optional["pl.Trainer"], function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue, ) -> None: - strategy._worker_setup(process_idx) + self._strategy._worker_setup(process_idx) results = function(*args, **kwargs) if trainer is not None: results = self._collect_rank_zero_results(trainer, results) - if strategy.local_rank == 0: + if self._strategy.local_rank == 0: return_queue.put(move_data_to_device(results, "cpu")) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 - strategy.barrier("end-process") + self._strategy.barrier("end-process") # Ensure that the rank 0 process is the one exiting last # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 - if strategy.local_rank == 0: + if self._strategy.local_rank == 0: time.sleep(2) def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: @@ -90,10 +87,10 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt weights_path = None if trainer.state.fn == TrainerFn.FITTING: weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") - trainer.strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) + self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training - if trainer.strategy.local_rank != 0: + if self._strategy.local_rank != 0: return # adds the `callback_metrics` to the queue diff --git a/pytorch_lightning/strategies/parallel.py b/pytorch_lightning/strategies/parallel.py index 2e152d4495e42..5d7d487a214e3 100644 --- a/pytorch_lightning/strategies/parallel.py +++ b/pytorch_lightning/strategies/parallel.py @@ -14,7 +14,7 @@ import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Callable, List, Optional +from typing import Any, List, Optional import torch from torch.nn.parallel import DistributedDataParallel @@ -24,7 +24,6 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.launchers.single_process import SingleProcessLauncher from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp @@ -126,10 +125,6 @@ def block_backward_sync(self): else: yield None - def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - launcher = SingleProcessLauncher() - return launcher.launch(function, *args, **kwargs) - def teardown(self) -> None: self.cluster_environment.teardown() super().teardown() diff --git a/pytorch_lightning/strategies/single_device.py b/pytorch_lightning/strategies/single_device.py index 95cedf52c838a..440c73afce8fc 100644 --- a/pytorch_lightning/strategies/single_device.py +++ b/pytorch_lightning/strategies/single_device.py @@ -13,14 +13,13 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Callable +from typing import Any import torch import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.launchers.single_process import SingleProcessLauncher from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.types import _DEVICE @@ -41,10 +40,6 @@ def __init__( self.local_rank = 0 self.world_size = 1 - def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - launcher = SingleProcessLauncher() - return launcher.launch(function, *args, **kwargs) - def reduce(self, tensor: Any | torch.Tensor, *args: Any, **kwargs: Any) -> Any | torch.Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only operates with a single device, the reduction is simply the identity. diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index 71c7702c1f0d8..580d880c21bec 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -27,6 +27,7 @@ from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.strategies.launchers.base import Launcher from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device @@ -48,6 +49,7 @@ def __init__( precision_plugin: Optional[PrecisionPlugin] = None, ) -> None: self.accelerator = accelerator + self._launcher = None self._model: Optional[Module] = None self.checkpoint_io = checkpoint_io self.precision_plugin = precision_plugin @@ -62,6 +64,10 @@ def __init__( f" Move your implementation to `{self.__class__.__name__}.teardown()` instead." ) + @property + def launcher(self) -> Optional[Launcher]: + return self._launcher + @property def accelerator(self) -> "pl.accelerators.accelerator.Accelerator": return self._accelerator @@ -139,10 +145,6 @@ def setup_precision_plugin(self) -> None: self.optimizers = optimizers self.lr_schedulers = schedulers - @abstractmethod - def launch(self, function: Any, *args: Any, **kwargs: Any) -> Any: - """Launch the proceses using a Launcher.""" - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the appropriate device if needed.""" for opt in self.optimizers: diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 6d6e4a4537845..124f2858f3a89 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -24,7 +24,7 @@ from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy -from pytorch_lightning.strategies.launchers.tpu_spawn import TPUSpawnLauncher +from pytorch_lightning.strategies.launchers.xla_spawn import XLASpawnLauncher from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters from pytorch_lightning.utilities.data import has_len @@ -67,6 +67,7 @@ def __init__( self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 self.start_method = "fork" + self._launcher = XLASpawnLauncher(self) @property def global_rank(self) -> int: @@ -212,10 +213,6 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st "start_method": self.start_method, } - def launch(self, function: Any, *args: Any, **kwargs: Any) -> Any: - launcher = TPUSpawnLauncher() - return launcher.launch(function, *args, **kwargs) - def _worker_setup(self, process_idx: int): reset_seed() self.tpu_local_core_rank = xm.get_local_ordinal() diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index d476bc5f0ca6e..fb3ade05ee721 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -775,6 +775,9 @@ def resolve_strategy(self, training_type: Strategy) -> Strategy: # set sync_batchnorm for training_type from trainer setting training_type.sync_batchnorm = self.sync_batchnorm + if isinstance(training_type, DDPStrategy): + training_type.configure_multi_process_launcher() + return training_type def select_accelerator(self) -> Accelerator: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1a18a46023a34..f64bba853fecf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -668,9 +668,11 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: **kwargs: keyword arguments to be passed to `trainer_fn` """ try: - kwargs["trainer"] = self - kwargs["strategy"] = self.strategy - return self.strategy.launch(trainer_fn, *args, **kwargs) + if self.strategy.launcher is not None: + kwargs["trainer"] = self + return self.strategy.launcher.launch(trainer_fn, *args, **kwargs) + else: + return trainer_fn(*args, **kwargs) # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") diff --git a/tests/strategies/test_ddp_spawn_strategy.py b/tests/strategies/test_ddp_spawn_strategy.py index 3da24ea1a7c88..7143d79788d87 100644 --- a/tests/strategies/test_ddp_spawn_strategy.py +++ b/tests/strategies/test_ddp_spawn_strategy.py @@ -20,7 +20,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.strategies import DDPSpawnStrategy -from pytorch_lightning.strategies.launchers.ddp_spawn import DDPSpawnLauncher +from pytorch_lightning.strategies.launchers.spawn import SpawnLauncher from pytorch_lightning.trainer.states import TrainerFn from tests.helpers.boring_model import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -83,7 +83,7 @@ def test_ddp_spawn_extra_parameters(tmpdir): assert model.test_val == "test_val" -class CustomDDPSpawnLauncher(DDPSpawnLauncher): +class CustomSpawnLauncher(SpawnLauncher): def add_to_queue(self, trainer, queue) -> None: queue.put("new_test_val") return super().add_to_queue(trainer, queue) @@ -94,9 +94,9 @@ def get_from_queue(self, trainer: Trainer, queue) -> None: class TestDDPSpawnStrategy(DDPSpawnStrategy): - def launch(self, function, *args, **kwargs): - launcher = CustomDDPSpawnLauncher() - return launcher.launch(function, *args, **kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._launcher = CustomSpawnLauncher(self) @RunIf(skip_windows=True, skip_49370=True) @@ -155,14 +155,13 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): file.""" model = Mock(wraps=BoringModel(), spec=BoringModel) strategy = DDPSpawnStrategy() - launcher = DDPSpawnLauncher() trainer = Trainer(default_root_dir=tmpdir, strategy=strategy) trainer.strategy.connect(model) trainer.state.fn = trainer_fn # pretend we are in a particular trainer state temp_file = Path(tmpdir, ".temp.ckpt") assert not temp_file.exists() - spawn_output = launcher._collect_rank_zero_results(trainer, {}) + spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {}) model.state_dict.assert_called_once() if trainer_fn == TrainerFn.FITTING: @@ -173,6 +172,6 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): assert not temp_file.exists() # <-- here would normally be the multiprocessing boundary - launcher._recover_results_in_main_process(spawn_output, trainer) + strategy._launcher._recover_results_in_main_process(spawn_output, trainer) assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None) assert not temp_file.exists() From f094fc0eb3110c839bf0a94eb0c7783c93675559 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Feb 2022 20:44:12 +0000 Subject: [PATCH 19/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/strategies/ddp_spawn.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 402ed9cd8665a..a03a3c0073866 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -28,16 +28,6 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.spawn import SpawnLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 -from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available -from pytorch_lightning.utilities.distributed import group as _group -from pytorch_lightning.utilities.distributed import ( - init_dist_connection, - ReduceOp, - sync_ddp_if_available, -) -from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.trainer.states import TrainerFn, TrainerState from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device From 75c073187f5974cd386d38d3283fea3844c6ba3e Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 8 Feb 2022 02:16:01 +0530 Subject: [PATCH 20/36] Apply suggestions from code review --- pytorch_lightning/strategies/launchers/spawn.py | 2 +- pytorch_lightning/strategies/launchers/xla_spawn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index e1f9c63b0a4a1..c2354669cd5b5 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -25,7 +25,7 @@ from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.trainer.states import TrainerFn, TrainerState from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device -from pytorch_lightning.utilities.distributed import rank_zero_debug +from pytorch_lightning.utilities.rank_zero import rank_zero_debug from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _PATH diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index 35982c377feab..f7a6c229b92c5 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -23,7 +23,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.distributed import rank_zero_debug +from pytorch_lightning.utilities.rank_zero import rank_zero_debug from pytorch_lightning.utilities.model_helpers import is_overridden if _TPU_AVAILABLE: From a3de1e326942a7387b6f0dcd9415c638418e1976 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Feb 2022 20:47:14 +0000 Subject: [PATCH 21/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/strategies/launchers/spawn.py | 2 +- pytorch_lightning/strategies/launchers/xla_spawn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index c2354669cd5b5..70adfdef67a2b 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -25,8 +25,8 @@ from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.trainer.states import TrainerFn, TrainerState from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device -from pytorch_lightning.utilities.rank_zero import rank_zero_debug from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_debug from pytorch_lightning.utilities.types import _PATH diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index f7a6c229b92c5..88338c059fae3 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -23,8 +23,8 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.rank_zero import rank_zero_debug from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_debug if _TPU_AVAILABLE: import torch_xla.distributed.xla_multiprocessing as xmp From 5da734e950120358fc400fe0ebb26505f1ac64ff Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 9 Feb 2022 17:30:49 +0530 Subject: [PATCH 22/36] review comments --- pytorch_lightning/strategies/launchers/base.py | 6 ++++++ .../strategies/launchers/multi_process.py | 5 +++++ pytorch_lightning/strategies/launchers/spawn.py | 12 +++++++++--- pytorch_lightning/strategies/launchers/xla_spawn.py | 10 ++++++++-- pytorch_lightning/trainer/trainer.py | 3 +-- 5 files changed, 29 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/strategies/launchers/base.py b/pytorch_lightning/strategies/launchers/base.py index 8227226c8c425..8b2041bc927dc 100644 --- a/pytorch_lightning/strategies/launchers/base.py +++ b/pytorch_lightning/strategies/launchers/base.py @@ -16,6 +16,12 @@ class Launcher(ABC): + r""" + Abstract base class used to build new Launchers. + + Subclass this class and override any of the relevant methods + """ + @abstractmethod def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """Launches the proceses.""" diff --git a/pytorch_lightning/strategies/launchers/multi_process.py b/pytorch_lightning/strategies/launchers/multi_process.py index 1c4f30a134d47..28fd9224c8214 100644 --- a/pytorch_lightning/strategies/launchers/multi_process.py +++ b/pytorch_lightning/strategies/launchers/multi_process.py @@ -31,6 +31,10 @@ class MultiProcessLauncher(Launcher): + r""" + Creates and launches subprocess scripts on each device. + """ + def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None: super().__init__() self.cluster_environment = cluster_environment @@ -39,6 +43,7 @@ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, self.interactive_ddp_procs: List[Popen] = [] def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + """Creates children scripts.""" kwargs.pop("trainer", None) self._call_children_scripts() return function(*args, **kwargs) diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index 70adfdef67a2b..fb2296fb546f0 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -31,16 +31,22 @@ class SpawnLauncher(Launcher): - def __init__(self, strategy: Strategy): + r""" + Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes when it + finishes. + """ + + def __init__(self, strategy: Strategy) -> None: self._strategy = strategy def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + """Creates spawn processes and join them at the end.""" trainer = kwargs.pop("trainer", None) os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port) context = mp.get_context("spawn") return_queue = context.SimpleQueue() mp.spawn( - self._wrapped_function, + self._wrapping_function, args=(trainer, function, args, kwargs, return_queue), nprocs=self._strategy.num_processes, ) @@ -51,7 +57,7 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: self._recover_results_in_main_process(spawn_output, trainer) return spawn_output.trainer_results - def _wrapped_function( + def _wrapping_function( self, process_idx: int, trainer: Optional["pl.Trainer"], diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index 88338c059fae3..f5f6240e60eb2 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -33,12 +33,18 @@ class XLASpawnLauncher(SpawnLauncher): + r""" + Spawns processes using the `torch_xla` :func:`xmp.spawn` method and joins processes when it + finishes. + """ + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + """Creates spawn processes and join them at the end.""" trainer = kwargs.pop("trainer", None) context = mp.get_context(self._strategy.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn( - self._wrapped_function, + self._wrapping_function, args=(trainer, function, args, kwargs, return_queue), **self._strategy.get_mp_spawn_kwargs() ) @@ -49,7 +55,7 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: self._recover_results_in_main_process(spawn_output, trainer) return spawn_output.trainer_results - def _wrapped_function( + def _wrapping_function( self, process_idx: int, trainer: Optional["pl.Trainer"], diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 62b7fdfc8948b..47fded473a643 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -670,8 +670,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: if self.strategy.launcher is not None: - kwargs["trainer"] = self - return self.strategy.launcher.launch(trainer_fn, *args, **kwargs) + return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs) else: return trainer_fn(*args, **kwargs) # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 From 6df086282f533e1421f8414a50f257a9964f44de Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 9 Feb 2022 17:33:11 +0530 Subject: [PATCH 23/36] pre-commit --- pytorch_lightning/strategies/ddp_spawn.py | 4 +--- pytorch_lightning/strategies/tpu_spawn.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index a03a3c0073866..cc48627249371 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -28,14 +28,12 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.spawn import SpawnLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy -from pytorch_lightning.trainer.states import TrainerFn, TrainerState +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.enums import _StrategyType -from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index f8e64fbd2c33f..86923ee022779 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only +from pytorch_lightning.utilities.rank_zero import rank_zero_only from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT From 1a7013cf9913aa176b7212ad73bf5f4d305b6787 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 10 Feb 2022 12:00:52 -0500 Subject: [PATCH 24/36] configure_launcher --- pytorch_lightning/strategies/ddp.py | 8 ++++---- pytorch_lightning/strategies/ddp_spawn.py | 6 ++++-- .../strategies/launchers/__init__.py | 16 ++++++++-------- pytorch_lightning/strategies/launchers/base.py | 2 +- .../strategies/launchers/multi_process.py | 5 +++-- pytorch_lightning/strategies/launchers/spawn.py | 2 +- .../strategies/launchers/xla_spawn.py | 2 +- pytorch_lightning/strategies/strategy.py | 9 ++++++--- pytorch_lightning/strategies/tpu_spawn.py | 6 ++++-- .../trainer/connectors/accelerator_connector.py | 4 +--- tests/strategies/test_ddp_spawn_strategy.py | 2 ++ 11 files changed, 35 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index c3ad183c271a9..dcb615bb9192a 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -32,7 +32,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.launchers.multi_process import MultiProcessLauncher +from pytorch_lightning.strategies.launchers.multi_process import _SubprocessScriptLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( @@ -133,9 +133,9 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self) -> bool: return True - def configure_multi_process_launcher(self): - if self.launcher is None and not self.cluster_environment.creates_processes_externally: - self._launcher = MultiProcessLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + def _configure_launcher(self): + self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + if not self.cluster_environment.creates_processes_externally: self._rank_0_will_call_children_scripts = True def setup_environment(self) -> None: diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index cc48627249371..9b58137d2719d 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -26,7 +26,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.launchers.spawn import SpawnLauncher +from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 @@ -77,7 +77,6 @@ def __init__( self._ddp_comm_wrapper = ddp_comm_wrapper self._local_rank = 0 self.set_world_ranks() - self._launcher = SpawnLauncher(self) @property def num_nodes(self) -> int: @@ -110,6 +109,9 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self): return True + def _configure_launcher(self): + self._launcher = _SpawnLauncher(self) + def setup(self, trainer: "pl.Trainer") -> None: os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) super().setup(trainer) diff --git a/pytorch_lightning/strategies/launchers/__init__.py b/pytorch_lightning/strategies/launchers/__init__.py index ce1330c295494..66523cae3874b 100644 --- a/pytorch_lightning/strategies/launchers/__init__.py +++ b/pytorch_lightning/strategies/launchers/__init__.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.strategies.launchers.base import Launcher -from pytorch_lightning.strategies.launchers.multi_process import MultiProcessLauncher -from pytorch_lightning.strategies.launchers.spawn import SpawnLauncher -from pytorch_lightning.strategies.launchers.xla_spawn import XLASpawnLauncher +from pytorch_lightning.strategies.launchers.base import _Launcher +from pytorch_lightning.strategies.launchers.multi_process import _SubprocessScriptLauncher +from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher +from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher __all__ = [ - "Launcher", - "MultiProcessLauncher", - "SpawnLauncher", - "XLASpawnLauncher", + "_Launcher", + "_SpawnLauncher", + "_SubprocessScriptLauncher", + "_XLASpawnLauncher", ] diff --git a/pytorch_lightning/strategies/launchers/base.py b/pytorch_lightning/strategies/launchers/base.py index 8b2041bc927dc..6d27786e96846 100644 --- a/pytorch_lightning/strategies/launchers/base.py +++ b/pytorch_lightning/strategies/launchers/base.py @@ -15,7 +15,7 @@ from typing import Any, Callable -class Launcher(ABC): +class _Launcher(ABC): r""" Abstract base class used to build new Launchers. diff --git a/pytorch_lightning/strategies/launchers/multi_process.py b/pytorch_lightning/strategies/launchers/multi_process.py index 28fd9224c8214..1577d822df5c0 100644 --- a/pytorch_lightning/strategies/launchers/multi_process.py +++ b/pytorch_lightning/strategies/launchers/multi_process.py @@ -30,7 +30,7 @@ from hydra.utils import get_original_cwd, to_absolute_path -class MultiProcessLauncher(Launcher): +class _SubprocessScriptLauncher(Launcher): r""" Creates and launches subprocess scripts on each device. """ @@ -45,7 +45,8 @@ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """Creates children scripts.""" kwargs.pop("trainer", None) - self._call_children_scripts() + if not self.cluster_environment.creates_processes_externally: + self._call_children_scripts() return function(*args, **kwargs) def _call_children_scripts(self) -> None: diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index fb2296fb546f0..d7327b079a6aa 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -30,7 +30,7 @@ from pytorch_lightning.utilities.types import _PATH -class SpawnLauncher(Launcher): +class _SpawnLauncher(Launcher): r""" Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes when it finishes. diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index f5f6240e60eb2..5fff9c2b0f12c 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -32,7 +32,7 @@ xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 -class XLASpawnLauncher(SpawnLauncher): +class _XLASpawnLauncher(SpawnLauncher): r""" Spawns processes using the `torch_xla` :func:`xmp.spawn` method and joins processes when it finishes. diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index c0c79d13ecc5b..3df616a786b19 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -27,7 +27,7 @@ from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.launchers.base import Launcher +from pytorch_lightning.strategies.launchers.base import _Launcher from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device @@ -49,7 +49,7 @@ def __init__( precision_plugin: Optional[PrecisionPlugin] = None, ) -> None: self.accelerator = accelerator - self._launcher = None + self._launcher: _Launcher = None self._model: Optional[Module] = None self.checkpoint_io = checkpoint_io self.precision_plugin = precision_plugin @@ -64,7 +64,7 @@ def __init__( ) @property - def launcher(self) -> Optional[Launcher]: + def launcher(self) -> Optional[_Launcher]: return self._launcher @property @@ -106,6 +106,9 @@ def connect(self, model: Module) -> None: """Called by the accelerator to connect the accelerator and the model with this plugin.""" self.model = model + def _configure_launcher(self): + """Attach the launcher based on Strategy.""" + def setup_environment(self) -> None: """Setup any processes or distributed connections. diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 86923ee022779..f3d855b43f8a6 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -24,7 +24,7 @@ from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy -from pytorch_lightning.strategies.launchers.xla_spawn import XLASpawnLauncher +from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters from pytorch_lightning.utilities.data import has_len @@ -68,7 +68,6 @@ def __init__( self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 self.start_method = "fork" - self._launcher = XLASpawnLauncher(self) @property def global_rank(self) -> int: @@ -117,6 +116,9 @@ def connect(self, model: "pl.LightningModule") -> None: self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model)) return super().connect(model) + def _configure_launcher(self): + self._launcher = _XLASpawnLauncher(self) + def setup(self, trainer: "pl.Trainer") -> None: self.start_method = "fork" self.accelerator.setup(trainer) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index d15b60433e4b9..8f770feb790c3 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -172,6 +172,7 @@ def __init__( self._set_devices_if_none() self.strategy = self.final_strategy() + self.strategy._configure_launcher() self.accelerator = self.strategy.accelerator self._check_plugin_compatibility() @@ -775,9 +776,6 @@ def resolve_strategy(self, training_type: Strategy) -> Strategy: # set sync_batchnorm for training_type from trainer setting training_type.sync_batchnorm = self.sync_batchnorm - if isinstance(training_type, DDPStrategy): - training_type.configure_multi_process_launcher() - return training_type def select_accelerator(self) -> Accelerator: diff --git a/tests/strategies/test_ddp_spawn_strategy.py b/tests/strategies/test_ddp_spawn_strategy.py index 7143d79788d87..1f5ef6fc76a30 100644 --- a/tests/strategies/test_ddp_spawn_strategy.py +++ b/tests/strategies/test_ddp_spawn_strategy.py @@ -96,6 +96,8 @@ def get_from_queue(self, trainer: Trainer, queue) -> None: class TestDDPSpawnStrategy(DDPSpawnStrategy): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + + def _configure_launcher(self): self._launcher = CustomSpawnLauncher(self) From 0dfbd2805c71014c5c7f743d500dcffcde3f3be5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 10 Feb 2022 12:02:36 -0500 Subject: [PATCH 25/36] rename to SubprocessScriptLauncher --- pytorch_lightning/strategies/ddp.py | 2 +- .../strategies/launchers/__init__.py | 2 +- .../strategies/launchers/multi_process.py | 118 ------------------ 3 files changed, 2 insertions(+), 120 deletions(-) delete mode 100644 pytorch_lightning/strategies/launchers/multi_process.py diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index dcb615bb9192a..4d87ebe15a5f4 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -32,7 +32,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.launchers.multi_process import _SubprocessScriptLauncher +from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( diff --git a/pytorch_lightning/strategies/launchers/__init__.py b/pytorch_lightning/strategies/launchers/__init__.py index 66523cae3874b..340a2c0160b0e 100644 --- a/pytorch_lightning/strategies/launchers/__init__.py +++ b/pytorch_lightning/strategies/launchers/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.strategies.launchers.base import _Launcher -from pytorch_lightning.strategies.launchers.multi_process import _SubprocessScriptLauncher from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher +from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher __all__ = [ diff --git a/pytorch_lightning/strategies/launchers/multi_process.py b/pytorch_lightning/strategies/launchers/multi_process.py deleted file mode 100644 index 1577d822df5c0..0000000000000 --- a/pytorch_lightning/strategies/launchers/multi_process.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import subprocess -import sys -from subprocess import Popen -from time import sleep -from typing import Any, Callable, List, Optional - -import __main__ -import numpy as np - -from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.strategies.launchers.base import Launcher -from pytorch_lightning.utilities import _HYDRA_AVAILABLE - -if _HYDRA_AVAILABLE: - from hydra.core.hydra_config import HydraConfig - from hydra.utils import get_original_cwd, to_absolute_path - - -class _SubprocessScriptLauncher(Launcher): - r""" - Creates and launches subprocess scripts on each device. - """ - - def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None: - super().__init__() - self.cluster_environment = cluster_environment - self.num_processes = num_processes - self.num_nodes = num_nodes - self.interactive_ddp_procs: List[Popen] = [] - - def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - """Creates children scripts.""" - kwargs.pop("trainer", None) - if not self.cluster_environment.creates_processes_externally: - self._call_children_scripts() - return function(*args, **kwargs) - - def _call_children_scripts(self) -> None: - # bookkeeping of spawned processes - self._check_can_spawn_children() - - # DDP Environment variables - os.environ["MASTER_ADDR"] = self.cluster_environment.main_address - os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) - - # allow the user to pass the node rank - os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) - os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) - - # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c` - # See https://docs.python.org/3/reference/import.html#main-spec - if __main__.__spec__ is None: # pragma: no-cover - # Script called as `python a/b/c.py` - # when user is using hydra find the absolute path - path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path - - # pull out the commands used to run the script and resolve the abs file path - command = sys.argv - try: - full_path = path_lib(command[0]) - except Exception: - full_path = os.path.abspath(command[0]) - - command[0] = full_path - # use the same python interpreter and actually running - command = [sys.executable] + command - else: # Script called as `python -m a.b.c` - command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:] - - os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}" - - self.interactive_ddp_procs = [] - - for local_rank in range(1, self.num_processes): - env_copy = os.environ.copy() - env_copy["LOCAL_RANK"] = f"{local_rank}" - - # remove env var if global seed not set - if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: - del env_copy["PL_GLOBAL_SEED"] - - # start process - # if hydra is available and initialized, make sure to set the cwd correctly - cwd: Optional[str] = None - if _HYDRA_AVAILABLE: - if HydraConfig.initialized(): - cwd = get_original_cwd() - os_cwd = f'"{os.getcwd()}"' - command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"] - proc = subprocess.Popen(command, env=env_copy, cwd=cwd) - self.interactive_ddp_procs.append(proc) - - # starting all processes at once can cause issues - # with dataloaders delay between 1-10 seconds - delay = np.random.uniform(1, 5, 1)[0] - sleep(delay) - - def _check_can_spawn_children(self) -> None: - if self.cluster_environment.local_rank() != 0: - raise RuntimeError( - "Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen." - " Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user," - " 2) `ClusterEnvironment.creates_processes_externally` incorrectly implemented." - ) From 329512129b662175f8a9b6be198f570d80b72a41 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 10 Feb 2022 13:07:01 -0500 Subject: [PATCH 26/36] protected and update test --- pytorch_lightning/lite/lite.py | 6 +----- pytorch_lightning/strategies/launchers/spawn.py | 4 ++-- .../strategies/launchers/xla_spawn.py | 4 ++-- setup.cfg | 16 ++++++++-------- tests/strategies/test_ddp_spawn_strategy.py | 4 ++-- tests/trainer/test_trainer.py | 4 ++-- 6 files changed, 17 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index c7351d92af2cb..b3152b197fbbe 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -401,18 +401,14 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: # apply sharded context to prevent OOM run_method = partial(self._run_with_sharded_context, run_method) - run_method = partial(self._wrapped_run_method, run_method) if self._strategy.launcher is not None: return self._strategy.launcher.launch(run_method, *args, **kwargs) else: return run_method(*args, **kwargs) - def _wrapped_run_method(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: - self._strategy.setup_environment() - return run_method(*args, **kwargs) - def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: + self._strategy.setup_environment() with self._strategy.model_sharded_context(), _replace_dataloader_init_method(): return run_method(*args, **kwargs) diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index d7327b079a6aa..4b1ba5c11bb82 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -21,7 +21,7 @@ import torch.multiprocessing as mp import pytorch_lightning as pl -from pytorch_lightning.strategies.launchers.base import Launcher +from pytorch_lightning.strategies.launchers.base import _Launcher from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.trainer.states import TrainerFn, TrainerState from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device @@ -30,7 +30,7 @@ from pytorch_lightning.utilities.types import _PATH -class _SpawnLauncher(Launcher): +class _SpawnLauncher(_Launcher): r""" Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes when it finishes. diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index 5fff9c2b0f12c..a46b30ddd9dc2 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -19,7 +19,7 @@ import torch.multiprocessing as mp import pytorch_lightning as pl -from pytorch_lightning.strategies.launchers.spawn import _FakeQueue, _SpawnOutput, SpawnLauncher +from pytorch_lightning.strategies.launchers.spawn import _FakeQueue, _SpawnLauncher, _SpawnOutput from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device @@ -32,7 +32,7 @@ xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 -class _XLASpawnLauncher(SpawnLauncher): +class _XLASpawnLauncher(_SpawnLauncher): r""" Spawns processes using the `torch_xla` :func:`xmp.spawn` method and joins processes when it finishes. diff --git a/setup.cfg b/setup.cfg index 79ab35616ed61..75552c5bdb7b6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,14 +24,14 @@ addopts = --doctest-modules --color=yes --disable-pytest-warnings -filterwarnings = - # error out on our deprecation warnings - ensures the code and tests are kept up-to-date - error::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning - error::FutureWarning - # warnings from deprecated modules on import - # TODO: remove in 1.7 - ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.decorators - ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.memory +# filterwarnings = +# # error out on our deprecation warnings - ensures the code and tests are kept up-to-date +# error::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning +# error::FutureWarning +# # warnings from deprecated modules on import +# # TODO: remove in 1.7 +# ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.decorators +# ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.memory junit_duration_report = call diff --git a/tests/strategies/test_ddp_spawn_strategy.py b/tests/strategies/test_ddp_spawn_strategy.py index 1f5ef6fc76a30..1ae6dca5dc5c9 100644 --- a/tests/strategies/test_ddp_spawn_strategy.py +++ b/tests/strategies/test_ddp_spawn_strategy.py @@ -20,7 +20,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.strategies import DDPSpawnStrategy -from pytorch_lightning.strategies.launchers.spawn import SpawnLauncher +from pytorch_lightning.strategies.launchers.spawn import _SpawnLauncher from pytorch_lightning.trainer.states import TrainerFn from tests.helpers.boring_model import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -83,7 +83,7 @@ def test_ddp_spawn_extra_parameters(tmpdir): assert model.test_val == "test_val" -class CustomSpawnLauncher(SpawnLauncher): +class CustomSpawnLauncher(_SpawnLauncher): def add_to_queue(self, trainer, queue) -> None: queue.put("new_test_val") return super().add_to_queue(trainer, queue) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0e2f488b4506e..d0987411e75d3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1039,7 +1039,7 @@ def configure_gradient_clipping(self, *args, **kwargs): # test that gradient is clipped correctly parameters = self.parameters() grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - torch.testing.assert_allclose(grad_norm, torch.tensor(0.05)) + assert torch.allclose(grad_norm, torch.tensor(0.05), rtol=1e-4) self.assertion_called = True model = TestModel() @@ -1070,7 +1070,7 @@ def configure_gradient_clipping(self, *args, **kwargs): parameters = self.parameters() grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] grad_max = torch.max(torch.stack(grad_max_list)) - torch.testing.assert_allclose(grad_max.abs(), torch.tensor(1e-10)) + assert torch.allclose(grad_max.abs(), torch.tensor(1e-10)) self.assertion_called = True model = TestModel() From 81cf4c1035ccd94da8e81e2c41c4625df465349b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 10 Feb 2022 13:20:59 -0500 Subject: [PATCH 27/36] enable dep failure --- setup.cfg | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/setup.cfg b/setup.cfg index 75552c5bdb7b6..5ef62709a3f4c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,14 +24,14 @@ addopts = --doctest-modules --color=yes --disable-pytest-warnings -# filterwarnings = -# # error out on our deprecation warnings - ensures the code and tests are kept up-to-date -# error::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning -# error::FutureWarning -# # warnings from deprecated modules on import -# # TODO: remove in 1.7 -# ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.decorators -# ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.memory +filterwarnings = + # error out on our deprecation warnings - ensures the code and tests are kept up-to-date + error::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning + error::FutureWarning + # warnings from deprecated modules on import + # TODO: remove in 1.7 + ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.decorators + ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.memory junit_duration_report = call From 25c241738accf8b11200d2329f5c64ac695ee81d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 10 Feb 2022 13:23:49 -0500 Subject: [PATCH 28/36] add script launcher --- .../strategies/launchers/subprocess_script.py | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 pytorch_lightning/strategies/launchers/subprocess_script.py diff --git a/pytorch_lightning/strategies/launchers/subprocess_script.py b/pytorch_lightning/strategies/launchers/subprocess_script.py new file mode 100644 index 0000000000000..deeff7bf755f0 --- /dev/null +++ b/pytorch_lightning/strategies/launchers/subprocess_script.py @@ -0,0 +1,118 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import subprocess +import sys +from subprocess import Popen +from time import sleep +from typing import Any, Callable, List, Optional + +import __main__ +import numpy as np + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.strategies.launchers.base import _Launcher +from pytorch_lightning.utilities import _HYDRA_AVAILABLE + +if _HYDRA_AVAILABLE: + from hydra.core.hydra_config import HydraConfig + from hydra.utils import get_original_cwd, to_absolute_path + + +class _SubprocessScriptLauncher(_Launcher): + r""" + Creates and launches subprocess scripts on each device. + """ + + def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None: + super().__init__() + self.cluster_environment = cluster_environment + self.num_processes = num_processes + self.num_nodes = num_nodes + self.interactive_ddp_procs: List[Popen] = [] + + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + """Creates children scripts.""" + kwargs.pop("trainer", None) + if not self.cluster_environment.creates_processes_externally: + self._call_children_scripts() + return function(*args, **kwargs) + + def _call_children_scripts(self) -> None: + # bookkeeping of spawned processes + self._check_can_spawn_children() + + # DDP Environment variables + os.environ["MASTER_ADDR"] = self.cluster_environment.main_address + os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) + + # allow the user to pass the node rank + os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) + os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) + + # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c` + # See https://docs.python.org/3/reference/import.html#main-spec + if __main__.__spec__ is None: # pragma: no-cover + # Script called as `python a/b/c.py` + # when user is using hydra find the absolute path + path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path + + # pull out the commands used to run the script and resolve the abs file path + command = sys.argv + try: + full_path = path_lib(command[0]) + except Exception: + full_path = os.path.abspath(command[0]) + + command[0] = full_path + # use the same python interpreter and actually running + command = [sys.executable] + command + else: # Script called as `python -m a.b.c` + command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:] + + os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}" + + self.interactive_ddp_procs = [] + + for local_rank in range(1, self.num_processes): + env_copy = os.environ.copy() + env_copy["LOCAL_RANK"] = f"{local_rank}" + + # remove env var if global seed not set + if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: + del env_copy["PL_GLOBAL_SEED"] + + # start process + # if hydra is available and initialized, make sure to set the cwd correctly + cwd: Optional[str] = None + if _HYDRA_AVAILABLE: + if HydraConfig.initialized(): + cwd = get_original_cwd() + os_cwd = f'"{os.getcwd()}"' + command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"] + proc = subprocess.Popen(command, env=env_copy, cwd=cwd) + self.interactive_ddp_procs.append(proc) + + # starting all processes at once can cause issues + # with dataloaders delay between 1-10 seconds + delay = np.random.uniform(1, 5, 1)[0] + sleep(delay) + + def _check_can_spawn_children(self) -> None: + if self.cluster_environment.local_rank() != 0: + raise RuntimeError( + "Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen." + " Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user," + " 2) `ClusterEnvironment.creates_processes_externally` incorrectly implemented." + ) From fa3017316e7f7357ed5951f9eb63c8b575eb5010 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Feb 2022 03:37:51 +0100 Subject: [PATCH 29/36] add extensive docs --- pytorch_lightning/strategies/ddp.py | 6 +--- .../strategies/launchers/subprocess_script.py | 35 ++++++++++++++++++- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 4d87ebe15a5f4..1e145efa8c665 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -61,11 +61,7 @@ class DDPStrategy(ParallelStrategy): - """Plugin for multi-process single-device training on one or multiple nodes. - - The main process in each node spawns N-1 child processes via :func:`subprocess.Popen`, where N is the number of - devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes. - """ + """Strategy for multi-process single-device training on one or multiple nodes.""" distributed_backend = _StrategyType.DDP diff --git a/pytorch_lightning/strategies/launchers/subprocess_script.py b/pytorch_lightning/strategies/launchers/subprocess_script.py index deeff7bf755f0..d354fd407d364 100644 --- a/pytorch_lightning/strategies/launchers/subprocess_script.py +++ b/pytorch_lightning/strategies/launchers/subprocess_script.py @@ -32,7 +32,40 @@ class _SubprocessScriptLauncher(_Launcher): r""" - Creates and launches subprocess scripts on each device. + A process laucher that invokes the current script as many times as desired in a single node. + + This launcher needs to be invoked on each node. + In its default behavior, the main process in each node then spawns N-1 child processes via :func:`subprocess.Popen`, + where N is the number of devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.run` + launches processes. + + For example, if the script gets invoked with the command + + .. code-block:: bash + + python train.py --devices 4 + + The launcher will create three additional subprocesses that get called like so: + + .. code-block:: bash + + LOCAL_RANK=1 python train.py --devices 4 + LOCAL_RANK=2 python train.py --devices 4 + LOCAL_RANK=3 python train.py --devices 4 + + It is implied that the main process which launched the others has ``LOCAL_RANK=0``. + Beside the local rank, the following other environment variables also get set, but unlike the local rank, these + get determined by the cluster environment: + + 1. `MASTER_ADDR`: The IP address of the main node. + 2. `MASTER_PORT`: The port number of the main node through which all processes communicate. + 3. `NODE_RANK`: The index of the node the current process is running on. Ranges from 0 to ``num_nodes - 1``. + 4. `WORLD_SIZE`: The total number of processes across all nodes, i.e., ``num_processes * num_nodes``. + + Arguments: + cluster_environment: A cluster environment that provides access to world size, node rank, etc. + num_processes: The number of processes to launch in the current node. + num_nodes: The total number of nodes that participate in this process group. """ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None: From e2800924ee90a4d85a22743c2b34d0957b04433f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Feb 2022 04:35:55 +0100 Subject: [PATCH 30/36] add extensive docs --- .../strategies/launchers/base.py | 8 +++-- .../strategies/launchers/spawn.py | 30 ++++++++++++++++--- .../strategies/launchers/subprocess_script.py | 9 +++++- .../strategies/launchers/xla_spawn.py | 27 ++++++++++++++--- 4 files changed, 63 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/strategies/launchers/base.py b/pytorch_lightning/strategies/launchers/base.py index 6d27786e96846..66a34369db10b 100644 --- a/pytorch_lightning/strategies/launchers/base.py +++ b/pytorch_lightning/strategies/launchers/base.py @@ -17,9 +17,13 @@ class _Launcher(ABC): r""" - Abstract base class used to build new Launchers. + Abstract base class for all Launchers. - Subclass this class and override any of the relevant methods + Launchers are responsible for the creation and instrumentation of new processes so that the + :class:`~pytorch_lightning.strategies.base.Strategy` can set up communication between all them. + + Subclass this class and override any of the relevant methods to provide a custom implementation depending on + cluster environment, hardware, strategy, etc. """ @abstractmethod diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index 4b1ba5c11bb82..19f60b882229c 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -31,16 +31,38 @@ class _SpawnLauncher(_Launcher): - r""" - Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes when it - finishes. + r"""Spawns processes that run a given function in parallel, and joins them all at the end. + + The main process in which this launcher is invoked creates N so-called worker processes (using + :func:`torch.multiprocessing.spawn`) that run the given function. + Worker processes have a rank that ranges from 0 to N - 1. + + Note: + - This launcher requires all objects to be pickleable. + - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. + + Arguments: + strategy: A reference to the strategy that is used together with this launcher. """ def __init__(self, strategy: Strategy) -> None: self._strategy = strategy def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - """Creates spawn processes and join them at the end.""" + """Spawns processes that run the given function in parallel. + + The function is allowed to have a return value. However, when all processes join, only the return value + of worker process 0 gets returned from this `launch` method in the main process. + + Arguments: + function: The entry point for all spawned processes. + *args: Optional positional arguments to be passed to the given function. + **kwargs: Optional keyword arguments to be passed to the given function. + If a keyword argument named `trainer` is present and is an instance of + :class:`~pytorch_lightning.trainer.trainer.Trainer`, a selected set of attributes from the trainer get + restored in the main process after processes join. The `trainer` keyword argument will NOT be passed + into the function. + """ trainer = kwargs.pop("trainer", None) os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port) context = mp.get_context("spawn") diff --git a/pytorch_lightning/strategies/launchers/subprocess_script.py b/pytorch_lightning/strategies/launchers/subprocess_script.py index d354fd407d364..e4b41500412d3 100644 --- a/pytorch_lightning/strategies/launchers/subprocess_script.py +++ b/pytorch_lightning/strategies/launchers/subprocess_script.py @@ -76,7 +76,14 @@ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, self.interactive_ddp_procs: List[Popen] = [] def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - """Creates children scripts.""" + """Creates new processes, then calls the given function. + + Arguments: + function: A callback function to execute after all processes have been created. + It is up to the implementation of this function to synchronize the processes, e.g., with barriers. + *args: Optional positional arguments to be passed to the given function. + **kwargs: Optional keyword arguments to be passed to the given function. + """ kwargs.pop("trainer", None) if not self.cluster_environment.creates_processes_externally: self._call_children_scripts() diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index a46b30ddd9dc2..814b0447a8a0c 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -33,13 +33,32 @@ class _XLASpawnLauncher(_SpawnLauncher): - r""" - Spawns processes using the `torch_xla` :func:`xmp.spawn` method and joins processes when it - finishes. + r"""Spawns processes that run a given function in parallel on XLA supported hardware, and joins them all at the end. + + The main process in which this launcher is invoked creates N so-called worker processes (using the + `torch_xla` :func:`xmp.spawn`) that run the given function. + Worker processes have a rank that ranges from 0 to N - 1. + + Note: + - This launcher requires all objects to be pickleable. + - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. """ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - """Creates spawn processes and join them at the end.""" + """Spawns processes that run the given function in parallel. + + The function is allowed to have a return value. However, when all processes join, only the return value + of worker process 0 gets returned from this `launch` method in the main process. + + Arguments: + function: The entry point for all spawned processes. + *args: Optional positional arguments to be passed to the given function. + **kwargs: Optional keyword arguments to be passed to the given function. + If a keyword argument named `trainer` is present and is an instance of + :class:`~pytorch_lightning.trainer.trainer.Trainer`, a selected set of attributes from the trainer get + restored in the main process after processes join. The `trainer` keyword argument will NOT be passed + into the function. + """ trainer = kwargs.pop("trainer", None) context = mp.get_context(self._strategy.start_method or "fork") return_queue = context.SimpleQueue() From eb567c9458b5f1a1dd52887e684bbdadf2f04526 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Feb 2022 04:41:03 +0100 Subject: [PATCH 31/36] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: ananthsub Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/strategies/ddp.py | 2 +- pytorch_lightning/strategies/launchers/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 4d87ebe15a5f4..4bf3336bddb04 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -133,7 +133,7 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self) -> bool: return True - def _configure_launcher(self): + def _configure_launcher(self) -> None: self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) if not self.cluster_environment.creates_processes_externally: self._rank_0_will_call_children_scripts = True diff --git a/pytorch_lightning/strategies/launchers/base.py b/pytorch_lightning/strategies/launchers/base.py index 6d27786e96846..34c105dafba55 100644 --- a/pytorch_lightning/strategies/launchers/base.py +++ b/pytorch_lightning/strategies/launchers/base.py @@ -24,4 +24,4 @@ class _Launcher(ABC): @abstractmethod def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: - """Launches the proceses.""" + """Launches the processes.""" From 0930b1f7acc844479532d27238341b95ddc51531 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Feb 2022 04:52:46 +0100 Subject: [PATCH 32/36] fix typing --- pytorch_lightning/strategies/launchers/spawn.py | 10 ++++------ pytorch_lightning/strategies/launchers/xla_spawn.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index 19f60b882229c..467ee793eef90 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -100,15 +100,13 @@ def _wrapping_function( def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: # transfer back the best path to the trainer if trainer.checkpoint_callback: - trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path + trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path) # TODO: pass also best score # load last weights if spawn_output.weights_path is not None: - ckpt = self._strategy.checkpoint_io.load_checkpoint( - spawn_output.weights_path, map_location=(lambda storage, loc: storage) - ) - trainer.lightning_module.load_state_dict(ckpt) + ckpt = self._strategy.checkpoint_io.load_checkpoint(spawn_output.weights_path) + trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type] self._strategy.checkpoint_io.remove_checkpoint(spawn_output.weights_path) trainer.state = spawn_output.trainer_state @@ -129,7 +127,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt state_dict = trainer.lightning_module.state_dict() if self._strategy.global_rank != 0: - return + return None # save the last weights weights_path = None diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index 814b0447a8a0c..8bac7888c568b 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -116,7 +116,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training if self._strategy.local_rank != 0: - return + return None # adds the `callback_metrics` to the queue extra = _FakeQueue() From 51b18f30432e596c3da9f66a1d0c5ce8f205a604 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 17 Feb 2022 20:53:02 +0530 Subject: [PATCH 33/36] address Adrian's comments --- pytorch_lightning/core/lightning.py | 6 ++---- pytorch_lightning/lite/lite.py | 4 ++-- .../trainer/configuration_validator.py | 6 ++---- setup.cfg | 14 +++++++------- tests/strategies/test_ddp_spawn_strategy.py | 3 --- 5 files changed, 13 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4287e885f030c..098956a703a8a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1970,8 +1970,7 @@ def add_to_queue(self, queue: pl.strategies.launchers.spawn._FakeQueue) -> None: queue: the instance of the queue to append the data. .. deprecated:: v1.5 - This method was deprecated in v1.5 in favor of `DDPSpawnStrategy.add_to_queue` - and will be removed in v1.7. + This method was deprecated in v1.5 and will be removed in v1.7. """ def get_from_queue(self, queue: pl.strategies.launchers.spawn._FakeQueue) -> None: @@ -1982,8 +1981,7 @@ def get_from_queue(self, queue: pl.strategies.launchers.spawn._FakeQueue) -> Non queue: the instance of the queue from where to get the data. .. deprecated:: v1.5 - This method was deprecated in v1.5 in favor of `DDPSpawnStrategy.get_from_queue` - and will be removed in v1.7. + This method was deprecated in v1.5 and will be removed in v1.7. """ @contextmanager diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index b3152b197fbbe..52df7954a64a2 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -400,14 +400,14 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: # apply sharded context to prevent OOM - run_method = partial(self._run_with_sharded_context, run_method) + run_method = partial(self._run_with_strategy_setup, run_method) if self._strategy.launcher is not None: return self._strategy.launcher.launch(run_method, *args, **kwargs) else: return run_method(*args, **kwargs) - def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: + def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._strategy.setup_environment() with self._strategy.model_sharded_context(), _replace_dataloader_init_method(): return run_method(*args, **kwargs) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index ac1088646d8f5..20d578fa2d46f 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -261,13 +261,11 @@ def _check_add_get_queue(model: "pl.LightningModule") -> None: """ if is_overridden("add_to_queue", model): rank_zero_deprecation( - "The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7 in " - "favor of `DDPSpawnStrategy.add_to_queue`" + "The `LightningModule.add_to_queue` method was deprecated in v1.5 and will be removed in v1.7." ) if is_overridden("get_from_queue", model): rank_zero_deprecation( - "The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7 in " - "favor of `DDPSpawnStrategy.get_from_queue`" + "The `LightningModule.get_from_queue` method was deprecated in v1.5 and will be removed in v1.7." ) diff --git a/setup.cfg b/setup.cfg index 5ef62709a3f4c..79ab35616ed61 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,13 +25,13 @@ addopts = --color=yes --disable-pytest-warnings filterwarnings = - # error out on our deprecation warnings - ensures the code and tests are kept up-to-date - error::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning - error::FutureWarning - # warnings from deprecated modules on import - # TODO: remove in 1.7 - ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.decorators - ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.memory + # error out on our deprecation warnings - ensures the code and tests are kept up-to-date + error::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning + error::FutureWarning + # warnings from deprecated modules on import + # TODO: remove in 1.7 + ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.decorators + ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.memory junit_duration_report = call diff --git a/tests/strategies/test_ddp_spawn_strategy.py b/tests/strategies/test_ddp_spawn_strategy.py index 1ae6dca5dc5c9..ef33a3ba56a43 100644 --- a/tests/strategies/test_ddp_spawn_strategy.py +++ b/tests/strategies/test_ddp_spawn_strategy.py @@ -94,9 +94,6 @@ def get_from_queue(self, trainer: Trainer, queue) -> None: class TestDDPSpawnStrategy(DDPSpawnStrategy): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - def _configure_launcher(self): self._launcher = CustomSpawnLauncher(self) From 67bc8705c51fff0f13d07c20b2aef7a64dc3c886 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 17 Feb 2022 20:54:07 +0530 Subject: [PATCH 34/36] Update pytorch_lightning/strategies/strategy.py Co-authored-by: ananthsub --- pytorch_lightning/strategies/strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index 3df616a786b19..37b9b435b7413 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -49,7 +49,7 @@ def __init__( precision_plugin: Optional[PrecisionPlugin] = None, ) -> None: self.accelerator = accelerator - self._launcher: _Launcher = None + self._launcher: Optional[_Launcher] = None self._model: Optional[Module] = None self.checkpoint_io = checkpoint_io self.precision_plugin = precision_plugin From 8286ec85d1e0e27cb80fd61a10c11186bd9e9ba3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 17 Feb 2022 21:34:59 +0100 Subject: [PATCH 35/36] Update pytorch_lightning/strategies/launchers/spawn.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/strategies/launchers/spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index 467ee793eef90..d1349fd39cd97 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -41,7 +41,7 @@ class _SpawnLauncher(_Launcher): - This launcher requires all objects to be pickleable. - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. - Arguments: + Args: strategy: A reference to the strategy that is used together with this launcher. """ From 4397f2d7481aad5d8a476e201d6a28daea5c4aa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 17 Feb 2022 21:37:43 +0100 Subject: [PATCH 36/36] revert changes in test_trainer.py --- tests/trainer/test_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 008ff501a7c25..587ff0b7b9f72 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1039,7 +1039,7 @@ def configure_gradient_clipping(self, *args, **kwargs): # test that gradient is clipped correctly parameters = self.parameters() grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - assert torch.allclose(grad_norm, torch.tensor(0.05), rtol=1e-4) + torch.testing.assert_allclose(grad_norm, torch.tensor(0.05)) self.assertion_called = True model = TestModel() @@ -1070,7 +1070,7 @@ def configure_gradient_clipping(self, *args, **kwargs): parameters = self.parameters() grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] grad_max = torch.max(torch.stack(grad_max_list)) - assert torch.allclose(grad_max.abs(), torch.tensor(1e-10)) + torch.testing.assert_allclose(grad_max.abs(), torch.tensor(1e-10)) self.assertion_called = True model = TestModel()