Skip to content

Commit f9bb47e

Browse files
committed
Fix current_epoch value on training end
1 parent 43a89eb commit f9bb47e

24 files changed

+114
-113
lines changed

CHANGELOG.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
224224
- DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655))
225225

226226

227-
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
228-
229-
230227
- Moved `Strategy` classes to the `strategies` directory ([#11226](https://github.com/PyTorchLightning/pytorch-lightning/pull/11226))
231228

232229

@@ -247,6 +244,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
247244

248245
- Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360))
249246

247+
248+
- The `trainer.current_epoch` value is now increased by 1 during and after `on_train_end` ([#8578](https://github.com/PyTorchLightning/pytorch-lightning/pull/8578))
249+
250+
250251
- Inherit from `ABC` for `Accelerator`: Users need to implement `auto_device_count` ([#11521](https://github.com/PyTorchLightning/pytorch-lightning/pull/11521))
251252

252253

@@ -261,11 +262,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
261262

262263
- When using DP (data-parallel), Lightning will no longer automatically reduce all tensors returned in training_step; it will only reduce the loss unless `training_step_end` is overridden ([#11594](https://github.com/PyTorchLightning/pytorch-lightning/pull/11594))
263264

265+
264266
- When using DP (data-parallel), the `training_epoch_end` hook will no longer receive reduced outputs from `training_step` and instead get the full tensor of results from all GPUs ([#11594](https://github.com/PyTorchLightning/pytorch-lightning/pull/11594))
267+
265268
### Deprecated
266269

270+
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
271+
272+
267273
- Deprecated `Trainer.{validated,tested,predicted}_ckpt_path` and replaced with read-only property `Trainer.ckpt_path` set when checkpoints loaded via `Trainer.{fit,validate,test,predict}` ([#11696](https://github.com/PyTorchLightning/pytorch-lightning/pull/11696))
268274

275+
269276
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/pull/10103))
270277

271278

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,13 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
221221
trainer.fit_loop._skip_backward = False
222222

223223
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
224-
if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1:
224+
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
225225
# BatchNorm epoch update. Reset state
226226
trainer.accumulate_grad_batches = self._accumulate_grad_batches
227227
trainer.num_training_batches -= 1
228228
trainer.fit_loop.max_epochs -= 1
229229
self.reset_momenta()
230-
elif trainer.current_epoch == self.swa_end:
230+
elif trainer.current_epoch - 1 == self.swa_end:
231231
# Last SWA epoch. Transfer weights from average model to pl_module
232232
self.transfer_weights(self._average_model, pl_module)
233233

pytorch_lightning/loops/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import pytorch_lightning as pl
2222
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
2323
from pytorch_lightning.trainer.progress import BaseProgress
24-
from pytorch_lightning.utilities.enums import _FaultTolerantMode
2524
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2625

2726
T = TypeVar("T") # the output type of `run`
@@ -288,11 +287,9 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di
288287

289288
destination[prefix + "state_dict"] = self.on_save_checkpoint()
290289

291-
# do not get the mode from `self.trainer` because it might not have been attached yet
292-
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
293290
for k, v in self.__dict__.items():
294291
key = prefix + k
295-
if ft_enabled and isinstance(v, BaseProgress):
292+
if isinstance(v, BaseProgress):
296293
destination[key] = v.state_dict()
297294
elif isinstance(v, Loop):
298295
v.state_dict(destination, key + ".")

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
1415
from collections import defaultdict
1516
from typing import Any, Dict, Generator, Iterator, List, Optional, overload, Tuple, Union
1617

@@ -35,6 +36,9 @@
3536
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
3637

3738

39+
log = logging.getLogger(__name__)
40+
41+
3842
class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
3943
"""Runs over all batches in a dataloader (one epoch).
4044
@@ -99,6 +103,13 @@ def _is_validation_done(self) -> bool:
99103
@property
100104
def done(self) -> bool:
101105
"""Evaluates when to leave the loop."""
106+
if self.trainer.should_stop and self.min_steps:
107+
self.trainer.should_stop = self.global_step >= self.min_steps
108+
if not self.trainer.should_stop:
109+
log.info(
110+
f"Trainer was signaled to stop but required minimum steps ({self.min_steps}) has not been met."
111+
" Training will continue..."
112+
)
102113
return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop
103114

104115
def connect( # type: ignore[override]

pytorch_lightning/loops/fit_loop.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class FitLoop(Loop[None]):
4040

4141
def __init__(
4242
self,
43-
min_epochs: Optional[int] = 1,
43+
min_epochs: int = 0,
4444
max_epochs: int = 1000,
4545
) -> None:
4646
super().__init__()
@@ -121,6 +121,21 @@ def running_loss(self) -> TensorRunningAccum:
121121
"""Returns the running loss."""
122122
return self.epoch_loop.batch_loop.running_loss
123123

124+
@Loop.restarting.setter
125+
def restarting(self, restarting: bool) -> None:
126+
# if the last epoch completely finished, we are not actually restarting, we can check this to see if all
127+
# current values are equal
128+
values = (
129+
self.epoch_progress.current.ready,
130+
self.epoch_progress.current.started,
131+
self.epoch_progress.current.processed,
132+
)
133+
finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values)
134+
if finished_before_on_train_end:
135+
self.epoch_progress.current.completed = self.epoch_progress.current.processed
136+
restarting &= finished_before_on_train_end
137+
Loop.restarting.fset(self, restarting) # call the parent setter
138+
124139
@property
125140
def _skip_backward(self) -> bool:
126141
"""Determines whether the loop will skip backward during automatic optimization."""
@@ -144,31 +159,23 @@ def done(self) -> bool:
144159
"""Evaluates when to leave the loop."""
145160
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
146161
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
147-
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.completed, self.max_epochs)
148-
149-
should_stop = False
150-
if self.trainer.should_stop:
151-
# early stopping
152-
met_min_epochs = self.epoch_progress.current.completed >= self.min_epochs if self.min_epochs else True
153-
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
154-
if met_min_epochs and met_min_steps:
155-
should_stop = True
156-
else:
162+
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
163+
164+
if self.trainer.should_stop and self.min_epochs:
165+
self.trainer.should_stop = self.epoch_progress.current.processed >= self.min_epochs
166+
if not self.trainer.should_stop:
157167
log.info(
158-
"Trainer was signaled to stop but required minimum epochs"
159-
f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
160-
" not been met. Training will continue..."
168+
f"Trainer was signaled to stop but required minimum epochs ({self.min_epochs}) has not been met."
169+
" Training will continue..."
161170
)
162-
self.trainer.should_stop = should_stop
163-
164-
return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0
171+
return stop_steps or self.trainer.should_stop or stop_epochs or self.trainer.num_training_batches == 0
165172

166173
@property
167174
def skip(self) -> bool:
168175
"""Whether we should skip the training and immediately return from the call to :meth:`run`."""
169176
# since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
170177
# until `on_run_start`, we use `limit_train_batches` instead
171-
return self.done or self.trainer.limit_train_batches == 0
178+
return self.trainer.limit_train_batches == 0
172179

173180
def connect(self, epoch_loop: TrainingEpochLoop) -> None: # type: ignore[override]
174181
"""Connects a training epoch loop to this fit loop."""
@@ -225,7 +232,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
225232
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
226233
):
227234
# set seed for distributed sampler (enables shuffling for each epoch)
228-
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.completed)
235+
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.processed)
229236

230237
# changing gradient according accumulation_scheduler
231238
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
@@ -309,11 +316,6 @@ def on_advance_end(self) -> None:
309316
def on_run_end(self) -> None:
310317
"""Calls the ``on_train_end`` hook."""
311318
log.detail(f"{self.__class__.__name__}: train run ended")
312-
# NOTE: the current_epoch is already incremented
313-
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
314-
# To simulate that current behavior, we decrement here.
315-
# TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
316-
self.epoch_progress.current.completed = max(self.epoch_progress.current.completed - 1, 0)
317319

318320
# hook
319321
self.trainer._call_callback_hooks("on_train_end")

pytorch_lightning/loops/utilities.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _parse_loop_limits(
7171
min_epochs: Optional[int],
7272
max_epochs: int,
7373
max_time: Optional[Union[str, timedelta, Dict[str, int]]],
74-
) -> Tuple[Optional[int], int, Optional[int], int, Optional[Union[str, timedelta, Dict[str, int]]]]:
74+
) -> Tuple[Optional[int], int, int, int, Optional[Union[str, timedelta, Dict[str, int]]]]:
7575
"""This utility computes the default values for the minimum and maximum number of steps and epochs given the
7676
values the user has selected.
7777
@@ -95,7 +95,12 @@ def _parse_loop_limits(
9595
max_epochs = 1000
9696
else:
9797
max_epochs = -1
98-
min_epochs = 1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs
98+
if min_epochs is None and min_steps is not None:
99+
# setting this allows FitLoop.done to re-evaluate should_stop when it gets triggered `on_fit_start`
100+
min_epochs = 1
101+
if min_epochs is None:
102+
# the default value is 0 so no training will be done when should_stop is triggered `on_fit_start`
103+
min_epochs = 0
99104
return min_steps, max_steps, min_epochs, max_epochs, max_time
100105

101106

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from torchmetrics import Metric
2222

2323
import pytorch_lightning as pl
24-
from pytorch_lightning.loops.utilities import _is_max_limit_reached
2524
from pytorch_lightning.plugins.environments import SLURMEnvironment
2625
from pytorch_lightning.trainer.states import TrainerFn
2726
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
@@ -228,7 +227,7 @@ def restore_loops(self) -> None:
228227
assert self.trainer.state.fn is not None
229228
state_dict = self._loaded_checkpoint.get("loops")
230229
if state_dict is not None:
231-
if self.trainer.state.fn == TrainerFn.FITTING:
230+
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
232231
self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"])
233232
elif self.trainer.state.fn == TrainerFn.VALIDATING:
234233
self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"])
@@ -327,21 +326,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
327326
LightningDataModule.__class__.__name__: pl DataModule's state
328327
}
329328
"""
330-
331-
# dump epoch/global_step/pytorch-lightning_version
332-
current_epoch = self.trainer.current_epoch
333-
global_step = self.trainer.global_step
334-
has_reached_max_steps = _is_max_limit_reached(global_step, self.trainer.max_steps)
335-
336-
global_step += 1
337-
if not has_reached_max_steps:
338-
current_epoch += 1
339-
340329
model = self.trainer.lightning_module
341330

342331
checkpoint = {
343-
"epoch": current_epoch,
344-
"global_step": global_step,
332+
# the epoch is saved for compatibility but it's not relevant for restoration
333+
"epoch": self.trainer.current_epoch,
334+
"global_step": self.trainer.global_step + 1,
345335
"pytorch-lightning_version": pl.__version__,
346336
"state_dict": self._get_lightning_module_state_dict(),
347337
"loops": self._get_loops_state_dict(),

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2439,7 +2439,7 @@ def max_epochs(self) -> int:
24392439
return self.fit_loop.max_epochs
24402440

24412441
@property
2442-
def min_epochs(self) -> Optional[int]:
2442+
def min_epochs(self) -> int:
24432443
return self.fit_loop.min_epochs
24442444

24452445
@property

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,8 @@ def scale_batch_size(
6060

6161
# Save initial model, that is loaded after batch size is found
6262
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
63-
trainer.fit_loop.epoch_progress.current.completed -= 1
6463
trainer.fit_loop.global_step -= 1
6564
trainer.save_checkpoint(ckpt_path)
66-
trainer.fit_loop.epoch_progress.current.completed += 1
6765
trainer.fit_loop.global_step += 1
6866
params = __scale_batch_dump_params(trainer)
6967

pytorch_lightning/tuner/lr_finder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,8 @@ def lr_find(
204204

205205
# Save initial model, that is loaded after learning rate is found
206206
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
207-
trainer.fit_loop.epoch_progress.current.completed -= 1
208207
trainer.fit_loop.global_step -= 1
209208
trainer.save_checkpoint(ckpt_path)
210-
trainer.fit_loop.epoch_progress.current.completed += 1
211209
trainer.fit_loop.global_step += 1
212210
params = __lr_finder_dump_params(trainer)
213211

0 commit comments

Comments
 (0)