Skip to content

Commit 45ca781

Browse files
jgboscarmoccarohitgr7Borda
authored
Improving Hydra+DDP support (#11617)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: rohitgr7 <[email protected]> Co-authored-by: Jirka <[email protected]>
1 parent dd2a1c5 commit 45ca781

File tree

2 files changed

+213
-37
lines changed

2 files changed

+213
-37
lines changed

src/pytorch_lightning/strategies/launchers/subprocess_script.py

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import subprocess
1616
import sys
1717
from time import sleep
18-
from typing import Any, Callable, Optional
18+
from typing import Any, Callable, Optional, Sequence
1919

2020
import __main__
2121
import numpy as np
@@ -25,7 +25,7 @@
2525
from lightning_lite.plugins import ClusterEnvironment
2626
from lightning_lite.strategies.launchers.base import _Launcher
2727

28-
_HYDRA_AVAILABLE = RequirementCache("hydra")
28+
_HYDRA_AVAILABLE = RequirementCache("hydra-core")
2929

3030

3131
class _SubprocessScriptLauncher(_Launcher):
@@ -101,32 +101,6 @@ def _call_children_scripts(self) -> None:
101101
# allow the user to pass the node rank
102102
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
103103
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
104-
105-
# Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c`
106-
# See https://docs.python.org/3/reference/import.html#main-spec
107-
if __main__.__spec__ is None: # pragma: no-cover
108-
# Script called as `python a/b/c.py`
109-
if _HYDRA_AVAILABLE:
110-
# when user is using hydra find the absolute path
111-
from hydra.utils import to_absolute_path
112-
113-
to_abs_path = to_absolute_path
114-
else:
115-
to_abs_path = os.path.abspath
116-
117-
# pull out the commands used to run the script and resolve the absolute file path
118-
command = sys.argv
119-
try:
120-
full_path = to_abs_path(command[0])
121-
except Exception:
122-
full_path = os.path.abspath(command[0])
123-
124-
command[0] = full_path
125-
# use the same python interpreter and actually running
126-
command = [sys.executable] + command
127-
else: # Script called as `python -m a.b.c`
128-
command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]
129-
130104
os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"
131105

132106
for local_rank in range(1, self.num_processes):
@@ -137,18 +111,18 @@ def _call_children_scripts(self) -> None:
137111
if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
138112
del env_copy["PL_GLOBAL_SEED"]
139113

140-
# start process
141-
# if hydra is available and initialized, make sure to set the cwd correctly
142-
cwd: Optional[str] = None
114+
hydra_in_use = False
143115
if _HYDRA_AVAILABLE:
144116
from hydra.core.hydra_config import HydraConfig
145-
from hydra.utils import get_original_cwd
146117

147-
if HydraConfig.initialized():
148-
cwd = get_original_cwd()
149-
os_cwd = f'"{os.getcwd()}"'
150-
command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]
151-
subprocess.Popen(command, env=env_copy, cwd=cwd)
118+
hydra_in_use = HydraConfig.initialized()
119+
120+
if hydra_in_use:
121+
command = _hydra_subprocess_cmd(local_rank)
122+
else:
123+
command = _basic_subprocess_cmd(local_rank)
124+
125+
subprocess.Popen(command, env=env_copy)
152126

153127
# starting all processes at once can cause issues
154128
# with dataloaders delay between 1-10 seconds
@@ -162,3 +136,44 @@ def _check_can_spawn_children(self) -> None:
162136
" Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user,"
163137
" 2) `ClusterEnvironment.creates_processes_externally` incorrectly implemented."
164138
)
139+
140+
141+
def _basic_subprocess_cmd(local_rank: int) -> Sequence[str]:
142+
if __main__.__spec__ is None: # pragma: no-cover
143+
return [sys.executable, os.path.abspath(sys.argv[0])] + sys.argv[1:]
144+
else:
145+
return [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]
146+
147+
148+
def _hydra_subprocess_cmd(local_rank: int) -> Sequence[str]:
149+
from hydra.core.hydra_config import HydraConfig
150+
from hydra.utils import to_absolute_path
151+
152+
# when user is using hydra find the absolute path
153+
if __main__.__spec__ is None: # pragma: no-cover
154+
command = [sys.executable, to_absolute_path(sys.argv[0])]
155+
else:
156+
command = [sys.executable, "-m", __main__.__spec__.name]
157+
158+
# extract the hydra configu
159+
hydra_cfg = HydraConfig.get()
160+
161+
# the location of the hydra configuration files saved for the current job
162+
hydra_output = hydra_cfg.runtime.output_dir
163+
if hydra_cfg.output_subdir is not None:
164+
hydra_output = os.path.join(hydra_output, hydra_cfg.output_subdir)
165+
166+
# check if experimental re-run capability exists
167+
# otherwise use existing config.yaml which may have issues
168+
pickled_config = os.path.join(hydra_output, "config.pickle")
169+
if os.path.exists(pickled_config):
170+
command += ["--experimental-rerun", pickled_config]
171+
172+
else:
173+
command += ["-cp", hydra_output, "-cn", "config.yaml"]
174+
command += [
175+
f"hydra.output_subdir=.pl_ddp_hydra_{local_rank}",
176+
f"hydra.run.dir={hydra_cfg.runtime.output_dir}",
177+
]
178+
179+
return command
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import logging
2+
import os
3+
import sys
4+
from pathlib import Path
5+
6+
import pytest
7+
from lightning_utilities.core.imports import RequirementCache
8+
9+
from pytorch_lightning.strategies.launchers.subprocess_script import _HYDRA_AVAILABLE
10+
from tests_pytorch.helpers.runif import RunIf
11+
12+
_HYDRA_WITH_RERUN = RequirementCache("hydra-core>=1.2")
13+
_HYDRA_WITH_RUN_PROCESS = RequirementCache("hydra-core>=1.0.7")
14+
15+
if _HYDRA_AVAILABLE:
16+
from omegaconf import OmegaConf
17+
if _HYDRA_WITH_RUN_PROCESS:
18+
from hydra.test_utils.test_utils import run_process
19+
20+
21+
# fixture to run hydra jobs in a clean temporary directory
22+
# Hydra creates its own output directories and logs
23+
@pytest.fixture
24+
def cleandir(tmp_path):
25+
"""Run function in a temporary directory."""
26+
old_dir = os.getcwd() # get current working directory (cwd)
27+
os.chdir(tmp_path) # change cwd to the temp-directory
28+
yield tmp_path # yields control to the test to be run
29+
os.chdir(old_dir)
30+
logging.shutdown()
31+
32+
33+
# Script to run from command line
34+
script = """
35+
import hydra
36+
import os
37+
import torch
38+
39+
from pytorch_lightning import Trainer
40+
from pytorch_lightning.demos.boring_classes import BoringModel
41+
42+
class BoringModelGPU(BoringModel):
43+
def on_train_start(self) -> None:
44+
# make sure that the model is on GPU when training
45+
assert self.device == torch.device(f"cuda:{self.trainer.strategy.local_rank}")
46+
47+
@hydra.main(config_path=None, version_base="1.1")
48+
def task_fn(cfg):
49+
trainer = Trainer(accelerator="auto", devices=cfg.devices, strategy=cfg.strategy, fast_dev_run=True)
50+
model = BoringModelGPU()
51+
trainer.fit(model)
52+
trainer.test(model)
53+
54+
if torch.distributed.is_initialized():
55+
torch.distributed.destroy_process_group()
56+
57+
os.environ.pop("LOCAL_RANK", None)
58+
59+
if __name__ == "__main__":
60+
task_fn()
61+
"""
62+
63+
64+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
65+
@pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
66+
@pytest.mark.parametrize("subdir", [None, "dksa", ".hello"])
67+
def test_ddp_with_hydra_runjob(cleandir, subdir):
68+
# Save script locally
69+
with open("temp.py", "w") as fn:
70+
fn.write(script)
71+
72+
# Run CLI
73+
devices = 2
74+
cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"']
75+
if subdir is not None:
76+
cmd += [f"hydra.output_subdir={subdir}"]
77+
run_process(cmd)
78+
79+
# Make sure config.yaml was created for additional
80+
# processes.
81+
logs = list(Path.cwd().glob("**/config.yaml"))
82+
assert len(logs) == devices
83+
84+
# Make sure the parameter was set and used
85+
cfg = OmegaConf.load(logs[0])
86+
assert cfg.devices == devices
87+
88+
# Make sure PL spawned a job that is logged by Hydra
89+
logs = list(Path.cwd().glob("**/*.log"))
90+
assert len(logs) == 1
91+
92+
93+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
94+
@pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
95+
@pytest.mark.parametrize("num_jobs", [1, 2])
96+
def test_ddp_with_hydra_multirunjob(cleandir, num_jobs):
97+
# Save script locally
98+
with open("temp.py", "w") as fn:
99+
fn.write(script)
100+
101+
# create fake multirun params based on `num_jobs`
102+
fake_param = "+foo=" + ",".join(str(i) for i in range(num_jobs))
103+
104+
# Run CLI
105+
run_process([sys.executable, "temp.py", "+devices=2", '+strategy="ddp"', fake_param, "--multirun"])
106+
107+
# Make sure config.yaml was created for each job
108+
configs = sorted(Path.cwd().glob("**/.pl_ddp_hydra_*/config.yaml"))
109+
assert len(configs) == num_jobs
110+
111+
# Make sure the parameter was set and used for each job
112+
for i, config in enumerate(configs):
113+
cfg = OmegaConf.load(config)
114+
local_rank = int(config.parent.parent.parts[-1])
115+
assert cfg.devices == 2
116+
assert cfg.foo == local_rank
117+
118+
logs = list(Path.cwd().glob("**/*.log"))
119+
assert len(logs) == num_jobs
120+
121+
122+
yaml_file = """
123+
hydra:
124+
callbacks:
125+
save_job_info:
126+
_target_: hydra.experimental.callbacks.PickleJobInfoCallback
127+
"""
128+
129+
130+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
131+
@pytest.mark.skipif(not _HYDRA_WITH_RERUN, reason=str(_HYDRA_WITH_RERUN))
132+
@pytest.mark.parametrize("num_jobs", [1, 2])
133+
def test_ddp_with_hydra_multirunjob_rerun(cleandir, num_jobs):
134+
# Save script locally
135+
with open("temp.py", "w") as fn:
136+
fn.write(script)
137+
138+
with open("config.yaml", "w") as fn:
139+
fn.write(yaml_file)
140+
141+
# create fake multirun params based on `num_jobs`
142+
fake_param = "+foo=" + ",".join(str(i) for i in range(num_jobs))
143+
144+
# Run CLI
145+
run_process(
146+
[
147+
sys.executable,
148+
"temp.py",
149+
"-cp",
150+
".",
151+
"-cn",
152+
"config.yaml",
153+
"+devices=2",
154+
'+strategy="ddp"',
155+
fake_param,
156+
"--multirun",
157+
]
158+
)
159+
160+
pickles = sorted(Path.cwd().glob("**/.hydra/config.pickle"))
161+
assert len(pickles) == num_jobs

0 commit comments

Comments
 (0)