diff --git a/CHANGELOG.md b/CHANGELOG.md index b203856de1ec4..4709ca317080b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -429,6 +429,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155)) +- Fixed the format of the configuration saved automatically by the CLI's `SaveConfigCallback` ([#11532](https://github.com/PyTorchLightning/pytorch-lightning/pull/11532)) + + - Fixed type promotion when tensors of higher category than float are logged ([#11401](https://github.com/PyTorchLightning/pytorch-lightning/pull/11401)) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 6466fb95674eb..609a681d2e6a7 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -699,8 +699,8 @@ def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback] config["callbacks"].append(self.trainer_defaults["callbacks"]) if self.save_config_callback and not config["fast_dev_run"]: config_callback = self.save_config_callback( - self.parser, - self.config, + self._parser(self.subcommand), + self.config.get(str(self.subcommand), self.config), self.save_config_filename, overwrite=self.save_config_overwrite, multifile=self.save_config_multifile, @@ -798,9 +798,7 @@ def get_automatic( def _get(self, config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any: """Utility to get a config value which might be inside a subcommand.""" - if self.subcommand is not None: - return config[self.subcommand].get(key, default) - return config.get(key, default) + return config.get(str(self.subcommand), config).get(key, default) def _run_subcommand(self, subcommand: str) -> None: """Run the chosen subcommand.""" diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index fac6c89b5fe76..a1608510aacb3 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -349,9 +349,7 @@ def test_lightning_cli_args(tmpdir): with open(config_path) as f: loaded_config = yaml.safe_load(f.read()) - loaded_config = loaded_config["fit"] cli_config = cli.config["fit"].as_dict() - assert cli_config["seed_everything"] == 1234 assert "model" not in loaded_config and "model" not in cli_config # no arguments to include assert loaded_config["data"] == cli_config["data"] @@ -405,9 +403,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir): with open(config_path) as f: loaded_config = yaml.safe_load(f.read()) - loaded_config = loaded_config["fit"] cli_config = cli.config["fit"].as_dict() - assert loaded_config["model"] == cli_config["model"] assert loaded_config["data"] == cli_config["data"] assert loaded_config["trainer"] == cli_config["trainer"] @@ -1251,6 +1247,10 @@ def test_lightning_cli_config_before_subcommand(): test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") assert cli.trainer.limit_test_batches == 1 + save_config_callback = cli.trainer.callbacks[0] + assert save_config_callback.config.trainer.limit_test_batches == 1 + assert save_config_callback.parser.subcommand == "test" + with mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), mock.patch( "pytorch_lightning.Trainer.validate", autospec=True ) as validate_mock: @@ -1259,6 +1259,10 @@ def test_lightning_cli_config_before_subcommand(): validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") assert cli.trainer.limit_val_batches == 1 + save_config_callback = cli.trainer.callbacks[0] + assert save_config_callback.config.trainer.limit_val_batches == 1 + assert save_config_callback.parser.subcommand == "validate" + def test_lightning_cli_config_before_subcommand_two_configs(): config1 = {"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"}}