-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Monitor on training_epoch_end with ModelCheckpoint #5084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ec66391
d6eab2c
e4207ce
a4694ac
fa3dfde
8e484bf
b399a9e
5fa6150
31ca923
3484f2b
b64dcfc
f3e3eb3
05fd97f
15a88db
f37528e
985d4f0
cafc0a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,15 +25,16 @@ | |
| import torch | ||
| import yaml | ||
| from omegaconf import Container, OmegaConf | ||
| from torch.utils.data import DataLoader, Dataset, random_split | ||
|
|
||
| import pytorch_lightning as pl | ||
| import tests.base.develop_utils as tutils | ||
| from pytorch_lightning import Trainer, seed_everything | ||
| from pytorch_lightning import LightningModule, Trainer, seed_everything | ||
| from pytorch_lightning.callbacks import ModelCheckpoint | ||
| from pytorch_lightning.loggers import TensorBoardLogger | ||
| from pytorch_lightning.utilities.cloud_io import load as pl_load | ||
| from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
| from tests.base import BoringModel | ||
| from tests.base import BoringModel, RandomDataset | ||
|
|
||
|
|
||
| class LogInTwoMethods(BoringModel): | ||
|
|
@@ -702,7 +703,7 @@ def validation_epoch_end(self, *_): | |
| ... | ||
|
|
||
| def assert_trainer_init(trainer): | ||
| assert not trainer.checkpoint_connector.has_trained | ||
| assert not trainer.checkpoint_connector._has_trained | ||
| assert trainer.global_step == 0 | ||
| assert trainer.current_epoch == 0 | ||
|
|
||
|
|
@@ -739,7 +740,7 @@ def assert_checkpoint_log_dir(idx): | |
|
|
||
| model = ExtendedBoringModel() | ||
| trainer.fit(model) | ||
| assert trainer.checkpoint_connector.has_trained | ||
| assert trainer.checkpoint_connector._has_trained | ||
| assert trainer.global_step == epochs * limit_train_batches | ||
| assert trainer.current_epoch == epochs - 1 | ||
| assert_checkpoint_log_dir(0) | ||
|
|
@@ -759,12 +760,12 @@ def assert_checkpoint_log_dir(idx): | |
|
|
||
| model = ExtendedBoringModel() | ||
| trainer.test(model) | ||
| assert not trainer.checkpoint_connector.has_trained | ||
| assert not trainer.checkpoint_connector._has_trained | ||
| assert trainer.global_step == epochs * limit_train_batches | ||
| assert trainer.current_epoch == epochs | ||
|
|
||
| trainer.fit(model) | ||
| assert not trainer.checkpoint_connector.has_trained | ||
| assert not trainer.checkpoint_connector._has_trained | ||
| assert trainer.global_step == epochs * limit_train_batches | ||
| assert trainer.current_epoch == epochs | ||
| assert_checkpoint_log_dir(idx) | ||
|
|
@@ -940,6 +941,41 @@ def __init__(self, hparams): | |
| assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type | ||
|
|
||
|
|
||
| def test_model_checkpoint_with_training_epoch_end(tmpdir): | ||
| """ | ||
| This test ensures ModelCheckpoint issues a warning when the monitor is logged on training_epoch_end | ||
| """ | ||
| class TestedModel(BoringModel): | ||
|
|
||
| def training_step(self, batch, batch_idx): | ||
| output = self.layer(batch) | ||
| loss = self.loss(batch, output) | ||
| self.log('train_loss', loss) | ||
| return {"loss": loss} | ||
|
|
||
| def training_epoch_end(self, outputs) -> None: | ||
| avg_loss = torch.stack([x["loss"] for x in outputs]).mean() | ||
| self.log('epoch_end_train_loss', avg_loss) | ||
|
|
||
| model = TestedModel() | ||
|
|
||
| chk = ModelCheckpoint(dirpath=tmpdir, monitor='epoch_end_train_loss', save_top_k=-1) | ||
| trainer = pl.Trainer( | ||
| default_root_dir=tmpdir, | ||
| max_epochs=4, | ||
| progress_bar_refresh_rate=1, | ||
| callbacks=[chk], | ||
| ) | ||
| trainer.current_epoch = 2 | ||
| trainer.fit(model) | ||
|
|
||
| chks = os.listdir(tmpdir) | ||
| assert 'epoch=4.ckpt' not in chks | ||
| assert 'epoch=3.ckpt' not in chks | ||
| assert 'epoch=2.ckpt' not in chks | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is missing a Also it might be better to test
instead of testing that the rest don't exist
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, let's close this one for now. |
||
|
|
||
|
|
||
|
|
||
| @pytest.mark.parametrize('max_epochs', [3, 4]) | ||
| @pytest.mark.parametrize( | ||
| 'save_top_k, expected', | ||
|
|
@@ -976,4 +1012,4 @@ def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, ex | |
| assert set(ckpt_files) == set(expected) | ||
|
|
||
| epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files] | ||
| assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs)) | ||
| assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs)) | ||
Uh oh!
There was an error while loading. Please reload this page.