Skip to content

Commit 84950b9

Browse files
committed
feat: add test
1 parent e295c9c commit 84950b9

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

tests/checkpointing/test_trainer_checkpoint.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313
# limitations under the License.
1414
import os
1515
from copy import deepcopy
16+
from pathlib import Path
1617

18+
import pytest
1719
import torch
20+
from torch.utils.data import DataLoader
1821

1922
import pytorch_lightning as pl
2023
from pytorch_lightning import seed_everything, Trainer
2124
from pytorch_lightning.callbacks import ModelCheckpoint
22-
from tests.helpers import BoringModel
25+
from pytorch_lightning.utilities.cloud_io import load as pl_load
26+
from tests.helpers import BoringModel, RandomDataset
2327

2428

2529
def test_finetuning_with_resume_from_checkpoint(tmpdir):
@@ -84,3 +88,40 @@ def validation_step(self, batch, batch_idx):
8488
assert best_model_path.endswith(f"epoch=0{idx}.ckpt")
8589
else:
8690
assert f"epoch={idx + 1}" in best_model_path
91+
92+
93+
@pytest.mark.parametrize(['max_epochs', 'data_length'], [(1, 64), (2, 64), (3, 32)])
94+
def test_lr_schedulers_step_count(tmpdir, max_epochs, data_length):
95+
"""
96+
This test validates that checkpoint is always saved after lr_scheduler beeing updated during training
97+
"""
98+
99+
class TestModel(BoringModel):
100+
101+
def configure_optimizers(self):
102+
optimizer = torch.optim.SGD(self.parameters(), lr=0.001)
103+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
104+
lr_scheduler_dict = {'scheduler': lr_scheduler, 'interval': 'step'}
105+
return [optimizer], [lr_scheduler_dict]
106+
107+
def train_dataloader(self):
108+
return DataLoader(RandomDataset(32, data_length))
109+
110+
train_step_checkpoint_callback = ModelCheckpoint(dirpath=f"{tmpdir}/every_train_step", every_n_train_steps=1)
111+
val_epoch_checkpoint_callback = ModelCheckpoint(dirpath=f"{tmpdir}/every_val_epoch", every_n_val_epochs=1)
112+
113+
model = TestModel()
114+
trainer = Trainer(
115+
default_root_dir=tmpdir,
116+
max_epochs=max_epochs,
117+
callbacks=[train_step_checkpoint_callback, val_epoch_checkpoint_callback]
118+
)
119+
trainer.fit(model)
120+
step_idx = data_length * max_epochs - 1
121+
train_step_lr_scheduler = pl_load(f"{tmpdir}/every_train_step/epoch={max_epochs-1}-step={step_idx}.ckpt"
122+
)['lr_schedulers'][0]
123+
val_epoch_lr_scheduler = pl_load(f"{tmpdir}/every_val_epoch/epoch={max_epochs-1}-step={step_idx}.ckpt"
124+
)['lr_schedulers'][0]
125+
#
126+
assert train_step_lr_scheduler['last_epoch'] == val_epoch_lr_scheduler['last_epoch'] == step_idx + 1
127+
assert train_step_lr_scheduler['_step_count'] == val_epoch_lr_scheduler['_step_count'] == step_idx + 2

0 commit comments

Comments
 (0)