From 6f0f7e12025b51264348a844832620a81f5a995e Mon Sep 17 00:00:00 2001 From: Espen Haugsdal Date: Sun, 5 Jul 2020 13:20:41 +0200 Subject: [PATCH 1/2] Add failing test for bug --- tests/trainer/test_trainer_cli.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index 51bbc96bd4f9b..fd381c251353b 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -102,7 +102,21 @@ def _raise(): pytest.param('--tpu_cores=8', {'tpu_cores': 8}), pytest.param("--tpu_cores=1,", - {'tpu_cores': '1,'}) + {'tpu_cores': '1,'}), + pytest.param( + "", + { + # These parameters are marked as Optional[...] in Trainer.__init__, with None as default. + # They should not be changed by the argparse interface. + "min_steps": None, + "max_steps": None, + "log_gpu_memory": None, + "distributed_backend": None, + "weights_save_path": None, + "truncated_bptt_steps": None, + "resume_from_checkpoint": None, + "profiler": None, + }), ]) def test_argparse_args_parsing(cli_args, expected): """Test multi type argument with bool.""" From c5e2f88384122ca98f28d5f54e73e65befba7ea3 Mon Sep 17 00:00:00 2001 From: Espen Haugsdal Date: Sun, 5 Jul 2020 13:29:17 +0200 Subject: [PATCH 2/2] Fix bug --- pytorch_lightning/trainer/trainer.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0b279f4e531f0..40478a533cf0b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -792,12 +792,32 @@ def _arg_default(x) -> Union[int, str]: else: return int(x) - @staticmethod - def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: + @classmethod + def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: """Parse CLI arguments, required for custom bool types.""" args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser - args = {k: True if v is None else v for k, v in vars(args).items()} - return Namespace(**args) + + types_default = { + arg: (arg_types, arg_default) for arg, arg_types, arg_default in cls.get_init_arguments_and_types() + } + + modified_args = {} + for k, v in vars(args).items(): + if k in types_default and v is None: + # We need to figure out if the None is due to using nargs="?" or if it comes from the default value + arg_types, arg_default = types_default[k] + if bool in arg_types and isinstance(arg_default, bool): + # Value has been passed as a flag => It is currently None, so we need to set it to True + # We always set to True, regardless of the default value. + # Users must pass False directly, but when passing nothing True is assumed. + # i.e. the only way to disable somthing that defaults to True is to use the long form: + # "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None, + # which then becomes True here. + + v = True + + modified_args[k] = v + return Namespace(**modified_args) @classmethod def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':