Skip to content

Commit b4811b3

Browse files
committed
fix the call for scheduler
1 parent 37ec90c commit b4811b3

File tree

5 files changed

+25
-30
lines changed

5 files changed

+25
-30
lines changed

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,6 @@ def get_evaluation_dataloaders(self, max_batches):
7272

7373
return dataloaders, max_batches
7474

75-
def should_skip_evaluation(self, dataloaders, max_batches):
76-
# skip when dataloaders aren't defined
77-
if dataloaders is None:
78-
return True
79-
80-
# enable disabling validation step with limit_val_batches = 0
81-
should_skip = sum(max_batches) == 0
82-
if should_skip:
83-
return True
84-
85-
return False
86-
8775
def on_evaluation_start(self, *args, **kwargs):
8876
if self.testing:
8977
self.trainer.call_hook('on_test_start', *args, **kwargs)

pytorch_lightning/trainer/trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ def train(self):
548548
# hook
549549
self.train_loop.on_train_end()
550550

551-
def run_evaluation(self, test_mode: bool = False, max_batches=None):
551+
def run_evaluation(self, test_mode: bool = False, max_batches=None, on_epoch=False):
552552

553553
# used to know if we are logging for val, test + reset cached results
554554
self.logger_connector.set_stage(test_mode, reset=True)
@@ -560,7 +560,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):
560560
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches)
561561

562562
# check if we want to skip this evaluation
563-
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
563+
if sum(max_batches) == 0:
564564
return [], []
565565

566566
# ref model
@@ -621,6 +621,10 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):
621621
# hook
622622
self.evaluation_loop.on_evaluation_epoch_end()
623623

624+
# update epoch-level lr_schedulers
625+
if on_epoch:
626+
self.optimizer_connector.update_learning_rates(interval='epoch')
627+
624628
# hook
625629
self.evaluation_loop.on_evaluation_end()
626630

pytorch_lightning/trainer/training_loop.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -626,19 +626,13 @@ def run_training_epoch(self):
626626
self.trainer.total_batch_idx += 1
627627

628628
# stop epoch if we limited the number of training batches
629-
if self._num_training_batches_reached():
629+
if self._num_training_batches_reached(is_last_batch):
630630
break
631631

632632
# progress global step according to grads progress
633633
self.increment_accumulated_grad_global_step()
634634

635635
# epoch end hook
636-
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
637-
if should_check_val:
638-
self.trainer.run_evaluation(test_mode=False)
639-
# reset stage to train
640-
self.trainer.logger_connector.set_stage("train")
641-
642636
self.run_on_epoch_end_hook(epoch_output)
643637

644638
# log epoch metrics
@@ -649,10 +643,19 @@ def run_training_epoch(self):
649643
self.num_optimizers
650644
)
651645

652-
# update LR schedulers
653-
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
646+
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
647+
if should_check_val:
648+
self.trainer.run_evaluation(test_mode=False, on_epoch=True)
649+
# reset stage to train
650+
self.trainer.logger_connector.set_stage("train")
651+
652+
should_skip_eval = sum(self.trainer.num_val_batches) == 0
653+
should_train_only_check = not self.trainer.enable_validation and should_skip_eval
654+
655+
if should_skip_eval or should_train_only_check:
656+
# update epoch level lr_schedulers
657+
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
654658

655-
should_train_only_check = not self.trainer.enable_validation and (sum(self.trainer.num_val_batches) == 0)
656659
self.check_checkpoint_callback(should_train_only_check)
657660
self.check_early_stopping_callback(should_train_only_check)
658661

@@ -890,8 +893,8 @@ def increment_accumulated_grad_global_step(self):
890893
def _accumulated_batches_reached(self):
891894
return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0
892895

893-
def _num_training_batches_reached(self):
894-
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches
896+
def _num_training_batches_reached(self, is_last_batch=False):
897+
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch
895898

896899
def should_accumulate(self):
897900
# checks if backward or backward + optimizer step (via closure)

tests/callbacks/test_callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ def test_trainer_callback_system(torch_save):
8888
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
8989
call.on_batch_end(trainer, model),
9090
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
91+
call.on_epoch_end(trainer, model),
92+
call.on_train_epoch_end(trainer, model, ANY),
9193
call.on_validation_start(trainer, model),
9294
call.on_validation_epoch_start(trainer, model),
9395
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
9496
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
9597
call.on_validation_epoch_end(trainer, model),
9698
call.on_validation_end(trainer, model),
9799
call.on_save_checkpoint(trainer, model),
98-
call.on_epoch_end(trainer, model),
99-
call.on_train_epoch_end(trainer, model, ANY),
100100
call.on_train_end(trainer, model),
101101
call.on_fit_end(trainer, model),
102102
call.teardown(trainer, model, 'fit'),

tests/models/test_hooks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,15 +328,15 @@ def on_test_model_train(self):
328328
'on_after_backward',
329329
'on_before_zero_grad',
330330
'on_train_batch_end',
331+
'on_epoch_end',
332+
'on_train_epoch_end',
331333
'on_validation_model_eval',
332334
'on_validation_epoch_start',
333335
'on_validation_batch_start',
334336
'on_validation_batch_end',
335337
'on_validation_epoch_end',
336338
'on_save_checkpoint',
337339
'on_validation_model_train',
338-
'on_epoch_end',
339-
'on_train_epoch_end',
340340
'on_train_end',
341341
'on_fit_end',
342342
]

0 commit comments

Comments
 (0)