Skip to content

Commit 7afb814

Browse files
committed
Comments requested by Thomas
1 parent 1be025d commit 7afb814

File tree

3 files changed

+5
-1
lines changed

3 files changed

+5
-1
lines changed

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
221221
trainer.fit_loop._skip_backward = False
222222

223223
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
224+
# the trainer increases the current epoch before this hook is called
224225
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
225226
# BatchNorm epoch update. Reset state
226227
trainer.accumulate_grad_batches = self._accumulate_grad_batches

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def reset(self) -> None:
127127
rank_zero_warn(
128128
"You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable"
129129
" results if further training is done. Consider using an end-of-epoch checkpoint or enabling"
130-
" fault-tolerant training."
130+
" fault-tolerant training:"
131+
" https://pytorch-lightning.readthedocs.io/en/stable/advanced/fault_tolerant_training.html"
131132
)
132133
else:
133134
self.batch_progress.reset_on_run()

pytorch_lightning/loops/fit_loop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def done(self) -> bool:
169169
"""Evaluates when to leave the loop."""
170170
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
171171
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
172+
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
173+
# we use it here because the checkpoint data won't have `completed` increased yet
172174
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
173175

174176
should_stop = False

0 commit comments

Comments
 (0)