|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import os |
| 15 | +import os.path as osp |
| 16 | +import pytorch_lightning as pl |
15 | 17 | from distutils.version import LooseVersion |
16 | 18 | from unittest.mock import MagicMock, Mock |
17 | 19 |
|
|
30 | 32 | from pytorch_lightning.callbacks import ModelCheckpoint |
31 | 33 | from pytorch_lightning.loggers import TensorBoardLogger |
32 | 34 | from tests.base import EvalModelTemplate, BoringModel |
| 35 | +from pytorch_lightning.utilities.cloud_io import load as pl_load |
33 | 36 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
34 | 37 |
|
35 | 38 |
|
@@ -472,7 +475,8 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v |
472 | 475 | model.validation_step = None |
473 | 476 | trainer = Trainer( |
474 | 477 | default_root_dir=tmpdir, |
475 | | - checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last), |
| 478 | + checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir, |
| 479 | + save_top_k=0, save_last=save_last), |
476 | 480 | max_epochs=max_epochs, |
477 | 481 | ) |
478 | 482 | trainer.fit(model) |
@@ -542,7 +546,7 @@ def validation_epoch_end(self, outputs): |
542 | 546 | assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1 |
543 | 547 |
|
544 | 548 |
|
545 | | -def test_checkpoint_within_callbacks_list(tmpdir): |
| 549 | +def test_checkpoint_repeated_strategy(tmpdir): |
546 | 550 | """ |
547 | 551 | This test validates that the checkpoint can be called when provided to callacks list |
548 | 552 | """ |
@@ -572,6 +576,159 @@ def validation_step(self, batch, batch_idx): |
572 | 576 | trainer.fit(model) |
573 | 577 | assert os.listdir(tmpdir) == ['epoch=00.ckpt'] |
574 | 578 |
|
| 579 | + def get_last_checkpoint(): |
| 580 | + ckpts = os.listdir(tmpdir) |
| 581 | + ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x} |
| 582 | + num_ckpts = len(ckpts_map) - 1 |
| 583 | + return ckpts_map[num_ckpts] |
| 584 | + |
| 585 | + for idx in range(1, 5): |
| 586 | + # load from checkpoint |
| 587 | + chk = get_last_checkpoint() |
| 588 | + model = BoringModel.load_from_checkpoint(chk) |
| 589 | + trainer = pl.Trainer(max_epochs=1, |
| 590 | + limit_train_batches=2, |
| 591 | + limit_val_batches=2, |
| 592 | + limit_test_batches=2, |
| 593 | + resume_from_checkpoint=chk) |
| 594 | + trainer.fit(model) |
| 595 | + trainer.test(model) |
| 596 | + |
| 597 | + assert str(os.listdir(tmpdir)) == "['epoch=00.ckpt']" |
| 598 | + |
| 599 | + |
| 600 | +def test_checkpoint_repeated_strategy_tmpdir(tmpdir): |
| 601 | + """ |
| 602 | + This test validates that the checkpoint can be called when provided to callacks list |
| 603 | + """ |
| 604 | + |
| 605 | + os.environ['PL_DEV_DEBUG'] = '1' |
| 606 | + |
| 607 | + checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath=os.path.join(tmpdir, "{epoch:02d}")) |
| 608 | + |
| 609 | + class ExtendedBoringModel(BoringModel): |
| 610 | + |
| 611 | + def validation_step(self, batch, batch_idx): |
| 612 | + output = self.layer(batch) |
| 613 | + loss = self.loss(batch, output) |
| 614 | + return {"val_loss": loss} |
| 615 | + |
| 616 | + model = ExtendedBoringModel() |
| 617 | + model.validation_step_end = None |
| 618 | + model.validation_epoch_end = None |
| 619 | + trainer = Trainer( |
| 620 | + default_root_dir=tmpdir, |
| 621 | + max_epochs=1, |
| 622 | + limit_train_batches=2, |
| 623 | + limit_val_batches=2, |
| 624 | + limit_test_batches=2, |
| 625 | + callbacks=[checkpoint_callback]) |
| 626 | + |
| 627 | + trainer.fit(model) |
| 628 | + assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs']) |
| 629 | + path_to_lightning_logs = osp.join(tmpdir, 'lightning_logs') |
| 630 | + assert sorted(os.listdir(path_to_lightning_logs)) == sorted(['version_0']) |
| 631 | + |
| 632 | + def get_last_checkpoint(): |
| 633 | + ckpts = os.listdir(tmpdir) |
| 634 | + ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x} |
| 635 | + num_ckpts = len(ckpts_map) - 1 |
| 636 | + return ckpts_map[num_ckpts] |
| 637 | + |
| 638 | + for idx in range(1, 5): |
| 639 | + |
| 640 | + # load from checkpoint |
| 641 | + chk = get_last_checkpoint() |
| 642 | + model = BoringModel.load_from_checkpoint(chk) |
| 643 | + trainer = pl.Trainer(default_root_dir=tmpdir, |
| 644 | + max_epochs=1, |
| 645 | + limit_train_batches=2, |
| 646 | + limit_val_batches=2, |
| 647 | + limit_test_batches=2, |
| 648 | + resume_from_checkpoint=chk) |
| 649 | + |
| 650 | + trainer.fit(model) |
| 651 | + trainer.test(model) |
| 652 | + assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs']) |
| 653 | + assert sorted(os.listdir(path_to_lightning_logs)) == sorted([f'version_{i}' for i in range(idx + 1)]) |
| 654 | + |
| 655 | + |
| 656 | +def test_checkpoint_repeated_strategy_extended(tmpdir): |
| 657 | + """ |
| 658 | + This test validates checkpoint can be called several times without |
| 659 | + increasing internally its global step if nothing run. |
| 660 | + """ |
| 661 | + |
| 662 | + os.environ['PL_DEV_DEBUG'] = '1' |
| 663 | + |
| 664 | + class ExtendedBoringModel(BoringModel): |
| 665 | + |
| 666 | + def validation_step(self, batch, batch_idx): |
| 667 | + output = self.layer(batch) |
| 668 | + loss = self.loss(batch, output) |
| 669 | + return {"val_loss": loss} |
| 670 | + |
| 671 | + model = ExtendedBoringModel() |
| 672 | + model.validation_step_end = None |
| 673 | + model.validation_epoch_end = None |
| 674 | + trainer = pl.Trainer(default_root_dir=tmpdir, |
| 675 | + max_epochs=1, |
| 676 | + limit_train_batches=2, |
| 677 | + limit_val_batches=2, |
| 678 | + limit_test_batches=2, |
| 679 | + ) |
| 680 | + |
| 681 | + assert trainer.checkpoint_connector.has_trained is not True |
| 682 | + assert trainer.current_epoch == 0 |
| 683 | + trainer.fit(model) |
| 684 | + assert trainer.checkpoint_connector.has_trained is True |
| 685 | + assert trainer.global_step == 2 |
| 686 | + assert trainer.current_epoch == 0 |
| 687 | + trainer.test(model) |
| 688 | + assert trainer.current_epoch == 0 |
| 689 | + assert str(os.listdir(osp.join(tmpdir, 'lightning_logs'))) == "['version_0']" |
| 690 | + |
| 691 | + def get_last_checkpoint(): |
| 692 | + logs_dir = osp.join(tmpdir, 'lightning_logs') |
| 693 | + versions = os.listdir(logs_dir) |
| 694 | + versions.sort() |
| 695 | + |
| 696 | + last_version = versions[-1] |
| 697 | + ckpt_dir = osp.join(logs_dir, last_version, "checkpoints") |
| 698 | + |
| 699 | + ckpts = os.listdir(ckpt_dir) |
| 700 | + ckpts.sort() |
| 701 | + |
| 702 | + return osp.join(ckpt_dir, ckpts[-1]) |
| 703 | + |
| 704 | + def assert_checkpoint_content(): |
| 705 | + chk = pl_load(get_last_checkpoint()) |
| 706 | + assert chk["epoch"] == 1 |
| 707 | + assert chk["global_step"] == 2 |
| 708 | + |
| 709 | + assert_checkpoint_content() |
| 710 | + |
| 711 | + for idx in range(1, 5): |
| 712 | + # load from checkpoint |
| 713 | + chk = get_last_checkpoint() |
| 714 | + assert_checkpoint_content() |
| 715 | + model = BoringModel.load_from_checkpoint(chk) |
| 716 | + trainer = pl.Trainer(default_root_dir=tmpdir, |
| 717 | + max_epochs=1, |
| 718 | + limit_train_batches=2, |
| 719 | + limit_val_batches=2, |
| 720 | + limit_test_batches=2, |
| 721 | + resume_from_checkpoint=chk) |
| 722 | + assert trainer.checkpoint_connector.has_trained is not True |
| 723 | + assert trainer.global_step == 0 |
| 724 | + trainer.test(model) |
| 725 | + assert trainer.global_step == 2 |
| 726 | + trainer.fit(model) |
| 727 | + assert trainer.global_step == 2 |
| 728 | + assert trainer.checkpoint_connector.has_trained is not True |
| 729 | + lightning_logs_path = osp.join(tmpdir, 'lightning_logs') |
| 730 | + assert sorted(os.listdir(lightning_logs_path)) == [f"version_{i}" for i in range(idx + 1)] |
| 731 | + |
575 | 732 |
|
576 | 733 | @pytest.mark.parametrize( |
577 | 734 | 'filepath, dirpath, filename', |
|
0 commit comments