Skip to content

Commit 65e7742

Browse files
committed
Fix current_epoch value on training end
1 parent 75cf898 commit 65e7742

22 files changed

+100
-106
lines changed

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,13 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
221221
trainer.fit_loop._skip_backward = False
222222

223223
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
224-
if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1:
224+
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
225225
# BatchNorm epoch update. Reset state
226226
trainer.accumulate_grad_batches = self._accumulate_grad_batches
227227
trainer.num_training_batches -= 1
228228
trainer.fit_loop.max_epochs -= 1
229229
self.reset_momenta()
230-
elif trainer.current_epoch == self.swa_end:
230+
elif trainer.current_epoch - 1 == self.swa_end:
231231
# Last SWA epoch. Transfer weights from average model to pl_module
232232
self.transfer_weights(self._average_model, pl_module)
233233

pytorch_lightning/loops/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import pytorch_lightning as pl
2222
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
2323
from pytorch_lightning.trainer.progress import BaseProgress
24-
from pytorch_lightning.utilities.enums import _FaultTolerantMode
2524
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2625

2726
T = TypeVar("T") # the output type of `run`
@@ -288,11 +287,9 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di
288287

289288
destination[prefix + "state_dict"] = self.on_save_checkpoint()
290289

291-
# do not get the mode from `self.trainer` because it might not have been attached yet
292-
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
293290
for k, v in self.__dict__.items():
294291
key = prefix + k
295-
if ft_enabled and isinstance(v, BaseProgress):
292+
if isinstance(v, BaseProgress):
296293
destination[key] = v.state_dict()
297294
elif isinstance(v, Loop):
298295
v.state_dict(destination, key + ".")

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
1415
from collections import defaultdict
1516
from typing import Any, Dict, Generator, Iterator, List, Optional, overload, Tuple, Union
1617

@@ -35,6 +36,9 @@
3536
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
3637

3738

39+
log = logging.getLogger(__name__)
40+
41+
3842
class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
3943
"""Runs over all batches in a dataloader (one epoch).
4044
@@ -99,6 +103,13 @@ def _is_validation_done(self) -> bool:
99103
@property
100104
def done(self) -> bool:
101105
"""Evaluates when to leave the loop."""
106+
if self.trainer.should_stop and self.min_steps:
107+
self.trainer.should_stop = self.global_step >= self.min_steps
108+
if not self.trainer.should_stop:
109+
log.info(
110+
f"Trainer was signaled to stop but required minimum steps ({self.min_steps}) has not been met."
111+
" Training will continue..."
112+
)
102113
return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop
103114

104115
def connect( # type: ignore[override]

pytorch_lightning/loops/fit_loop.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class FitLoop(Loop[None]):
3838

3939
def __init__(
4040
self,
41-
min_epochs: Optional[int] = 1,
41+
min_epochs: int = 0,
4242
max_epochs: int = 1000,
4343
) -> None:
4444
super().__init__()
@@ -119,6 +119,21 @@ def running_loss(self) -> TensorRunningAccum:
119119
"""Returns the running loss."""
120120
return self.epoch_loop.batch_loop.running_loss
121121

122+
@Loop.restarting.setter
123+
def restarting(self, restarting: bool) -> None:
124+
# if the last epoch completely finished, we are not actually restarting, we can check this to see if all
125+
# current values are equal
126+
values = (
127+
self.epoch_progress.current.ready,
128+
self.epoch_progress.current.started,
129+
self.epoch_progress.current.processed,
130+
)
131+
finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values)
132+
if finished_before_on_train_end:
133+
self.epoch_progress.current.completed = self.epoch_progress.current.processed
134+
restarting &= finished_before_on_train_end
135+
Loop.restarting.fset(self, restarting) # call the parent setter
136+
122137
@property
123138
def _skip_backward(self) -> bool:
124139
"""Determines whether the loop will skip backward during automatic optimization."""
@@ -142,31 +157,23 @@ def done(self) -> bool:
142157
"""Evaluates when to leave the loop."""
143158
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
144159
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
145-
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.completed, self.max_epochs)
146-
147-
should_stop = False
148-
if self.trainer.should_stop:
149-
# early stopping
150-
met_min_epochs = self.epoch_progress.current.completed >= self.min_epochs if self.min_epochs else True
151-
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
152-
if met_min_epochs and met_min_steps:
153-
should_stop = True
154-
else:
160+
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
161+
162+
if self.trainer.should_stop and self.min_epochs:
163+
self.trainer.should_stop = self.epoch_progress.current.processed >= self.min_epochs
164+
if not self.trainer.should_stop:
155165
log.info(
156-
"Trainer was signaled to stop but required minimum epochs"
157-
f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
158-
" not been met. Training will continue..."
166+
f"Trainer was signaled to stop but required minimum epochs ({self.min_epochs}) has not been met."
167+
" Training will continue..."
159168
)
160-
self.trainer.should_stop = should_stop
161-
162-
return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0
169+
return stop_steps or self.trainer.should_stop or stop_epochs or self.trainer.num_training_batches == 0
163170

164171
@property
165172
def skip(self) -> bool:
166173
"""Whether we should skip the training and immediately return from the call to :meth:`run`."""
167174
# since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
168175
# until `on_run_start`, we use `limit_train_batches` instead
169-
return self.done or self.trainer.limit_train_batches == 0
176+
return self.trainer.limit_train_batches == 0
170177

171178
def connect(self, epoch_loop: TrainingEpochLoop) -> None: # type: ignore[override]
172179
"""Connects a training epoch loop to this fit loop."""
@@ -205,7 +212,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
205212
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
206213
):
207214
# set seed for distributed sampler (enables shuffling for each epoch)
208-
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.completed)
215+
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.processed)
209216

210217
# changing gradient according accumulation_scheduler
211218
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
@@ -289,11 +296,6 @@ def on_advance_end(self) -> None:
289296
def on_run_end(self) -> None:
290297
"""Calls the ``on_train_end`` hook."""
291298
log.detail(f"{self.__class__.__name__}: train run ended")
292-
# NOTE: the current_epoch is already incremented
293-
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
294-
# To simulate that current behavior, we decrement here.
295-
# TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
296-
self.epoch_progress.current.completed = max(self.epoch_progress.current.completed - 1, 0)
297299

298300
# hook
299301
self.trainer._call_callback_hooks("on_train_end")

pytorch_lightning/loops/utilities.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _parse_loop_limits(
7171
min_epochs: Optional[int],
7272
max_epochs: int,
7373
max_time: Optional[Union[str, timedelta, Dict[str, int]]],
74-
) -> Tuple[Optional[int], int, Optional[int], int, Optional[Union[str, timedelta, Dict[str, int]]]]:
74+
) -> Tuple[Optional[int], int, int, int, Optional[Union[str, timedelta, Dict[str, int]]]]:
7575
"""This utility computes the default values for the minimum and maximum number of steps and epochs given the
7676
values the user has selected.
7777
@@ -95,7 +95,12 @@ def _parse_loop_limits(
9595
max_epochs = 1000
9696
else:
9797
max_epochs = -1
98-
min_epochs = 1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs
98+
if min_epochs is None and min_steps is not None:
99+
# setting this allows FitLoop.done to re-evaluate should_stop when it gets triggered `on_fit_start`
100+
min_epochs = 1
101+
if min_epochs is None:
102+
# the default value is 0 so no training will be done when should_stop is triggered `on_fit_start`
103+
min_epochs = 0
99104
return min_steps, max_steps, min_epochs, max_epochs, max_time
100105

101106

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from torchmetrics import Metric
2222

2323
import pytorch_lightning as pl
24-
from pytorch_lightning.loops.utilities import _is_max_limit_reached
2524
from pytorch_lightning.plugins.environments import SLURMEnvironment
2625
from pytorch_lightning.trainer.states import TrainerFn
2726
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
@@ -225,7 +224,7 @@ def restore_loops(self) -> None:
225224
assert self.trainer.state.fn is not None
226225
state_dict = self._loaded_checkpoint.get("loops")
227226
if state_dict is not None:
228-
if self.trainer.state.fn == TrainerFn.FITTING:
227+
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
229228
self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"])
230229
elif self.trainer.state.fn == TrainerFn.VALIDATING:
231230
self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"])
@@ -336,21 +335,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
336335
LightningDataModule.__class__.__name__: pl DataModule's state
337336
}
338337
"""
339-
340-
# dump epoch/global_step/pytorch-lightning_version
341-
current_epoch = self.trainer.current_epoch
342-
global_step = self.trainer.global_step
343-
has_reached_max_steps = _is_max_limit_reached(global_step, self.trainer.max_steps)
344-
345-
global_step += 1
346-
if not has_reached_max_steps:
347-
current_epoch += 1
348-
349338
model = self.trainer.lightning_module
350339

351340
checkpoint = {
352-
"epoch": current_epoch,
353-
"global_step": global_step,
341+
# the epoch is saved for compatibility but it's not relevant for restoration
342+
"epoch": self.trainer.current_epoch,
343+
"global_step": self.trainer.global_step + 1,
354344
"pytorch-lightning_version": pl.__version__,
355345
"state_dict": self._get_lightning_module_state_dict(),
356346
"loops": self._get_loops_state_dict(),

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2356,7 +2356,7 @@ def max_epochs(self) -> int:
23562356
return self.fit_loop.max_epochs
23572357

23582358
@property
2359-
def min_epochs(self) -> Optional[int]:
2359+
def min_epochs(self) -> int:
23602360
return self.fit_loop.min_epochs
23612361

23622362
@property

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,8 @@ def scale_batch_size(
6060

6161
# Save initial model, that is loaded after batch size is found
6262
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
63-
trainer.fit_loop.epoch_progress.current.completed -= 1
6463
trainer.fit_loop.global_step -= 1
6564
trainer.save_checkpoint(ckpt_path)
66-
trainer.fit_loop.epoch_progress.current.completed += 1
6765
trainer.fit_loop.global_step += 1
6866
params = __scale_batch_dump_params(trainer)
6967

pytorch_lightning/tuner/lr_finder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,8 @@ def lr_find(
204204

205205
# Save initial model, that is loaded after learning rate is found
206206
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
207-
trainer.fit_loop.epoch_progress.current.completed -= 1
208207
trainer.fit_loop.global_step -= 1
209208
trainer.save_checkpoint(ckpt_path)
210-
trainer.fit_loop.epoch_progress.current.completed += 1
211209
trainer.fit_loop.global_step += 1
212210
params = __lr_finder_dump_params(trainer)
213211

tests/callbacks/test_early_stopping.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
8080
# ensure state is persisted properly
8181
checkpoint = torch.load(checkpoint_filepath)
8282
# the checkpoint saves "epoch + 1"
83-
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
83+
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"]]
8484
assert 4 == len(early_stop_callback.saved_states)
8585
es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}"
8686
assert checkpoint["callbacks"][es_name] == early_stop_callback_state
@@ -143,7 +143,7 @@ def validation_epoch_end(self, outputs):
143143
enable_progress_bar=False,
144144
)
145145
trainer.fit(model)
146-
assert trainer.current_epoch == expected_stop_epoch
146+
assert trainer.current_epoch - 1 == expected_stop_epoch
147147

148148

149149
@pytest.mark.parametrize("validation_step_none", [True, False])
@@ -179,7 +179,7 @@ def training_epoch_end(self, outputs):
179179
enable_progress_bar=False,
180180
)
181181
trainer.fit(model)
182-
assert trainer.current_epoch == expected_stop_epoch
182+
assert trainer.current_epoch - 1 == expected_stop_epoch
183183

184184

185185
def test_pickling(tmpdir):
@@ -236,7 +236,7 @@ def validation_epoch_end(self, outputs):
236236
max_epochs=20,
237237
)
238238
trainer.fit(model)
239-
assert trainer.current_epoch == expected_epoch, "early_stopping failed"
239+
assert trainer.current_epoch - 1 == expected_epoch, "early_stopping failed"
240240

241241

242242
@pytest.mark.parametrize("stop_value", [torch.tensor(np.inf), torch.tensor(np.nan)])
@@ -260,7 +260,7 @@ def validation_epoch_end(self, outputs):
260260
max_epochs=10,
261261
)
262262
trainer.fit(model)
263-
assert trainer.current_epoch == expected_stop_epoch
263+
assert trainer.current_epoch - 1 == expected_stop_epoch
264264
assert early_stopping.stopped_epoch == expected_stop_epoch
265265

266266

@@ -388,7 +388,7 @@ def validation_epoch_end(self, outputs):
388388
self._epoch_end()
389389

390390
def on_train_end(self) -> None:
391-
assert self.trainer.current_epoch == self.expected_end_epoch, "Early Stopping Failed"
391+
assert self.trainer.current_epoch - 1 == self.expected_end_epoch, "Early Stopping Failed"
392392

393393

394394
_ES_CHECK = dict(check_on_train_epoch_end=True)
@@ -481,7 +481,7 @@ def validation_step(self, batch, batch_idx):
481481
if case == "val_check_interval":
482482
assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval)
483483
else:
484-
assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch - 1
484+
assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch
485485

486486

487487
def test_early_stopping_squeezes():

0 commit comments

Comments
 (0)