|
18 | 18 | from torch import optim |
19 | 19 |
|
20 | 20 | from pytorch_lightning import Callback, Trainer |
| 21 | +from pytorch_lightning.callbacks import ModelCheckpoint |
21 | 22 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
22 | 23 | from tests.base import EvalModelTemplate |
23 | 24 | from tests.helpers.boring_model import BoringModel |
@@ -620,3 +621,87 @@ def test_lr_scheduler_epoch_step_frequency(mocked_sched, check_val_every_n_epoch |
620 | 621 | ) |
621 | 622 | trainer.fit(model) |
622 | 623 | assert mocked_sched.call_count == expected_steps |
| 624 | + |
| 625 | + |
| 626 | +@pytest.mark.parametrize('every_n_train_steps, epoch_interval', [(None, True), (2, False), (2, True)]) |
| 627 | +def test_lr_scheduler_state_updated_before_saving(tmpdir, every_n_train_steps, epoch_interval): |
| 628 | + batches = 2 |
| 629 | + max_epochs = 1 |
| 630 | + lr, gamma = 1, 10 |
| 631 | + trainer = Trainer( |
| 632 | + default_root_dir=tmpdir, |
| 633 | + progress_bar_refresh_rate=0, |
| 634 | + logger=False, |
| 635 | + max_epochs=max_epochs, |
| 636 | + limit_train_batches=batches, |
| 637 | + limit_val_batches=1, |
| 638 | + callbacks=[ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=every_n_train_steps)] |
| 639 | + ) |
| 640 | + |
| 641 | + class TestModel(BoringModel): |
| 642 | + |
| 643 | + def configure_optimizers(self): |
| 644 | + optimizer = torch.optim.SGD(self.parameters(), lr=lr) |
| 645 | + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma) |
| 646 | + lr_dict = {'scheduler': lr_scheduler} |
| 647 | + if not epoch_interval: |
| 648 | + lr_dict['interval'] = 'step' |
| 649 | + return [optimizer], [lr_dict] |
| 650 | + |
| 651 | + def on_save_checkpoint(self, checkpoint): |
| 652 | + lr_dict = checkpoint['lr_schedulers'][0] |
| 653 | + # 2 batches ran. since the lr_dict interval is `step`, the step count should be 2 |
| 654 | + assert self.trainer.global_step + 1 == batches # the global step hasn't been increased yet |
| 655 | + compare_to = max_epochs if epoch_interval else batches |
| 656 | + assert lr_dict['_step_count'] - 1 == compare_to # step count starts at 1 |
| 657 | + assert lr_dict['_last_lr'] == [lr * gamma**compare_to] |
| 658 | + self.on_save_checkpoint_called = True |
| 659 | + |
| 660 | + model = TestModel() |
| 661 | + trainer.fit(model) |
| 662 | + assert model.on_save_checkpoint_called |
| 663 | + |
| 664 | + |
| 665 | +def test_plateau_scheduler_lr_step_interval_updated_after_saving(tmpdir): |
| 666 | + batches = 4 |
| 667 | + trainer = Trainer( |
| 668 | + default_root_dir=tmpdir, |
| 669 | + progress_bar_refresh_rate=0, |
| 670 | + logger=False, |
| 671 | + max_epochs=1, |
| 672 | + limit_train_batches=batches, |
| 673 | + limit_val_batches=1, |
| 674 | + callbacks=[ModelCheckpoint(dirpath=tmpdir)] |
| 675 | + ) |
| 676 | + |
| 677 | + class TestModel(BoringModel): |
| 678 | + |
| 679 | + def training_step(self, batch, batch_idx, optimizer_idx): |
| 680 | + self.log("foo", batch_idx) |
| 681 | + return super().training_step(batch, batch_idx) |
| 682 | + |
| 683 | + def configure_optimizers(self): |
| 684 | + optimizer_1 = torch.optim.Adam(self.parameters()) |
| 685 | + optimizer_2 = torch.optim.Adam(self.parameters()) |
| 686 | + |
| 687 | + lr_scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_1) |
| 688 | + lr_dict_1 = {'scheduler': lr_scheduler1, 'interval': 'step', 'monitor': 'foo'} |
| 689 | + |
| 690 | + lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=1) |
| 691 | + lr_dict_2 = {'scheduler': lr_scheduler2, 'interval': 'step'} |
| 692 | + return [optimizer_1, optimizer_2], [lr_dict_1, lr_dict_2] |
| 693 | + |
| 694 | + def on_save_checkpoint(self, checkpoint): |
| 695 | + lr_dict_1 = checkpoint['lr_schedulers'][0] |
| 696 | + # since plateau schedulers are updated after saving checkpoint, last_epoch should be 3 |
| 697 | + assert lr_dict_1['last_epoch'] == batches - 1 # last epoch starts at 0 |
| 698 | + |
| 699 | + lr_dict_2 = checkpoint['lr_schedulers'][1] |
| 700 | + assert lr_dict_2['_step_count'] - 1 == batches # step count starts at 1 |
| 701 | + |
| 702 | + self.on_save_checkpoint_called = True |
| 703 | + |
| 704 | + model = TestModel() |
| 705 | + model.training_epoch_end = None |
| 706 | + trainer.fit(model) |
| 707 | + assert model.on_save_checkpoint_called |
0 commit comments