Skip to content

Commit 789fae8

Browse files
carmoccaawaelchli
andauthored
Fix current_epoch value on training end (#8578)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 5e78f42 commit 789fae8

24 files changed

+145
-143
lines changed

CHANGELOG.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
243243
- DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655))
244244

245245

246-
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
247-
248-
249246
- Moved `Strategy` classes to the `strategies` directory ([#11226](https://github.com/PyTorchLightning/pytorch-lightning/pull/11226))
250247

251248

@@ -266,6 +263,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
266263

267264
- Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360))
268265

266+
267+
- The `trainer.current_epoch` value is now increased by 1 during and after `on_train_end` ([#8578](https://github.com/PyTorchLightning/pytorch-lightning/pull/8578))
268+
269+
269270
- Inherit from `ABC` for `Accelerator`: Users need to implement `auto_device_count` ([#11521](https://github.com/PyTorchLightning/pytorch-lightning/pull/11521))
270271

271272

@@ -288,8 +289,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
288289

289290
### Deprecated
290291

292+
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
293+
294+
291295
- Deprecated `Trainer.{validated,tested,predicted}_ckpt_path` and replaced with read-only property `Trainer.ckpt_path` set when checkpoints loaded via `Trainer.{fit,validate,test,predict}` ([#11696](https://github.com/PyTorchLightning/pytorch-lightning/pull/11696))
292296

297+
293298
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/pull/10103))
294299

295300

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,14 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
221221
trainer.fit_loop._skip_backward = False
222222

223223
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
224-
if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1:
224+
# the trainer increases the current epoch before this hook is called
225+
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
225226
# BatchNorm epoch update. Reset state
226227
trainer.accumulate_grad_batches = self._accumulate_grad_batches
227228
trainer.num_training_batches -= 1
228229
trainer.fit_loop.max_epochs -= 1
229230
self.reset_momenta()
230-
elif trainer.current_epoch == self.swa_end:
231+
elif trainer.current_epoch - 1 == self.swa_end:
231232
# Last SWA epoch. Transfer weights from average model to pl_module
232233
self.transfer_weights(self._average_model, pl_module)
233234

pytorch_lightning/loops/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import pytorch_lightning as pl
2222
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
2323
from pytorch_lightning.trainer.progress import BaseProgress
24-
from pytorch_lightning.utilities.enums import _FaultTolerantMode
2524
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2625

2726
T = TypeVar("T") # the output type of `run`
@@ -288,11 +287,9 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di
288287

289288
destination[prefix + "state_dict"] = self.on_save_checkpoint()
290289

291-
# do not get the mode from `self.trainer` because it might not have been attached yet
292-
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
293290
for k, v in self.__dict__.items():
294291
key = prefix + k
295-
if ft_enabled and isinstance(v, BaseProgress):
292+
if isinstance(v, BaseProgress):
296293
destination[key] = v.state_dict()
297294
elif isinstance(v, Loop):
298295
v.state_dict(destination, key + ".")

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import math
1415
from collections import defaultdict
1516
from typing import Any, Dict, Generator, Iterator, List, Optional, overload, Tuple, Union
1617

@@ -118,6 +119,17 @@ def reset(self) -> None:
118119
self.batch_progress.reset_on_restart()
119120
self.scheduler_progress.reset_on_restart()
120121
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()
122+
123+
trainer = self.trainer
124+
if not trainer.state._fault_tolerant_mode.is_enabled and trainer.num_training_batches != float("inf"):
125+
expected_steps = math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches)
126+
if self.global_step % expected_steps != 0:
127+
rank_zero_warn(
128+
"You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable"
129+
" results if further training is done. Consider using an end-of-epoch checkpoint or enabling"
130+
" fault-tolerant training:"
131+
" https://pytorch-lightning.readthedocs.io/en/stable/advanced/fault_tolerant_training.html"
132+
)
121133
else:
122134
self.batch_progress.reset_on_run()
123135
self.scheduler_progress.reset_on_run()
@@ -479,11 +491,6 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
479491

480492
# TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
481493
is_val_check_batch = is_last_batch
482-
483-
# while restarting with no fault-tolerant, batch_progress.current.ready is -1
484-
if batch_idx == -1:
485-
return False
486-
487494
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
488495
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
489496
elif self.trainer.val_check_batch != float("inf"):

pytorch_lightning/loops/fit_loop.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
import math
1615
import os
1716
from functools import partial
1817
from typing import Optional, Type
@@ -26,7 +25,6 @@
2625
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
2726
from pytorch_lightning.trainer.progress import Progress
2827
from pytorch_lightning.trainer.supporters import TensorRunningAccum
29-
from pytorch_lightning.utilities.enums import _FaultTolerantMode
3028
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3129
from pytorch_lightning.utilities.fetching import (
3230
AbstractDataFetcher,
@@ -51,7 +49,7 @@ class FitLoop(Loop[None]):
5149

5250
def __init__(
5351
self,
54-
min_epochs: Optional[int] = 1,
52+
min_epochs: int = 0,
5553
max_epochs: int = 1000,
5654
) -> None:
5755
super().__init__()
@@ -133,6 +131,21 @@ def running_loss(self) -> TensorRunningAccum:
133131
"""Returns the running loss."""
134132
return self.epoch_loop.batch_loop.running_loss
135133

134+
@Loop.restarting.setter
135+
def restarting(self, restarting: bool) -> None:
136+
# if the last epoch completely finished, we are not actually restarting, we can check this to see if all
137+
# current values are equal
138+
values = (
139+
self.epoch_progress.current.ready,
140+
self.epoch_progress.current.started,
141+
self.epoch_progress.current.processed,
142+
)
143+
finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values)
144+
if finished_before_on_train_end:
145+
self.epoch_progress.current.completed = self.epoch_progress.current.processed
146+
restarting &= finished_before_on_train_end
147+
Loop.restarting.fset(self, restarting) # call the parent setter
148+
136149
@property
137150
def _skip_backward(self) -> bool:
138151
"""Determines whether the loop will skip backward during automatic optimization."""
@@ -156,12 +169,14 @@ def done(self) -> bool:
156169
"""Evaluates when to leave the loop."""
157170
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
158171
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
159-
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.completed, self.max_epochs)
172+
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
173+
# we use it here because the checkpoint data won't have `completed` increased yet
174+
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
160175

161176
should_stop = False
162177
if self.trainer.should_stop:
163178
# early stopping
164-
met_min_epochs = self.epoch_progress.current.completed >= self.min_epochs if self.min_epochs else True
179+
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
165180
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
166181
if met_min_epochs and met_min_steps:
167182
should_stop = True
@@ -198,23 +213,6 @@ def on_run_start(self) -> None: # type: ignore[override]
198213
data_fetcher_cls = _select_data_fetcher(self.trainer)
199214
self._data_fetcher = data_fetcher_cls()
200215

201-
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
202-
if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (0, float("inf")):
203-
self.trainer.accumulate_grad_batches = self.trainer.accumulation_scheduler.get_accumulate_grad_batches(
204-
self.trainer.current_epoch
205-
)
206-
expected_steps = math.ceil(self.trainer.num_training_batches / self.trainer.accumulate_grad_batches)
207-
208-
# global_step is incremented during checkpointing (#11555)
209-
if (self.trainer.global_step - 1) % expected_steps != 0:
210-
rank_zero_warn(
211-
"You're resuming from a checkpoint that ended mid-epoch."
212-
" Training will start from the beginning of the next epoch."
213-
" This can cause unreliable results if further training is done,"
214-
" consider using an end of epoch checkpoint or use fault-tolerant training"
215-
" to restart as if training did not stop."
216-
)
217-
218216
self._is_fresh_start_epoch = True
219217
self._results.to(device=self.trainer.lightning_module.device)
220218

@@ -240,7 +238,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
240238
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
241239
):
242240
# set seed for distributed sampler (enables shuffling for each epoch)
243-
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.completed)
241+
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.processed)
244242

245243
# changing gradient according accumulation_scheduler
246244
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
@@ -325,11 +323,6 @@ def on_advance_end(self) -> None:
325323
def on_run_end(self) -> None:
326324
"""Calls the ``on_train_end`` hook."""
327325
log.detail(f"{self.__class__.__name__}: train run ended")
328-
# NOTE: the current_epoch is already incremented
329-
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
330-
# To simulate that current behavior, we decrement here.
331-
# TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
332-
self.epoch_progress.current.completed = max(self.epoch_progress.current.completed - 1, 0)
333326

334327
# hook
335328
self.trainer._call_callback_hooks("on_train_end")

pytorch_lightning/loops/utilities.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _parse_loop_limits(
7171
min_epochs: Optional[int],
7272
max_epochs: int,
7373
max_time: Optional[Union[str, timedelta, Dict[str, int]]],
74-
) -> Tuple[Optional[int], int, Optional[int], int, Optional[Union[str, timedelta, Dict[str, int]]]]:
74+
) -> Tuple[Optional[int], int, int, int, Optional[Union[str, timedelta, Dict[str, int]]]]:
7575
"""This utility computes the default values for the minimum and maximum number of steps and epochs given the
7676
values the user has selected.
7777
@@ -95,7 +95,12 @@ def _parse_loop_limits(
9595
max_epochs = 1000
9696
else:
9797
max_epochs = -1
98-
min_epochs = 1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs
98+
if min_epochs is None and min_steps is not None:
99+
# setting this allows FitLoop.done to re-evaluate should_stop when it gets triggered `on_fit_start`
100+
min_epochs = 1
101+
if min_epochs is None:
102+
# the default value is 0 so no training will be done when should_stop is triggered `on_fit_start`
103+
min_epochs = 0
99104
return min_steps, max_steps, min_epochs, max_epochs, max_time
100105

101106

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from torchmetrics import Metric
2222

2323
import pytorch_lightning as pl
24-
from pytorch_lightning.loops.utilities import _is_max_limit_reached
2524
from pytorch_lightning.plugins.environments import SLURMEnvironment
2625
from pytorch_lightning.trainer.states import TrainerFn
2726
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
@@ -230,7 +229,7 @@ def restore_loops(self) -> None:
230229
assert self.trainer.state.fn is not None
231230
state_dict = self._loaded_checkpoint.get("loops")
232231
if state_dict is not None:
233-
if self.trainer.state.fn == TrainerFn.FITTING:
232+
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
234233
self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"])
235234
elif self.trainer.state.fn == TrainerFn.VALIDATING:
236235
self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"])
@@ -329,21 +328,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
329328
LightningDataModule.__class__.__qualname__: pl DataModule's state
330329
}
331330
"""
332-
333-
# dump epoch/global_step/pytorch-lightning_version
334-
current_epoch = self.trainer.current_epoch
335-
global_step = self.trainer.global_step
336-
has_reached_max_steps = _is_max_limit_reached(global_step, self.trainer.max_steps)
337-
338-
global_step += 1
339-
if not has_reached_max_steps:
340-
current_epoch += 1
341-
342331
model = self.trainer.lightning_module
343332

344333
checkpoint = {
345-
"epoch": current_epoch,
346-
"global_step": global_step,
334+
# the epoch is saved for compatibility but it's not relevant for restoration
335+
"epoch": self.trainer.current_epoch,
336+
"global_step": self.trainer.global_step + 1,
347337
"pytorch-lightning_version": pl.__version__,
348338
"state_dict": self._get_lightning_module_state_dict(),
349339
"loops": self._get_loops_state_dict(),

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2438,7 +2438,7 @@ def max_epochs(self) -> int:
24382438
return self.fit_loop.max_epochs
24392439

24402440
@property
2441-
def min_epochs(self) -> Optional[int]:
2441+
def min_epochs(self) -> int:
24422442
return self.fit_loop.min_epochs
24432443

24442444
@property

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,8 @@ def scale_batch_size(
6060

6161
# Save initial model, that is loaded after batch size is found
6262
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
63-
trainer.fit_loop.epoch_progress.current.completed -= 1
6463
trainer.fit_loop.global_step -= 1
6564
trainer.save_checkpoint(ckpt_path)
66-
trainer.fit_loop.epoch_progress.current.completed += 1
6765
trainer.fit_loop.global_step += 1
6866
params = __scale_batch_dump_params(trainer)
6967

pytorch_lightning/tuner/lr_finder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,8 @@ def lr_find(
204204

205205
# Save initial model, that is loaded after learning rate is found
206206
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
207-
trainer.fit_loop.epoch_progress.current.completed -= 1
208207
trainer.fit_loop.global_step -= 1
209208
trainer.save_checkpoint(ckpt_path)
210-
trainer.fit_loop.epoch_progress.current.completed += 1
211209
trainer.fit_loop.global_step += 1
212210
params = __lr_finder_dump_params(trainer)
213211

0 commit comments

Comments
 (0)