Skip to content

Commit 29486ab

Browse files
nisheethlahotiBorda
authored andcommitted
Use hydra.run.dir (not os.getcwd) for DDP subprocesses' run dir (#18145)
Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit ebbd538)
1 parent 2e828e2 commit 29486ab

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Fixed
1111

12-
-
12+
- Fixed issue where DDP subprocesses that used Hydra would set hydra's working directory to current directory ([#18145](https://github.com/Lightning-AI/lightning/pull/18145))
1313

1414

1515
## [2.0.6] - 2023-07-20

src/lightning/fabric/strategies/launchers/subprocess_script.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _basic_subprocess_cmd() -> Sequence[str]:
143143

144144
def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
145145
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
146+
from hydra.core.hydra_config import HydraConfig
146147
from hydra.utils import get_original_cwd, to_absolute_path
147148

148149
# when user is using hydra find the absolute path
@@ -154,6 +155,7 @@ def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
154155
command += sys.argv[1:]
155156

156157
cwd = get_original_cwd()
157-
os_cwd = f'"{os.getcwd()}"'
158-
command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]
158+
rundir = f'"{HydraConfig.get().run.dir}"'
159+
# Set output_subdir null since we don't want different subprocesses trying to write to config.yaml
160+
command += [f"hydra.run.dir={rundir}", f"hydra.job.name=train_ddp_process_{local_rank}", "hydra.output_subdir=null"]
159161
return command, cwd

tests/tests_pytorch/strategies/launchers/test_subprocess_script.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
99
from tests_pytorch.helpers.runif import RunIf
1010

11-
_HYDRA_WITH_RERUN = RequirementCache("hydra-core>=1.2")
1211
_HYDRA_WITH_RUN_PROCESS = RequirementCache("hydra-core>=1.0.7")
1312

1413
if _HYDRA_WITH_RUN_PROCESS:
1514
from hydra.test_utils.test_utils import run_process
15+
from omegaconf import OmegaConf
1616

1717

1818
# Script to run from command line
@@ -48,21 +48,34 @@ def task_fn(cfg):
4848

4949
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
5050
@pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
51-
@pytest.mark.parametrize("subdir", [None, "dksa", ".hello"])
52-
def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch):
53-
monkeypatch.chdir(tmpdir)
51+
@pytest.mark.parametrize("subdir", [None, "null", "dksa", ".hello"])
52+
def test_ddp_with_hydra_runjob(subdir, tmp_path, monkeypatch):
53+
monkeypatch.chdir(tmp_path)
5454

5555
# Save script locally
5656
with open("temp.py", "w") as fn:
5757
fn.write(script)
5858

5959
# Run CLI
6060
devices = 2
61-
cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"']
61+
run_dir = tmp_path / "hydra_output"
62+
cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"', f"hydra.run.dir={run_dir}"]
6263
if subdir is not None:
6364
cmd += [f"hydra.output_subdir={subdir}"]
6465
run_process(cmd)
6566

67+
# Make sure no config.yaml was created for additional processes
68+
saved_confs = list(run_dir.glob("**/config.yaml"))
69+
assert len(saved_confs) == (0 if subdir == "null" else 1) # Main process has config.yaml iff subdir!="null"
70+
71+
if saved_confs: # Make sure the parameter was set and used
72+
cfg = OmegaConf.load(saved_confs[0])
73+
assert cfg.devices == devices
74+
75+
# Make sure PL spawned jobs that are logged by Hydra
76+
logs = list(run_dir.glob("**/*.log"))
77+
assert len(logs) == devices
78+
6679

6780
def test_kill():
6881
launcher = _SubprocessScriptLauncher(Mock(), 1, 1)

0 commit comments

Comments
 (0)