Skip to content

Commit 398f122

Browse files
carmoccas-rog
andauthored
Improve some tests (#5049)
* Improve some tests * Add TrainerState asserts Co-authored-by: Roger Shieh <[email protected]>
1 parent a49291d commit 398f122

File tree

2 files changed

+144
-259
lines changed

2 files changed

+144
-259
lines changed

tests/checkpointing/test_model_checkpoint.py

Lines changed: 30 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
import os.path as osp
1615
import pickle
1716
import platform
1817
import re
1918
from argparse import Namespace
20-
from distutils.version import LooseVersion
2119
from pathlib import Path
2220
from unittest import mock
23-
from unittest.mock import MagicMock, Mock
21+
from unittest.mock import Mock
2422

2523
import cloudpickle
2624
import pytest
@@ -641,20 +639,17 @@ def validation_epoch_end(self, outputs):
641639
@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
642640
def test_checkpoint_repeated_strategy(enable_pl_optimizer, tmpdir):
643641
"""
644-
This test validates that the checkpoint can be called when provided to callacks list
642+
This test validates that the checkpoint can be called when provided to callbacks list
645643
"""
646-
647644
checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}")
648645

649646
class ExtendedBoringModel(BoringModel):
650-
651647
def validation_step(self, batch, batch_idx):
652648
output = self.layer(batch)
653649
loss = self.loss(batch, output)
654650
return {"val_loss": loss}
655651

656652
model = ExtendedBoringModel()
657-
model.validation_step_end = None
658653
model.validation_epoch_end = None
659654
trainer = Trainer(
660655
max_epochs=1,
@@ -663,92 +658,30 @@ def validation_step(self, batch, batch_idx):
663658
limit_test_batches=2,
664659
callbacks=[checkpoint_callback],
665660
enable_pl_optimizer=enable_pl_optimizer,
661+
weights_summary=None,
662+
progress_bar_refresh_rate=0,
666663
)
667-
668664
trainer.fit(model)
669665
assert os.listdir(tmpdir) == ['epoch=00.ckpt']
670666

671-
def get_last_checkpoint():
672-
ckpts = os.listdir(tmpdir)
673-
ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x}
674-
num_ckpts = len(ckpts_map) - 1
675-
return ckpts_map[num_ckpts]
676-
677-
for idx in range(1, 5):
667+
for idx in range(4):
678668
# load from checkpoint
679-
chk = get_last_checkpoint()
680-
model = BoringModel.load_from_checkpoint(chk)
681-
trainer = pl.Trainer(
682-
max_epochs=1,
683-
limit_train_batches=2,
684-
limit_val_batches=2,
685-
limit_test_batches=2,
686-
resume_from_checkpoint=chk,
687-
enable_pl_optimizer=enable_pl_optimizer)
688-
trainer.fit(model)
689-
trainer.test(model)
690-
691-
assert str(os.listdir(tmpdir)) == "['epoch=00.ckpt']"
692-
693-
694-
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
695-
@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
696-
def test_checkpoint_repeated_strategy_tmpdir(enable_pl_optimizer, tmpdir):
697-
"""
698-
This test validates that the checkpoint can be called when provided to callacks list
699-
"""
700-
701-
checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath=os.path.join(tmpdir, "{epoch:02d}"))
702-
703-
class ExtendedBoringModel(BoringModel):
704-
705-
def validation_step(self, batch, batch_idx):
706-
output = self.layer(batch)
707-
loss = self.loss(batch, output)
708-
return {"val_loss": loss}
709-
710-
model = ExtendedBoringModel()
711-
model.validation_step_end = None
712-
model.validation_epoch_end = None
713-
trainer = Trainer(
714-
default_root_dir=tmpdir,
715-
max_epochs=1,
716-
limit_train_batches=2,
717-
limit_val_batches=2,
718-
limit_test_batches=2,
719-
callbacks=[checkpoint_callback],
720-
enable_pl_optimizer=enable_pl_optimizer,
721-
)
722-
723-
trainer.fit(model)
724-
assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs'])
725-
path_to_lightning_logs = osp.join(tmpdir, 'lightning_logs')
726-
assert sorted(os.listdir(path_to_lightning_logs)) == sorted(['version_0'])
727-
728-
def get_last_checkpoint():
729-
ckpts = os.listdir(tmpdir)
730-
ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x}
731-
num_ckpts = len(ckpts_map) - 1
732-
return ckpts_map[num_ckpts]
733-
734-
for idx in range(1, 5):
735-
736-
# load from checkpoint
737-
chk = get_last_checkpoint()
738-
model = LogInTwoMethods.load_from_checkpoint(chk)
669+
model = LogInTwoMethods.load_from_checkpoint(checkpoint_callback.best_model_path)
739670
trainer = pl.Trainer(
740671
default_root_dir=tmpdir,
741672
max_epochs=1,
742673
limit_train_batches=2,
743674
limit_val_batches=2,
744675
limit_test_batches=2,
745-
resume_from_checkpoint=chk,
746-
enable_pl_optimizer=enable_pl_optimizer)
747-
676+
resume_from_checkpoint=checkpoint_callback.best_model_path,
677+
enable_pl_optimizer=enable_pl_optimizer,
678+
weights_summary=None,
679+
progress_bar_refresh_rate=0,
680+
)
748681
trainer.fit(model)
749-
trainer.test(model)
750-
assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs'])
751-
assert sorted(os.listdir(path_to_lightning_logs)) == sorted([f'version_{i}' for i in range(idx + 1)])
682+
trainer.test(model, verbose=False)
683+
assert set(os.listdir(tmpdir)) == {'epoch=00.ckpt', 'lightning_logs'}
684+
assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f'version_{i}' for i in range(4)}
752685

753686

754687
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@@ -760,86 +693,71 @@ def test_checkpoint_repeated_strategy_extended(enable_pl_optimizer, tmpdir):
760693
"""
761694

762695
class ExtendedBoringModel(BoringModel):
763-
764696
def validation_step(self, batch, batch_idx):
765697
output = self.layer(batch)
766698
loss = self.loss(batch, output)
767699
return {"val_loss": loss}
768700

701+
def validation_epoch_end(self, *_):
702+
...
703+
769704
def assert_trainer_init(trainer):
770705
assert not trainer.checkpoint_connector.has_trained
771706
assert trainer.global_step == 0
772707
assert trainer.current_epoch == 0
773708

774709
def get_last_checkpoint(ckpt_dir):
775-
ckpts = os.listdir(ckpt_dir)
776-
ckpts.sort()
777-
return osp.join(ckpt_dir, ckpts[-1])
710+
last = ckpt_dir.listdir(sort=True)[-1]
711+
return str(last)
778712

779713
def assert_checkpoint_content(ckpt_dir):
780714
chk = pl_load(get_last_checkpoint(ckpt_dir))
781715
assert chk["epoch"] == epochs
782716
assert chk["global_step"] == 4
783717

784718
def assert_checkpoint_log_dir(idx):
785-
lightning_logs_path = osp.join(tmpdir, 'lightning_logs')
786-
assert sorted(os.listdir(lightning_logs_path)) == [f'version_{i}' for i in range(idx + 1)]
787-
assert len(os.listdir(ckpt_dir)) == epochs
788-
789-
def get_model():
790-
model = ExtendedBoringModel()
791-
model.validation_step_end = None
792-
model.validation_epoch_end = None
793-
return model
719+
lightning_logs = tmpdir / 'lightning_logs'
720+
actual = [d.basename for d in lightning_logs.listdir(sort=True)]
721+
assert actual == [f'version_{i}' for i in range(idx + 1)]
722+
assert len(ckpt_dir.listdir()) == epochs
794723

795-
ckpt_dir = osp.join(tmpdir, 'checkpoints')
724+
ckpt_dir = tmpdir / 'checkpoints'
796725
checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
797726
epochs = 2
798727
limit_train_batches = 2
799-
800-
model = get_model()
801-
802728
trainer_config = dict(
803729
default_root_dir=tmpdir,
804730
max_epochs=epochs,
805731
limit_train_batches=limit_train_batches,
806732
limit_val_batches=3,
807733
limit_test_batches=4,
808734
enable_pl_optimizer=enable_pl_optimizer,
809-
)
810-
811-
trainer = pl.Trainer(
812-
**trainer_config,
813735
callbacks=[checkpoint_cb],
814736
)
737+
trainer = pl.Trainer(**trainer_config)
815738
assert_trainer_init(trainer)
816739

740+
model = ExtendedBoringModel()
817741
trainer.fit(model)
818742
assert trainer.checkpoint_connector.has_trained
819743
assert trainer.global_step == epochs * limit_train_batches
820744
assert trainer.current_epoch == epochs - 1
821745
assert_checkpoint_log_dir(0)
746+
assert_checkpoint_content(ckpt_dir)
822747

823748
trainer.test(model)
824749
assert trainer.current_epoch == epochs - 1
825750

826-
assert_checkpoint_content(ckpt_dir)
827-
828751
for idx in range(1, 5):
829752
chk = get_last_checkpoint(ckpt_dir)
830753
assert_checkpoint_content(ckpt_dir)
831754

832-
checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
833-
model = get_model()
834-
835755
# load from checkpoint
836-
trainer = pl.Trainer(
837-
**trainer_config,
838-
resume_from_checkpoint=chk,
839-
callbacks=[checkpoint_cb],
840-
)
756+
trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)]
757+
trainer = pl.Trainer(**trainer_config, resume_from_checkpoint=chk)
841758
assert_trainer_init(trainer)
842759

760+
model = ExtendedBoringModel()
843761
trainer.test(model)
844762
assert not trainer.checkpoint_connector.has_trained
845763
assert trainer.global_step == epochs * limit_train_batches

0 commit comments

Comments
 (0)