1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import os
15- import os .path as osp
1615import pickle
1716import platform
1817import re
1918from argparse import Namespace
20- from distutils .version import LooseVersion
2119from pathlib import Path
2220from unittest import mock
23- from unittest .mock import MagicMock , Mock
21+ from unittest .mock import Mock
2422
2523import cloudpickle
2624import pytest
@@ -641,20 +639,17 @@ def validation_epoch_end(self, outputs):
641639@pytest .mark .parametrize ("enable_pl_optimizer" , [False , True ])
642640def test_checkpoint_repeated_strategy (enable_pl_optimizer , tmpdir ):
643641 """
644- This test validates that the checkpoint can be called when provided to callacks list
642+ This test validates that the checkpoint can be called when provided to callbacks list
645643 """
646-
647644 checkpoint_callback = ModelCheckpoint (monitor = 'val_loss' , dirpath = tmpdir , filename = "{epoch:02d}" )
648645
649646 class ExtendedBoringModel (BoringModel ):
650-
651647 def validation_step (self , batch , batch_idx ):
652648 output = self .layer (batch )
653649 loss = self .loss (batch , output )
654650 return {"val_loss" : loss }
655651
656652 model = ExtendedBoringModel ()
657- model .validation_step_end = None
658653 model .validation_epoch_end = None
659654 trainer = Trainer (
660655 max_epochs = 1 ,
@@ -663,92 +658,30 @@ def validation_step(self, batch, batch_idx):
663658 limit_test_batches = 2 ,
664659 callbacks = [checkpoint_callback ],
665660 enable_pl_optimizer = enable_pl_optimizer ,
661+ weights_summary = None ,
662+ progress_bar_refresh_rate = 0 ,
666663 )
667-
668664 trainer .fit (model )
669665 assert os .listdir (tmpdir ) == ['epoch=00.ckpt' ]
670666
671- def get_last_checkpoint ():
672- ckpts = os .listdir (tmpdir )
673- ckpts_map = {int (x .split ("=" )[1 ].split ('.' )[0 ]): osp .join (tmpdir , x ) for x in ckpts if "epoch" in x }
674- num_ckpts = len (ckpts_map ) - 1
675- return ckpts_map [num_ckpts ]
676-
677- for idx in range (1 , 5 ):
667+ for idx in range (4 ):
678668 # load from checkpoint
679- chk = get_last_checkpoint ()
680- model = BoringModel .load_from_checkpoint (chk )
681- trainer = pl .Trainer (
682- max_epochs = 1 ,
683- limit_train_batches = 2 ,
684- limit_val_batches = 2 ,
685- limit_test_batches = 2 ,
686- resume_from_checkpoint = chk ,
687- enable_pl_optimizer = enable_pl_optimizer )
688- trainer .fit (model )
689- trainer .test (model )
690-
691- assert str (os .listdir (tmpdir )) == "['epoch=00.ckpt']"
692-
693-
694- @mock .patch .dict (os .environ , {"PL_DEV_DEBUG" : "1" })
695- @pytest .mark .parametrize ("enable_pl_optimizer" , [False , True ])
696- def test_checkpoint_repeated_strategy_tmpdir (enable_pl_optimizer , tmpdir ):
697- """
698- This test validates that the checkpoint can be called when provided to callacks list
699- """
700-
701- checkpoint_callback = ModelCheckpoint (monitor = 'val_loss' , filepath = os .path .join (tmpdir , "{epoch:02d}" ))
702-
703- class ExtendedBoringModel (BoringModel ):
704-
705- def validation_step (self , batch , batch_idx ):
706- output = self .layer (batch )
707- loss = self .loss (batch , output )
708- return {"val_loss" : loss }
709-
710- model = ExtendedBoringModel ()
711- model .validation_step_end = None
712- model .validation_epoch_end = None
713- trainer = Trainer (
714- default_root_dir = tmpdir ,
715- max_epochs = 1 ,
716- limit_train_batches = 2 ,
717- limit_val_batches = 2 ,
718- limit_test_batches = 2 ,
719- callbacks = [checkpoint_callback ],
720- enable_pl_optimizer = enable_pl_optimizer ,
721- )
722-
723- trainer .fit (model )
724- assert sorted (os .listdir (tmpdir )) == sorted (['epoch=00.ckpt' , 'lightning_logs' ])
725- path_to_lightning_logs = osp .join (tmpdir , 'lightning_logs' )
726- assert sorted (os .listdir (path_to_lightning_logs )) == sorted (['version_0' ])
727-
728- def get_last_checkpoint ():
729- ckpts = os .listdir (tmpdir )
730- ckpts_map = {int (x .split ("=" )[1 ].split ('.' )[0 ]): osp .join (tmpdir , x ) for x in ckpts if "epoch" in x }
731- num_ckpts = len (ckpts_map ) - 1
732- return ckpts_map [num_ckpts ]
733-
734- for idx in range (1 , 5 ):
735-
736- # load from checkpoint
737- chk = get_last_checkpoint ()
738- model = LogInTwoMethods .load_from_checkpoint (chk )
669+ model = LogInTwoMethods .load_from_checkpoint (checkpoint_callback .best_model_path )
739670 trainer = pl .Trainer (
740671 default_root_dir = tmpdir ,
741672 max_epochs = 1 ,
742673 limit_train_batches = 2 ,
743674 limit_val_batches = 2 ,
744675 limit_test_batches = 2 ,
745- resume_from_checkpoint = chk ,
746- enable_pl_optimizer = enable_pl_optimizer )
747-
676+ resume_from_checkpoint = checkpoint_callback .best_model_path ,
677+ enable_pl_optimizer = enable_pl_optimizer ,
678+ weights_summary = None ,
679+ progress_bar_refresh_rate = 0 ,
680+ )
748681 trainer .fit (model )
749- trainer .test (model )
750- assert sorted (os .listdir (tmpdir )) == sorted ([ 'epoch=00.ckpt' , 'lightning_logs' ])
751- assert sorted (os .listdir (path_to_lightning_logs )) == sorted ([ f'version_{ i } ' for i in range (idx + 1 )])
682+ trainer .test (model , verbose = False )
683+ assert set (os .listdir (tmpdir )) == { 'epoch=00.ckpt' , 'lightning_logs' }
684+ assert set (os .listdir (tmpdir . join ( "lightning_logs" ))) == { f'version_{ i } ' for i in range (4 )}
752685
753686
754687@mock .patch .dict (os .environ , {"PL_DEV_DEBUG" : "1" })
@@ -760,86 +693,71 @@ def test_checkpoint_repeated_strategy_extended(enable_pl_optimizer, tmpdir):
760693 """
761694
762695 class ExtendedBoringModel (BoringModel ):
763-
764696 def validation_step (self , batch , batch_idx ):
765697 output = self .layer (batch )
766698 loss = self .loss (batch , output )
767699 return {"val_loss" : loss }
768700
701+ def validation_epoch_end (self , * _ ):
702+ ...
703+
769704 def assert_trainer_init (trainer ):
770705 assert not trainer .checkpoint_connector .has_trained
771706 assert trainer .global_step == 0
772707 assert trainer .current_epoch == 0
773708
774709 def get_last_checkpoint (ckpt_dir ):
775- ckpts = os .listdir (ckpt_dir )
776- ckpts .sort ()
777- return osp .join (ckpt_dir , ckpts [- 1 ])
710+ last = ckpt_dir .listdir (sort = True )[- 1 ]
711+ return str (last )
778712
779713 def assert_checkpoint_content (ckpt_dir ):
780714 chk = pl_load (get_last_checkpoint (ckpt_dir ))
781715 assert chk ["epoch" ] == epochs
782716 assert chk ["global_step" ] == 4
783717
784718 def assert_checkpoint_log_dir (idx ):
785- lightning_logs_path = osp .join (tmpdir , 'lightning_logs' )
786- assert sorted (os .listdir (lightning_logs_path )) == [f'version_{ i } ' for i in range (idx + 1 )]
787- assert len (os .listdir (ckpt_dir )) == epochs
788-
789- def get_model ():
790- model = ExtendedBoringModel ()
791- model .validation_step_end = None
792- model .validation_epoch_end = None
793- return model
719+ lightning_logs = tmpdir / 'lightning_logs'
720+ actual = [d .basename for d in lightning_logs .listdir (sort = True )]
721+ assert actual == [f'version_{ i } ' for i in range (idx + 1 )]
722+ assert len (ckpt_dir .listdir ()) == epochs
794723
795- ckpt_dir = osp . join ( tmpdir , 'checkpoints' )
724+ ckpt_dir = tmpdir / 'checkpoints'
796725 checkpoint_cb = ModelCheckpoint (dirpath = ckpt_dir , save_top_k = - 1 )
797726 epochs = 2
798727 limit_train_batches = 2
799-
800- model = get_model ()
801-
802728 trainer_config = dict (
803729 default_root_dir = tmpdir ,
804730 max_epochs = epochs ,
805731 limit_train_batches = limit_train_batches ,
806732 limit_val_batches = 3 ,
807733 limit_test_batches = 4 ,
808734 enable_pl_optimizer = enable_pl_optimizer ,
809- )
810-
811- trainer = pl .Trainer (
812- ** trainer_config ,
813735 callbacks = [checkpoint_cb ],
814736 )
737+ trainer = pl .Trainer (** trainer_config )
815738 assert_trainer_init (trainer )
816739
740+ model = ExtendedBoringModel ()
817741 trainer .fit (model )
818742 assert trainer .checkpoint_connector .has_trained
819743 assert trainer .global_step == epochs * limit_train_batches
820744 assert trainer .current_epoch == epochs - 1
821745 assert_checkpoint_log_dir (0 )
746+ assert_checkpoint_content (ckpt_dir )
822747
823748 trainer .test (model )
824749 assert trainer .current_epoch == epochs - 1
825750
826- assert_checkpoint_content (ckpt_dir )
827-
828751 for idx in range (1 , 5 ):
829752 chk = get_last_checkpoint (ckpt_dir )
830753 assert_checkpoint_content (ckpt_dir )
831754
832- checkpoint_cb = ModelCheckpoint (dirpath = ckpt_dir , save_top_k = - 1 )
833- model = get_model ()
834-
835755 # load from checkpoint
836- trainer = pl .Trainer (
837- ** trainer_config ,
838- resume_from_checkpoint = chk ,
839- callbacks = [checkpoint_cb ],
840- )
756+ trainer_config ["callbacks" ] = [ModelCheckpoint (dirpath = ckpt_dir , save_top_k = - 1 )]
757+ trainer = pl .Trainer (** trainer_config , resume_from_checkpoint = chk )
841758 assert_trainer_init (trainer )
842759
760+ model = ExtendedBoringModel ()
843761 trainer .test (model )
844762 assert not trainer .checkpoint_connector .has_trained
845763 assert trainer .global_step == epochs * limit_train_batches
0 commit comments