2626import torch
2727import yaml
2828from omegaconf import Container , OmegaConf
29+ from torch import optim
2930
3031import pytorch_lightning as pl
3132import tests .helpers .utils as tutils
@@ -47,8 +48,8 @@ def training_step(self, batch, batch_idx):
4748
4849 def validation_epoch_end (self , outputs ):
4950 outs = torch .stack ([x ['x' ] for x in outputs ]).mean ()
50- self .log ('epoch' , self .current_epoch , on_epoch = True )
51- self .log ('val_acc' , outs , on_epoch = True )
51+ self .log ('epoch' , self .current_epoch )
52+ self .log ('val_acc' , outs )
5253
5354
5455@mock .patch .dict (os .environ , {"PL_DEV_DEBUG" : "1" })
@@ -57,14 +58,16 @@ def validation_epoch_end(self, outputs):
5758 [('base' , "base" , 'val_log' ), ('base' , "base" , 'train_log_epoch' ), (None , "base" , 'train_log_epoch' ),
5859 ("base" , None , 'train_log_epoch' )],
5960)
60- def test_model_checkpoint_correct_score_and_checkpoint (tmpdir , validation_step , val_dataloaders , monitor ):
61+ @pytest .mark .parametrize ('reduce_lr_on_plateau' , [False , True ])
62+ def test_model_checkpoint_score_and_ckpt (tmpdir , validation_step , val_dataloaders , monitor , reduce_lr_on_plateau ):
6163 """
6264 Test that when a model checkpoint is saved, it saves with
6365 the correct score appended to ckpt_path and checkpoint data
6466 """
6567 max_epochs = 3
6668 limit_train_batches = 5
6769 limit_val_batches = 7
70+ lr = 1e-1
6871
6972 class CustomBoringModel (BoringModel ):
7073
@@ -74,21 +77,28 @@ def __init__(self):
7477 self .val_logs = torch .randn (max_epochs , limit_val_batches )
7578
7679 def training_step (self , batch , batch_idx ):
77- out = super ().training_step (batch , batch_idx )
7880 log_value = self .train_log_epochs [self .current_epoch , batch_idx ]
7981 self .log ('train_log' , log_value , on_epoch = True )
80- return out
82+ return super (). training_step ( batch , batch_idx )
8183
8284 def validation_step (self , batch , batch_idx ):
83- out = super ().validation_step (batch , batch_idx )
8485 log_value = self .val_logs [self .current_epoch , batch_idx ]
8586 self .log ('val_log' , log_value )
8687 self .log ('epoch' , self .current_epoch , on_epoch = True )
87- return out
88+ return super (). validation_step ( batch , batch_idx )
8889
8990 def configure_optimizers (self ):
90- optimizer = torch .optim .SGD (self .layer .parameters (), lr = 0.2 )
91- lr_scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
91+ optimizer = optim .SGD (self .parameters (), lr = lr )
92+
93+ if reduce_lr_on_plateau :
94+ lr_scheduler = {
95+ 'scheduler' : optim .lr_scheduler .ReduceLROnPlateau (optimizer ),
96+ 'monitor' : monitor ,
97+ 'strict' : True ,
98+ }
99+ else :
100+ lr_scheduler = optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
101+
92102 return [optimizer ], [lr_scheduler ]
93103
94104 filename = '{' + f'{ monitor } ' + ':.4f}-{epoch}'
@@ -109,11 +119,15 @@ def configure_optimizers(self):
109119 max_epochs = max_epochs ,
110120 progress_bar_refresh_rate = 0 ,
111121 )
112- trainer .fit (model )
122+ results = trainer .fit (model )
123+ assert results
124+ assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
113125
114126 ckpt_files = list (Path (tmpdir ).glob ('*.ckpt' ))
115127 scores = [metric [monitor ] for metric in trainer .dev_debugger .logged_metrics if monitor in metric ]
128+ lr_scheduler_debug = trainer .dev_debugger .saved_lr_scheduler_updates
116129 assert len (ckpt_files ) == len (scores ) == max_epochs
130+ assert len (lr_scheduler_debug ) == max_epochs
117131
118132 for epoch in range (max_epochs ):
119133 score = scores [epoch ]
@@ -130,9 +144,118 @@ def configure_optimizers(self):
130144 assert mc_specific_data ['monitor' ] == monitor
131145 assert mc_specific_data ['current_score' ] == score
132146
133- lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
134- assert lr_scheduler_specific_data ['_step_count' ] == epoch + 2
135- assert lr_scheduler_specific_data ['_last_lr' ][0 ], 4 == 0.2 * (0.1 ** (epoch + 1 ))
147+ if not reduce_lr_on_plateau :
148+ lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
149+ assert lr_scheduler_specific_data ['_step_count' ] == epoch + 2
150+ assert lr_scheduler_specific_data ['_last_lr' ][0 ] == lr * (lr ** (epoch + 1 ))
151+
152+ assert lr_scheduler_debug [epoch ]['monitor_val' ] == (score if reduce_lr_on_plateau else None )
153+ assert lr_scheduler_debug [epoch ]['monitor_key' ] == (monitor if reduce_lr_on_plateau else None )
154+
155+
156+ @mock .patch .dict (os .environ , {"PL_DEV_DEBUG" : "1" })
157+ @pytest .mark .parametrize (
158+ "val_check_interval,reduce_lr_on_plateau" ,
159+ [
160+ (0.25 , True ),
161+ (0.25 , False ),
162+ (0.33 , False ),
163+ ],
164+ )
165+ def test_model_checkpoint_score_and_ckpt_val_check_interval (tmpdir , val_check_interval , reduce_lr_on_plateau ):
166+ """
167+ Test that when a model checkpoint is saved, it saves with the correct
168+ score appended to ckpt_path and checkpoint data with val_check_interval
169+ """
170+ max_epochs = 3
171+ limit_train_batches = 12
172+ limit_val_batches = 7
173+ lr = 1e-1
174+ monitor = 'val_log'
175+ per_epoch_steps = int (limit_train_batches * val_check_interval )
176+ per_epoch_call_count = limit_train_batches // per_epoch_steps
177+
178+ class CustomBoringModel (BoringModel ):
179+
180+ def __init__ (self ):
181+ super ().__init__ ()
182+ self .val_logs = torch .randn (per_epoch_call_count * max_epochs , limit_val_batches )
183+ self .val_loop_count = 0
184+
185+ def validation_step (self , batch , batch_idx ):
186+ log_value = self .val_logs [self .val_loop_count , batch_idx ]
187+ self .log ('val_log' , log_value )
188+ self .log ('epoch' , self .current_epoch , on_epoch = True )
189+ return super ().validation_step (batch , batch_idx )
190+
191+ def validation_epoch_end (self , outputs ):
192+ self .val_loop_count += 1
193+ super ().validation_epoch_end (outputs )
194+
195+ def configure_optimizers (self ):
196+ optimizer = optim .SGD (self .parameters (), lr = lr )
197+
198+ if reduce_lr_on_plateau :
199+ lr_scheduler = {
200+ 'scheduler' : optim .lr_scheduler .ReduceLROnPlateau (optimizer ),
201+ 'monitor' : monitor ,
202+ 'strict' : True ,
203+ }
204+ else :
205+ lr_scheduler = optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
206+
207+ return [optimizer ], [lr_scheduler ]
208+
209+ filename = '{' + f'{ monitor } ' + ':.4f}-{epoch}'
210+ checkpoint = ModelCheckpoint (dirpath = tmpdir , filename = filename , monitor = monitor , save_top_k = - 1 )
211+
212+ model = CustomBoringModel ()
213+
214+ trainer = Trainer (
215+ default_root_dir = tmpdir ,
216+ callbacks = [checkpoint ],
217+ limit_train_batches = limit_train_batches ,
218+ limit_val_batches = limit_val_batches ,
219+ max_epochs = max_epochs ,
220+ val_check_interval = val_check_interval ,
221+ progress_bar_refresh_rate = 0 ,
222+ num_sanity_val_steps = 0 ,
223+ )
224+ results = trainer .fit (model )
225+ assert results
226+ assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
227+
228+ ckpt_files = list (Path (tmpdir ).glob ('*.ckpt' ))
229+ scores = [metric [monitor ] for metric in trainer .dev_debugger .logged_metrics if monitor in metric ]
230+ lr_scheduler_debug = trainer .dev_debugger .saved_lr_scheduler_updates
231+ assert len (ckpt_files ) == len (scores ) == per_epoch_call_count * max_epochs
232+ assert len (lr_scheduler_debug ) == max_epochs
233+
234+ for epoch in range (max_epochs ):
235+ for ix in range (per_epoch_call_count ):
236+ global_ix = ix + per_epoch_call_count * epoch
237+ score = scores [global_ix ]
238+ expected_score = getattr (model , f'{ monitor } s' )[global_ix ].mean ().item ()
239+ expected_filename = f'{ monitor } ={ score :.4f} -epoch={ epoch } .ckpt'
240+ assert math .isclose (score , expected_score , rel_tol = 1e-4 )
241+
242+ chk = pl_load (os .path .join (checkpoint .dirpath , expected_filename ))
243+ assert chk ['epoch' ] == epoch + 1
244+ assert chk ['global_step' ] == per_epoch_steps * (global_ix + 1 )
245+
246+ mc_specific_data = chk ['callbacks' ][type (checkpoint )]
247+ assert mc_specific_data ['dirpath' ] == checkpoint .dirpath
248+ assert mc_specific_data ['monitor' ] == monitor
249+ assert mc_specific_data ['current_score' ] == score
250+
251+ if not reduce_lr_on_plateau :
252+ lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
253+ did_update = 1 if ix + 1 == per_epoch_call_count else 0
254+ assert lr_scheduler_specific_data ['_step_count' ] == epoch + 1 + did_update
255+ assert lr_scheduler_specific_data ['_last_lr' ][0 ] == lr * (lr ** (epoch + did_update ))
256+
257+ assert lr_scheduler_debug [epoch ]['monitor_val' ] == (score if reduce_lr_on_plateau else None )
258+ assert lr_scheduler_debug [epoch ]['monitor_key' ] == (monitor if reduce_lr_on_plateau else None )
136259
137260
138261@pytest .mark .parametrize ("save_top_k" , [- 1 , 0 , 1 , 2 ])
0 commit comments