Skip to content

Commit df77f4e

Browse files
gnadathurgnadathur
andauthored
Fill missing options in toml file wih argparse defaults (#91)
Summary: Summary: Follow up on config unification, options not available in config file are picked from command line defaults. Test Plan: ============================= test session starts ============================== platform linux -- Python 3.10.13, pytest-8.0.1, pluggy-1.4.0 -- /home/gnadathur/local/a/pytorch-env/bin/python cachedir: .pytest_cache rootdir: /data/users/gnadathur/a/torchtrain configfile: pyproject.toml plugins: cov-4.1.0 collecting ... collected 3 items test/test_job_config.py::TestJobConfig::test_command_line_args PASSED [ 33%] test/test_job_config.py::TestJobConfig::test_job_config_file PASSED [ 66%] test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist PASSED [100%] ---------- coverage: platform linux, python 3.10.13-final-0 ---------- Coverage XML written to file coverage.xml ============================= slowest 20 durations ============================= 0.00s call test/test_job_config.py::TestJobConfig::test_job_config_file 0.00s call test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s call test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist 0.00s setup test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s teardown test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s setup test/test_job_config.py::TestJobConfig::test_job_config_file 0.00s setup test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist 0.00s teardown test/test_job_config.py::TestJobConfig::test_job_config_file 0.00s teardown test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist ============================== 3 passed in 0.06s =============================== Test Plan: Reviewers: Subscribers: Tasks: Tags: --------- Co-authored-by: gnadathur <[email protected]>
1 parent 3b48039 commit df77f4e

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

test/test_job_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ class TestJobConfig:
99
def test_command_line_args(self):
1010
config = JobConfig()
1111
config.parse_args([])
12-
assert config.model.name == "llama"
12+
assert config.training.steps == -1
1313

1414
def test_job_config_file(self):
1515
config = JobConfig()
1616
config.parse_args(["--job.config_file", "./train_configs/debug_model.toml"])
17-
assert config.model.name == "llama"
17+
assert config.training.steps == 10
1818

1919
def test_job_file_does_not_exist(self):
2020
with pytest.raises(FileNotFoundError):

torchtrain/config_manager.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3+
14
# Copyright (c) Meta Platforms, Inc. and affiliates.
25
# All rights reserved.
36
import argparse
@@ -17,16 +20,16 @@ class JobConfig:
1720
Semantics:
1821
- Default config is loaded from a toml file. If no toml file is provided,
1922
then the default config is loaded from argparse defaults.
23+
- if toml file has missing keys, they are filled with argparse defaults.
2024
"""
2125

2226
def parse_args(self, args_list: list = sys.argv[1:]):
2327
args = JobConfig.init_args_from_command_line(args_list)
2428
config_file = getattr(args, "job.config_file", None)
25-
if config_file is None:
26-
args_dict = self._args_to_two_level_dict(args)
27-
else:
29+
args_dict = self._args_to_two_level_dict(args)
30+
if config_file is not None:
2831
with open(config_file, "rb") as f:
29-
args_dict = tomllib.load(f)
32+
args_dict |= tomllib.load(f)
3033
for k, v in args_dict.items():
3134
class_type = type(k.title(), (), v)
3235
setattr(self, k, class_type())

0 commit comments

Comments
 (0)