|
13 | 13 | # limitations under the License. |
14 | 14 | import os |
15 | 15 | from copy import deepcopy |
| 16 | +from pathlib import Path |
16 | 17 |
|
| 18 | +import pytest |
17 | 19 | import torch |
| 20 | +from torch.utils.data import DataLoader |
18 | 21 |
|
19 | 22 | import pytorch_lightning as pl |
20 | 23 | from pytorch_lightning import seed_everything, Trainer |
21 | 24 | from pytorch_lightning.callbacks import ModelCheckpoint |
22 | | -from tests.helpers import BoringModel |
| 25 | +from pytorch_lightning.utilities.cloud_io import load as pl_load |
| 26 | +from tests.helpers import BoringModel, RandomDataset |
23 | 27 |
|
24 | 28 |
|
25 | 29 | def test_finetuning_with_resume_from_checkpoint(tmpdir): |
@@ -84,3 +88,40 @@ def validation_step(self, batch, batch_idx): |
84 | 88 | assert best_model_path.endswith(f"epoch=0{idx}.ckpt") |
85 | 89 | else: |
86 | 90 | assert f"epoch={idx + 1}" in best_model_path |
| 91 | + |
| 92 | + |
| 93 | +@pytest.mark.parametrize(['max_epochs', 'data_length'], [(1, 64), (2, 64), (3, 32)]) |
| 94 | +def test_lr_schedulers_step_count(tmpdir, max_epochs, data_length): |
| 95 | + """ |
| 96 | + This test validates that checkpoint is always saved after lr_scheduler beeing updated during training |
| 97 | + """ |
| 98 | + |
| 99 | + class TestModel(BoringModel): |
| 100 | + |
| 101 | + def configure_optimizers(self): |
| 102 | + optimizer = torch.optim.SGD(self.parameters(), lr=0.001) |
| 103 | + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) |
| 104 | + lr_scheduler_dict = {'scheduler': lr_scheduler, 'interval': 'step'} |
| 105 | + return [optimizer], [lr_scheduler_dict] |
| 106 | + |
| 107 | + def train_dataloader(self): |
| 108 | + return DataLoader(RandomDataset(32, data_length)) |
| 109 | + |
| 110 | + train_step_checkpoint_callback = ModelCheckpoint(dirpath=f"{tmpdir}/every_train_step", every_n_train_steps=1) |
| 111 | + val_epoch_checkpoint_callback = ModelCheckpoint(dirpath=f"{tmpdir}/every_val_epoch", every_n_val_epochs=1) |
| 112 | + |
| 113 | + model = TestModel() |
| 114 | + trainer = Trainer( |
| 115 | + default_root_dir=tmpdir, |
| 116 | + max_epochs=max_epochs, |
| 117 | + callbacks=[train_step_checkpoint_callback, val_epoch_checkpoint_callback] |
| 118 | + ) |
| 119 | + trainer.fit(model) |
| 120 | + step_idx = data_length * max_epochs - 1 |
| 121 | + train_step_lr_scheduler = pl_load(f"{tmpdir}/every_train_step/epoch={max_epochs-1}-step={step_idx}.ckpt" |
| 122 | + )['lr_schedulers'][0] |
| 123 | + val_epoch_lr_scheduler = pl_load(f"{tmpdir}/every_val_epoch/epoch={max_epochs-1}-step={step_idx}.ckpt" |
| 124 | + )['lr_schedulers'][0] |
| 125 | + # |
| 126 | + assert train_step_lr_scheduler['last_epoch'] == val_epoch_lr_scheduler['last_epoch'] == step_idx + 1 |
| 127 | + assert train_step_lr_scheduler['_step_count'] == val_epoch_lr_scheduler['_step_count'] == step_idx + 2 |
0 commit comments