Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9687e73
add executors for strategies
rohitgr7 Jan 27, 2022
7f325f5
add license
rohitgr7 Jan 27, 2022
9114b09
fix issues
rohitgr7 Jan 27, 2022
e054d0e
return results
rohitgr7 Jan 27, 2022
a2248b4
fix os
rohitgr7 Jan 27, 2022
8751c36
fix DDP
rohitgr7 Jan 27, 2022
eea1fa7
rm redundant code
rohitgr7 Jan 27, 2022
e46b824
fix import
rohitgr7 Jan 27, 2022
4457eff
executor -> launcher
rohitgr7 Feb 2, 2022
184b222
reduce arguments
rohitgr7 Feb 2, 2022
a2f8d0a
fix tpu and mypy
rohitgr7 Feb 2, 2022
72b1bb9
fix tpu and mypy
rohitgr7 Feb 2, 2022
e7bb90b
fix tpu and mypy
rohitgr7 Feb 2, 2022
9857ffe
recover results for TPU
rohitgr7 Feb 3, 2022
d2817f3
lite patch
rohitgr7 Feb 3, 2022
b6bd4d8
fix deadlock detection
rohitgr7 Feb 3, 2022
ce49e15
lite spawn patch
rohitgr7 Feb 3, 2022
2314253
constructor initialize
rohitgr7 Feb 4, 2022
02b0873
Merge branch 'master' into feat/executor
rohitgr7 Feb 7, 2022
f094fc0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2022
75c0731
Apply suggestions from code review
rohitgr7 Feb 7, 2022
a3de1e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2022
5da734e
review comments
rohitgr7 Feb 9, 2022
6df0862
pre-commit
rohitgr7 Feb 9, 2022
a485da9
Merge remote-tracking branch 'origin/master' into feat/executor
rohitgr7 Feb 10, 2022
1a7013c
configure_launcher
rohitgr7 Feb 10, 2022
0dfbd28
rename to SubprocessScriptLauncher
rohitgr7 Feb 10, 2022
3295121
protected and update test
rohitgr7 Feb 10, 2022
81cf4c1
enable dep failure
rohitgr7 Feb 10, 2022
25c2417
add script launcher
rohitgr7 Feb 10, 2022
fa30173
add extensive docs
awaelchli Feb 12, 2022
e280092
add extensive docs
awaelchli Feb 12, 2022
eb567c9
Apply suggestions from code review
awaelchli Feb 12, 2022
197ff6b
Merge remote-tracking branch 'origin/feat/executor' into feat/executor
awaelchli Feb 12, 2022
0930b1f
fix typing
awaelchli Feb 12, 2022
51b18f3
address Adrian's comments
rohitgr7 Feb 17, 2022
67bc870
Update pytorch_lightning/strategies/strategy.py
rohitgr7 Feb 17, 2022
e3d155c
Merge remote-tracking branch 'origin/master' into feat/executor
rohitgr7 Feb 17, 2022
8286ec8
Update pytorch_lightning/strategies/launchers/spawn.py
awaelchli Feb 17, 2022
4397f2d
revert changes in test_trainer.py
awaelchli Feb 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1962,28 +1962,26 @@ def model_size(self) -> float:
)
return get_model_size_mb(self)

def add_to_queue(self, queue: pl.strategies.ddp_spawn._FakeQueue) -> None:
def add_to_queue(self, queue: pl.strategies.launchers.spawn._FakeQueue) -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
sharing, we cast the data to numpy.

Args:
queue: the instance of the queue to append the data.

.. deprecated:: v1.5
This method was deprecated in v1.5 in favor of `DDPSpawnStrategy.add_to_queue`
and will be removed in v1.7.
This method was deprecated in v1.5 and will be removed in v1.7.
"""

def get_from_queue(self, queue: pl.strategies.ddp_spawn._FakeQueue) -> None:
def get_from_queue(self, queue: pl.strategies.launchers.spawn._FakeQueue) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
we cast back the data to ``torch.Tensor``.

Args:
queue: the instance of the queue from where to get the data.

.. deprecated:: v1.5
This method was deprecated in v1.5 in favor of `DDPSpawnStrategy.get_from_queue`
and will be removed in v1.7.
This method was deprecated in v1.5 and will be removed in v1.7.
"""

@contextmanager
Expand Down
13 changes: 6 additions & 7 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.plugins import PLUGIN_INPUT
from pytorch_lightning.strategies import DDPSpawnStrategy, DeepSpeedStrategy, Strategy, TPUSpawnStrategy
from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy, TPUSpawnStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_data_to_device
Expand Down Expand Up @@ -399,17 +399,16 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
return seed_everything(seed=seed, workers=workers)

def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()

# apply sharded context to prevent OOM
run_method = partial(self._run_with_sharded_context, run_method)
run_method = partial(self._run_with_strategy_setup, run_method)

if isinstance(self._strategy, DDPSpawnStrategy):
return self._strategy.spawn(run_method, *args, **kwargs)
if self._strategy.launcher is not None:
return self._strategy.launcher.launch(run_method, *args, **kwargs)
else:
return run_method(*args, **kwargs)

def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
with self._strategy.model_sharded_context(), _replace_dataloader_init_method():
return run_method(*args, **kwargs)

Expand Down
13 changes: 13 additions & 0 deletions pytorch_lightning/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path

from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
Expand Down
100 changes: 9 additions & 91 deletions pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,11 @@
import os
import shutil
import signal
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from time import sleep
from typing import Any, Dict, List, Optional, Union

import __main__
import numpy as np
import torch
import torch.distributed
from torch.nn import Module
Expand All @@ -37,11 +32,11 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
_FAIRSCALE_AVAILABLE,
_HYDRA_AVAILABLE,
_IS_WINDOWS,
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
Expand All @@ -58,9 +53,6 @@

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook

Expand All @@ -69,11 +61,7 @@


class DDPStrategy(ParallelStrategy):
"""Plugin for multi-process single-device training on one or multiple nodes.

The main process in each node spawns N-1 child processes via :func:`subprocess.Popen`, where N is the number of
devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes.
"""
"""Strategy for multi-process single-device training on one or multiple nodes."""

distributed_backend = _StrategyType.DDP

Expand All @@ -98,7 +86,6 @@ def __init__(
precision_plugin=precision_plugin,
)
log.detail(f"{self.__class__.__name__}: initializing DDP plugin")
self.interactive_ddp_procs = []
self._num_nodes = 1
self.sync_batchnorm = False
self._ddp_kwargs = kwargs
Expand All @@ -108,7 +95,7 @@ def __init__(
self._model_averaging_period = model_averaging_period
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self._rank_0_has_called_call_children_scripts: bool = False
self._rank_0_will_call_children_scripts: bool = False
self.set_world_ranks()

@property
Expand Down Expand Up @@ -142,18 +129,19 @@ def distributed_sampler_kwargs(self):
def _is_single_process_single_device(self) -> bool:
return True

def setup_environment(self) -> None:
# start the other scripts
def _configure_launcher(self) -> None:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
if not self.cluster_environment.creates_processes_externally:
self._call_children_scripts()
self._rank_0_will_call_children_scripts = True

def setup_environment(self) -> None:
self.setup_distributed()
super().setup_environment()

def setup(self, trainer: "pl.Trainer") -> None:
super().setup(trainer)
# share ddp pids to all processes
self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts)
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()

Expand All @@ -174,68 +162,6 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
log.detail(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

def _call_children_scripts(self):
# bookkeeping of spawned processes
self._check_can_spawn_children()

# DDP Environment variables
os.environ["MASTER_ADDR"] = self.cluster_environment.main_address
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)

# allow the user to pass the node rank
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())

# Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c`
# See https://docs.python.org/3/reference/import.html#main-spec
if __main__.__spec__ is None: # pragma: no-cover
# Script called as `python a/b/c.py`
# when user is using hydra find the absolute path
path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path

# pull out the commands used to run the script and resolve the abs file path
command = sys.argv
try:
full_path = path_lib(command[0])
except Exception:
full_path = os.path.abspath(command[0])

command[0] = full_path
# use the same python interpreter and actually running
command = [sys.executable] + command
else: # Script called as `python -m a.b.c`
command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]

os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"

self.interactive_ddp_procs = []

for local_rank in range(1, self.num_processes):
env_copy = os.environ.copy()
env_copy["LOCAL_RANK"] = f"{local_rank}"

# remove env var if global seed not set
if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
del env_copy["PL_GLOBAL_SEED"]

# start process
# if hydra is available and initialized, make sure to set the cwd correctly
cwd: Optional[str] = None
if _HYDRA_AVAILABLE:
if HydraConfig.initialized():
cwd = get_original_cwd()
os_cwd = f'"{os.getcwd()}"'
command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]
proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
self.interactive_ddp_procs.append(proc)

# starting all processes at once can cause issues
# with dataloaders delay between 1-10 seconds
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)

self._rank_0_has_called_call_children_scripts = True

def setup_distributed(self):
log.detail(f"{self.__class__.__name__}: setting up distributed...")
reset_seed()
Expand All @@ -251,14 +177,6 @@ def setup_distributed(self):
# where to store ip_table
init_dist_connection(self.cluster_environment, self.torch_distributed_backend)

def _check_can_spawn_children(self):
if self.local_rank != 0:
raise RuntimeError(
"Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen."
" Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user,"
" 2) `ClusterEnvironment.creates_processes_externally` incorrectly implemented."
)

def set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
Expand Down Expand Up @@ -436,7 +354,7 @@ def _should_run_deadlock_detection(self) -> bool:
By default this is disabled. Otherwise, if the cluster environment creates the processes, allow the scheduler /
parent process to perform the process termination, external to Lightning.
"""
return os.getenv("PL_RECONCILE_PROCESS", "0") == "1" or self._rank_0_has_called_call_children_scripts
return os.getenv("PL_RECONCILE_PROCESS", "0") == "1" or self._rank_0_will_call_children_scripts

def _share_information_to_prevent_deadlock(self) -> None:
self._share_pids()
Expand Down
Loading