Skip to content

Commit a475010

Browse files
authored
[CLI] Support custom trainers without callbacks (#13138)
1 parent ec3c963 commit a475010

File tree

3 files changed

+45
-20
lines changed

3 files changed

+45
-20
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7373
- Added `teardown()` method to `Accelerator` ([#11935](https://github.com/PyTorchLightning/pytorch-lightning/pull/11935))
7474

7575

76+
- Added support for using custom Trainers that don't include callbacks using the CLI ([#13138](https://github.com/PyTorchLightning/pytorch-lightning/pull/13138))
77+
78+
7679
- Added a `timeout` argument to `DDPStrategy`. ([#13244](https://github.com/PyTorchLightning/pytorch-lightning/pull/13244))
7780

7881

src/pytorch_lightning/utilities/cli.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -602,30 +602,36 @@ def instantiate_trainer(self, **kwargs: Any) -> Trainer:
602602
kwargs: Any custom trainer arguments.
603603
"""
604604
extra_callbacks = [self._get(self.config_init, c) for c in self._parser(self.subcommand).callback_keys]
605-
trainer_config = {**self._get(self.config_init, "trainer"), **kwargs}
605+
trainer_config = {**self._get(self.config_init, "trainer", default={}), **kwargs}
606606
return self._instantiate_trainer(trainer_config, extra_callbacks)
607607

608608
def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer:
609-
if config["callbacks"] is None:
610-
config["callbacks"] = []
611-
elif not isinstance(config["callbacks"], list):
612-
config["callbacks"] = [config["callbacks"]]
613-
assert isinstance(config["callbacks"], list) # to handle mypy false positive
614-
config["callbacks"].extend(callbacks)
615-
if "callbacks" in self.trainer_defaults:
616-
if isinstance(self.trainer_defaults["callbacks"], list):
617-
config["callbacks"].extend(self.trainer_defaults["callbacks"])
618-
else:
619-
config["callbacks"].append(self.trainer_defaults["callbacks"])
620-
if self.save_config_callback and not config["fast_dev_run"]:
621-
config_callback = self.save_config_callback(
622-
self._parser(self.subcommand),
623-
self.config.get(str(self.subcommand), self.config),
624-
self.save_config_filename,
625-
overwrite=self.save_config_overwrite,
626-
multifile=self.save_config_multifile,
609+
key = "callbacks"
610+
if key in config:
611+
if config[key] is None:
612+
config[key] = []
613+
elif not isinstance(config[key], list):
614+
config[key] = [config[key]]
615+
config[key].extend(callbacks)
616+
if key in self.trainer_defaults:
617+
if isinstance(self.trainer_defaults[key], list):
618+
config[key].extend(self.trainer_defaults[key])
619+
else:
620+
config[key].append(self.trainer_defaults[key])
621+
if self.save_config_callback and not config.get("fast_dev_run", False):
622+
config_callback = self.save_config_callback(
623+
self._parser(self.subcommand),
624+
self.config.get(str(self.subcommand), self.config),
625+
self.save_config_filename,
626+
overwrite=self.save_config_overwrite,
627+
multifile=self.save_config_multifile,
628+
)
629+
config[key].append(config_callback)
630+
else:
631+
rank_zero_warn(
632+
f"The `{self.trainer_class.__qualname__}` class does not expose the `{key}` argument so they will"
633+
" not be included."
627634
)
628-
config["callbacks"].append(config_callback)
629635
return self.trainer_class(**config)
630636

631637
def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser:

tests/tests_pytorch/utilities/test_cli.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,22 @@ def test_cli_auto_seeding():
14831483
assert cli.config["seed_everything"] == 123 # the original seed is kept
14841484

14851485

1486+
def test_cli_trainer_no_callbacks():
1487+
class MyTrainer(Trainer):
1488+
def __init__(self):
1489+
super().__init__()
1490+
1491+
class MyCallback(Callback):
1492+
...
1493+
1494+
match = "MyTrainer` class does not expose the `callbacks"
1495+
with mock.patch("sys.argv", ["any.py"]), pytest.warns(UserWarning, match=match):
1496+
cli = LightningCLI(
1497+
BoringModel, run=False, trainer_class=MyTrainer, trainer_defaults={"callbacks": MyCallback()}
1498+
)
1499+
assert not any(isinstance(cb, MyCallback) for cb in cli.trainer.callbacks)
1500+
1501+
14861502
def test_unresolvable_import_paths():
14871503
class TestModel(BoringModel):
14881504
def __init__(self, a_func: Callable = torch.softmax):

0 commit comments

Comments
 (0)