Skip to content

Commit 0cb64fb

Browse files
rohitgr7carmoccaBorda
authored
Fix mid-epoch warning call while resuming (#11556)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka <[email protected]>
1 parent d43fd0d commit 0cb64fb

File tree

5 files changed

+71
-17
lines changed

5 files changed

+71
-17
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
497497
- Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))
498498

499499

500+
- Fixed the mid-epoch warning call while resuming training ([#11556](https://github.com/PyTorchLightning/pytorch-lightning/pull/11556))
501+
502+
500503
- Fixed an issue in `RichProgressbar` to display the metrics logged only on main progress bar ([#11690](https://github.com/PyTorchLightning/pytorch-lightning/pull/11690))
501504

502505

pytorch_lightning/loops/fit_loop.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import math
1516
from typing import Optional
1617

1718
from pytorch_lightning.loops import Loop
@@ -22,8 +23,10 @@
2223
from pytorch_lightning.trainer.progress import Progress
2324
from pytorch_lightning.trainer.supporters import TensorRunningAccum
2425
from pytorch_lightning.utilities import rank_zero_deprecation
26+
from pytorch_lightning.utilities.enums import _FaultTolerantMode
2527
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2628
from pytorch_lightning.utilities.model_helpers import is_overridden
29+
from pytorch_lightning.utilities.warnings import rank_zero_warn
2730

2831
log = logging.getLogger(__name__)
2932

@@ -181,6 +184,24 @@ def on_run_start(self) -> None: # type: ignore[override]
181184
"""Calls the ``on_train_start`` hook."""
182185
# reset train dataloader and val dataloader
183186
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
187+
188+
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
189+
if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (0, float("inf")):
190+
self.trainer.accumulate_grad_batches = self.trainer.accumulation_scheduler.get_accumulate_grad_batches(
191+
self.trainer.current_epoch
192+
)
193+
expected_steps = math.ceil(self.trainer.num_training_batches / self.trainer.accumulate_grad_batches)
194+
195+
# global_step is incremented during checkpointing (#11555)
196+
if (self.trainer.global_step - 1) % expected_steps != 0:
197+
rank_zero_warn(
198+
"You're resuming from a checkpoint that ended mid-epoch."
199+
" Training will start from the beginning of the next epoch."
200+
" This can cause unreliable results if further training is done,"
201+
" consider using an end of epoch checkpoint or use fault-tolerant training"
202+
" to restart as if training did not stop."
203+
)
204+
184205
self._is_fresh_start_epoch = True
185206
self._results.to(device=self.trainer.lightning_module.device)
186207
self.trainer._call_callback_hooks("on_train_start")

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytorch_lightning.loops.utilities import _is_max_limit_reached
2525
from pytorch_lightning.plugins.environments import SLURMEnvironment
2626
from pytorch_lightning.trainer.states import TrainerFn
27-
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
27+
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info
2828
from pytorch_lightning.utilities.cloud_io import get_filesystem
2929
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3030
from pytorch_lightning.utilities.imports import _fault_tolerant_training
@@ -248,21 +248,6 @@ def restore_loops(self) -> None:
248248
f" but you have set Trainer(max_epochs={self.trainer.max_epochs})."
249249
)
250250

251-
# Division deals with global step stepping once per accumulated batch
252-
# Inequality deals with different global step for odd vs even num_training_batches
253-
self.trainer.accumulate_grad_batches = self.trainer.accumulation_scheduler.get_accumulate_grad_batches(
254-
self.trainer.current_epoch
255-
)
256-
n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
257-
expected_steps = self.trainer.num_training_batches / n_accum
258-
if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1:
259-
rank_zero_warn(
260-
"You're resuming from a checkpoint that ended mid-epoch."
261-
" Training will start from the beginning of the next epoch."
262-
" This can cause unreliable results if further training is done,"
263-
" consider using an end of epoch checkpoint."
264-
)
265-
266251
def restore_optimizers_and_schedulers(self) -> None:
267252
"""Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint."""
268253
if not self._loaded_checkpoint:

pytorch_lightning/utilities/cloud_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
def load(
29-
path_or_url: Union[str, IO, Path],
29+
path_or_url: Union[IO, _PATH],
3030
map_location: Optional[
3131
Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
3232
] = None,

tests/models/test_restore.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pickle
1818
from copy import deepcopy
1919
from typing import Generic, Mapping, TypeVar
20+
from unittest import mock
2021

2122
import cloudpickle
2223
import pytest
@@ -32,6 +33,7 @@
3233
from tests.helpers.datamodules import ClassifDataModule
3334
from tests.helpers.runif import RunIf
3435
from tests.helpers.simple_models import ClassificationModel
36+
from tests.helpers.utils import no_warning_call
3537

3638

3739
class ModelTrainerPropertyParity(Callback):
@@ -776,3 +778,46 @@ def test_model_pickle(tmpdir):
776778
model = BoringModel()
777779
pickle.dumps(model)
778780
cloudpickle.dumps(model)
781+
782+
783+
@pytest.mark.parametrize("stop_batch_idx", [4, 7])
784+
def test_restarting_mid_epoch_raises_warning(tmpdir, stop_batch_idx):
785+
"""Test that a warning is raised if training is restarted from mid-epoch."""
786+
787+
class CustomModel(BoringModel):
788+
def __init__(self, stop_batch_idx):
789+
super().__init__()
790+
self.stop_batch_idx = stop_batch_idx
791+
792+
def training_step(self, batch, batch_idx):
793+
if (batch_idx + 1) == self.stop_batch_idx:
794+
self.trainer.should_stop = True
795+
796+
return super().training_step(batch, batch_idx)
797+
798+
limit_train_batches = 7
799+
trainer_kwargs = {
800+
"default_root_dir": tmpdir,
801+
"limit_train_batches": limit_train_batches,
802+
"enable_progress_bar": False,
803+
"enable_model_summary": False,
804+
}
805+
trainer = Trainer(max_epochs=1, **trainer_kwargs)
806+
model = CustomModel(stop_batch_idx)
807+
trainer.fit(model)
808+
809+
ckpt_path = str(tmpdir / "resume.ckpt")
810+
trainer.save_checkpoint(ckpt_path)
811+
812+
trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs)
813+
814+
warning_raised = limit_train_batches != stop_batch_idx
815+
context_manager = pytest.warns if warning_raised else no_warning_call
816+
with context_manager(UserWarning, match="resuming from a checkpoint that ended mid-epoch"):
817+
trainer.fit(model, ckpt_path=ckpt_path)
818+
819+
if warning_raised:
820+
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}):
821+
trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs)
822+
with no_warning_call(UserWarning, match="resuming from a checkpoint that ended mid-epoch"):
823+
trainer.fit(model, ckpt_path=ckpt_path)

0 commit comments

Comments
 (0)