Skip to content

Commit b156d49

Browse files
committed
Refactor global step update
1 parent 374ff75 commit b156d49

File tree

2 files changed

+8
-36
lines changed

2 files changed

+8
-36
lines changed

pytorch_lightning/trainer/training_loop.py

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -512,28 +512,17 @@ def run_training_epoch(self):
512512
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
513513
self.trainer.checkpoint_connector.has_trained = True
514514

515-
# max steps reached, end training
516-
if (
517-
self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1
518-
and self._accumulated_batches_reached()
519-
):
520-
break
521-
522-
# end epoch early
523-
# stop when the flag is changed or we've gone past the amount
524-
# requested in the batches
525-
if self.trainer.should_stop:
526-
break
527-
528515
self.trainer.total_batch_idx += 1
529516

530-
# stop epoch if we limited the number of training batches
531-
if self._num_training_batches_reached(is_last_batch):
532-
break
533-
534517
# progress global step according to grads progress
535518
self.increment_accumulated_grad_global_step()
536519

520+
max_steps_reached = (
521+
self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step
522+
)
523+
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
524+
break
525+
537526
if batch_idx is None:
538527
# dataloader/iterator did not produce a batch
539528
return
@@ -552,18 +541,6 @@ def run_training_epoch(self):
552541
if (val_loop_called and not should_check_val) or should_train_only:
553542
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
554543

555-
if should_train_only:
556-
self.check_checkpoint_callback(True)
557-
558-
if should_check_val:
559-
self.trainer.validating = True
560-
self.trainer.run_evaluation(on_epoch=True)
561-
self.trainer.training = True
562-
563-
# increment the global step once
564-
# progress global step according to grads progress
565-
self.increment_accumulated_grad_global_step()
566-
567544
def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
568545
# inform logger the batch loop has finished
569546
self.trainer.logger_connector.on_train_epoch_end()
@@ -863,16 +840,12 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo
863840
elif self.trainer.val_check_batch != float('inf'):
864841
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
865842

866-
# Note: num_training_batches is also inf for iterable datasets with no length defined
867-
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
868843
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
869844

870845
if on_epoch:
871-
return (
872-
is_val_check_batch and epoch_end_val_check
873-
) or self.trainer.should_stop or is_last_batch_for_infinite_dataset
846+
return is_val_check_batch or self.trainer.should_stop or is_last_batch_for_infinite_dataset
874847
else:
875-
return is_val_check_batch and not epoch_end_val_check
848+
return is_val_check_batch
876849

877850
def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
878851
# enable not needing to add opt_idx to training_step

tests/trainer/test_trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,6 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
800800

801801
with pytest.raises(ValueError, match=r".*The loss returned in `training_step` is.*"):
802802
trainer.fit(model)
803-
assert trainer.global_step == model.test_step_inf_loss
804803

805804
for param in model.parameters():
806805
assert torch.isfinite(param).all()

0 commit comments

Comments
 (0)