From 084eeafa87a12101f8b524b5ecba9097749e7cc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 16 Dec 2021 17:13:46 +0100 Subject: [PATCH 1/4] Initialize ModelCheckpoint state as early as possible --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- tests/models/test_restore.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 3be03daebcf1e..a229f694db4f4 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -248,7 +248,7 @@ def state_key(self) -> str: save_on_train_epoch_end=self._save_on_train_epoch_end, ) - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """When pretrain routine starts we resolve the ckpt dir on the fly.""" if self._save_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch or multiple training epochs without diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 1139e6fb5e8ad..246987763c11a 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -246,10 +246,11 @@ def get_trainer_args(): checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer_args = dict( default_root_dir=tmpdir, - max_steps=1, + limit_train_batches=1, + limit_val_batches=2, + max_epochs=1, logger=False, callbacks=[checkpoint, callback_capture], - limit_val_batches=2, ) assert checkpoint.best_model_path == "" assert checkpoint.best_model_score is None @@ -257,11 +258,13 @@ def get_trainer_args(): # initial training trainer = Trainer(**get_trainer_args()) + # with pytest.deprecated_call(match="The `on_init_end` callback hook "): trainer.fit(model, datamodule=dm) callbacks_before_resume = deepcopy(trainer.callbacks) # resumed training trainer = Trainer(**get_trainer_args()) + # with pytest.deprecated_call(match="The `on_init_end` callback hook "): trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt")) assert len(callbacks_before_resume) == len(callback_capture.callbacks) From 5777d6665b7ae589eb89ce789fbd85dfa1aa4f34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 16 Dec 2021 17:40:31 +0100 Subject: [PATCH 2/4] fix --- pytorch_lightning/callbacks/model_checkpoint.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a229f694db4f4..72039edc58462 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -249,12 +249,15 @@ def state_key(self) -> str: ) def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """When pretrain routine starts we resolve the ckpt dir on the fly.""" + # NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states, + # because the attributes are part of the state_key which needs to be fully defined before reloading. if self._save_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch or multiple training epochs without # validation, then we run after validation instead of on train epoch end self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 + def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """When pretrain routine starts we build the ckpt dir on the fly.""" self.__resolve_ckpt_dir(trainer) if trainer.is_global_zero: self.__warn_if_dir_not_empty(self.dirpath) From 28f5cae876430f6f564795a934968424e5cfd370 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 16 Dec 2021 23:39:22 +0100 Subject: [PATCH 3/4] move to setup hook --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 72039edc58462..d5f876325f696 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -248,7 +248,7 @@ def state_key(self) -> str: save_on_train_epoch_end=self._save_on_train_epoch_end, ) - def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: # NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states, # because the attributes are part of the state_key which needs to be fully defined before reloading. if self._save_on_train_epoch_end is None: From 4c355656337fa6f850bb9ce240899d124e27175e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 16 Dec 2021 23:41:57 +0100 Subject: [PATCH 4/4] clean up debug statements --- tests/models/test_restore.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 246987763c11a..6e8b6e5926bca 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -258,13 +258,11 @@ def get_trainer_args(): # initial training trainer = Trainer(**get_trainer_args()) - # with pytest.deprecated_call(match="The `on_init_end` callback hook "): trainer.fit(model, datamodule=dm) callbacks_before_resume = deepcopy(trainer.callbacks) # resumed training trainer = Trainer(**get_trainer_args()) - # with pytest.deprecated_call(match="The `on_init_end` callback hook "): trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt")) assert len(callbacks_before_resume) == len(callback_capture.callbacks)