Skip to content

Commit 34237cf

Browse files
handle unknown args passed to Trainer.from_argparse_args (#1932)
* filter valid args * error on unknown manual args * added test * changelog * update docs and doctest * simplify * doctest * doctest * doctest * better test with mock check for init call * fstring * extend test * skip test on 3.6 not working Co-authored-by: William Falcon <[email protected]>
1 parent f46a7ba commit 34237cf

File tree

3 files changed

+46
-5
lines changed

3 files changed

+46
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3838

3939
- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
4040

41+
- Fixed an issue with `Trainer.from_argparse_args` when passing in unknown Trainer args ([#1932](https://github.com/PyTorchLightning/pytorch-lightning/pull/1932))
42+
4143
- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))
4244

45+
4346
## [0.7.6] - 2020-05-16
4447

4548
### Added

pytorch_lightning/trainer/trainer.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -725,20 +725,32 @@ def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
725725

726726
@classmethod
727727
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
728-
"""create an instance from CLI arguments
728+
"""
729+
Create an instance from CLI arguments.
730+
731+
Args:
732+
args: The parser or namespace to take arguments from. Only known arguments will be
733+
parsed and passed to the :class:`Trainer`.
734+
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
735+
These must be valid Trainer arguments.
729736
730737
Example:
731738
>>> parser = ArgumentParser(add_help=False)
732739
>>> parser = Trainer.add_argparse_args(parser)
740+
>>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP
733741
>>> args = Trainer.parse_argparser(parser.parse_args(""))
734-
>>> trainer = Trainer.from_argparse_args(args)
742+
>>> trainer = Trainer.from_argparse_args(args, logger=False)
735743
"""
736744
if isinstance(args, ArgumentParser):
737-
args = Trainer.parse_argparser(args)
745+
args = cls.parse_argparser(args)
738746
params = vars(args)
739-
params.update(**kwargs)
740747

741-
return cls(**params)
748+
# we only want to pass in valid Trainer args, the rest may be user specific
749+
valid_kwargs = inspect.signature(cls.__init__).parameters
750+
trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
751+
trainer_kwargs.update(**kwargs)
752+
753+
return cls(**trainer_kwargs)
742754

743755
@property
744756
def num_gpus(self) -> int:

tests/trainer/test_trainer_cli.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import pickle
3+
import sys
34
from argparse import ArgumentParser, Namespace
45
from unittest import mock
56

@@ -110,3 +111,28 @@ def test_argparse_args_parsing(cli_args, expected):
110111
for k, v in expected.items():
111112
assert getattr(args, k) == v
112113
assert Trainer.from_argparse_args(args)
114+
115+
116+
@pytest.mark.skipif(
117+
sys.version_info < (3, 7),
118+
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec"
119+
)
120+
@pytest.mark.parametrize(['cli_args', 'extra_args'], [
121+
pytest.param({}, {}),
122+
pytest.param({'logger': False}, {}),
123+
pytest.param({'logger': False}, {'logger': True}),
124+
pytest.param({'logger': False}, {'checkpoint_callback': True}),
125+
])
126+
def test_init_from_argparse_args(cli_args, extra_args):
127+
unknown_args = dict(unknown_arg=0)
128+
129+
# unkown args in the argparser/namespace should be ignored
130+
with mock.patch('pytorch_lightning.Trainer.__init__', autospec=True, return_value=None) as init:
131+
trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args)
132+
expected = dict(cli_args)
133+
expected.update(extra_args) # extra args should override any cli arg
134+
init.assert_called_with(trainer, **expected)
135+
136+
# passing in unknown manual args should throw an error
137+
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"):
138+
Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args)

0 commit comments

Comments
 (0)