Skip to content
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))

- Fixed an issue with `Trainer.from_argparse_args` when passing in unknown Trainer args ([#1932](https://github.com/PyTorchLightning/pytorch-lightning/pull/1932))

- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))


## [0.7.6] - 2020-05-16

### Added
Expand Down
22 changes: 17 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,20 +730,32 @@ def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:

@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
"""create an instance from CLI arguments
"""
Create an instance from CLI arguments.

Args:
args: The parser or namespace to take arguments from. Only known arguments will be
parsed and passed to the :class:`Trainer`.
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
These must be valid Trainer arguments.

Example:
>>> parser = ArgumentParser(add_help=False)
>>> parser = Trainer.add_argparse_args(parser)
>>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP
>>> args = Trainer.parse_argparser(parser.parse_args(""))
>>> trainer = Trainer.from_argparse_args(args)
>>> trainer = Trainer.from_argparse_args(args, logger=False)
"""
if isinstance(args, ArgumentParser):
args = Trainer.parse_argparser(args)
args = cls.parse_argparser(args)
params = vars(args)
params.update(**kwargs)

return cls(**params)
# we only want to pass in valid Trainer args, the rest may be user specific
valid_kwargs = inspect.signature(cls.__init__).parameters
trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
trainer_kwargs.update(**kwargs)

return cls(**trainer_kwargs)

@property
def num_gpus(self) -> int:
Expand Down
26 changes: 26 additions & 0 deletions tests/trainer/test_trainer_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import pickle
import sys
from argparse import ArgumentParser, Namespace
from unittest import mock

Expand Down Expand Up @@ -110,3 +111,28 @@ def test_argparse_args_parsing(cli_args, expected):
for k, v in expected.items():
assert getattr(args, k) == v
assert Trainer.from_argparse_args(args)


@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec"
)
@pytest.mark.parametrize(['cli_args', 'extra_args'], [
pytest.param({}, {}),
pytest.param({'logger': False}, {}),
pytest.param({'logger': False}, {'logger': True}),
pytest.param({'logger': False}, {'checkpoint_callback': True}),
])
def test_init_from_argparse_args(cli_args, extra_args):
unknown_args = dict(unknown_arg=0)

# unkown args in the argparser/namespace should be ignored
with mock.patch('pytorch_lightning.Trainer.__init__', autospec=True, return_value=None) as init:
trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args)
expected = dict(cli_args)
expected.update(extra_args) # extra args should override any cli arg
init.assert_called_with(trainer, **expected)

# passing in unknown manual args should throw an error
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"):
Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args)