|
8 | 8 | from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher |
9 | 9 | from tests_pytorch.helpers.runif import RunIf |
10 | 10 |
|
11 | | -_HYDRA_WITH_RERUN = RequirementCache("hydra-core>=1.2") |
12 | 11 | _HYDRA_WITH_RUN_PROCESS = RequirementCache("hydra-core>=1.0.7") |
13 | 12 |
|
14 | 13 | if _HYDRA_WITH_RUN_PROCESS: |
15 | 14 | from hydra.test_utils.test_utils import run_process |
| 15 | + from omegaconf import OmegaConf |
16 | 16 |
|
17 | 17 |
|
18 | 18 | # Script to run from command line |
@@ -48,21 +48,34 @@ def task_fn(cfg): |
48 | 48 |
|
49 | 49 | @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) |
50 | 50 | @pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS)) |
51 | | -@pytest.mark.parametrize("subdir", [None, "dksa", ".hello"]) |
52 | | -def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch): |
53 | | - monkeypatch.chdir(tmpdir) |
| 51 | +@pytest.mark.parametrize("subdir", [None, "null", "dksa", ".hello"]) |
| 52 | +def test_ddp_with_hydra_runjob(subdir, tmp_path, monkeypatch): |
| 53 | + monkeypatch.chdir(tmp_path) |
54 | 54 |
|
55 | 55 | # Save script locally |
56 | 56 | with open("temp.py", "w") as fn: |
57 | 57 | fn.write(script) |
58 | 58 |
|
59 | 59 | # Run CLI |
60 | 60 | devices = 2 |
61 | | - cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"'] |
| 61 | + run_dir = tmp_path / "hydra_output" |
| 62 | + cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"', f"hydra.run.dir={run_dir}"] |
62 | 63 | if subdir is not None: |
63 | 64 | cmd += [f"hydra.output_subdir={subdir}"] |
64 | 65 | run_process(cmd) |
65 | 66 |
|
| 67 | + # Make sure no config.yaml was created for additional processes |
| 68 | + saved_confs = list(run_dir.glob("**/config.yaml")) |
| 69 | + assert len(saved_confs) == (0 if subdir == "null" else 1) # Main process has config.yaml iff subdir!="null" |
| 70 | + |
| 71 | + if saved_confs: # Make sure the parameter was set and used |
| 72 | + cfg = OmegaConf.load(saved_confs[0]) |
| 73 | + assert cfg.devices == devices |
| 74 | + |
| 75 | + # Make sure PL spawned jobs that are logged by Hydra |
| 76 | + logs = list(run_dir.glob("**/*.log")) |
| 77 | + assert len(logs) == devices |
| 78 | + |
66 | 79 |
|
67 | 80 | def test_kill(): |
68 | 81 | launcher = _SubprocessScriptLauncher(Mock(), 1, 1) |
|
0 commit comments