Skip to content

Commit 17ebb0c

Browse files
committed
Apply #11556
1 parent 9228953 commit 17ebb0c

File tree

3 files changed

+52
-43
lines changed

3 files changed

+52
-43
lines changed

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import math
1516
from collections import defaultdict
1617
from typing import Any, Dict, Generator, Iterator, List, Optional, overload, Tuple, Union
1718

@@ -129,6 +130,16 @@ def reset(self) -> None:
129130
self.batch_progress.reset_on_restart()
130131
self.scheduler_progress.reset_on_restart()
131132
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()
133+
134+
trainer = self.trainer
135+
if not trainer.state._fault_tolerant_mode.is_enabled and trainer.num_training_batches != float("inf"):
136+
expected_steps = math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches)
137+
if self.global_step % expected_steps != 0:
138+
rank_zero_warn(
139+
"You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable"
140+
"results if further training is done. Consider using an end-of-epoch checkpoint or enabling"
141+
"fault-tolerant training."
142+
)
132143
else:
133144
self.batch_progress.reset_on_run()
134145
self.scheduler_progress.reset_on_run()

pytorch_lightning/loops/fit_loop.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
import math
1615
import os
1716
from functools import partial
18-
from typing import Optional, Type
17+
from typing import Optional
18+
from typing import Type
1919

2020
import pytorch_lightning as pl
2121
from pytorch_lightning.accelerators import GPUAccelerator
@@ -26,7 +26,6 @@
2626
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
2727
from pytorch_lightning.trainer.progress import Progress
2828
from pytorch_lightning.trainer.supporters import TensorRunningAccum
29-
from pytorch_lightning.utilities.enums import _FaultTolerantMode
3029
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3130
from pytorch_lightning.utilities.fetching import (
3231
AbstractDataFetcher,
@@ -35,7 +34,8 @@
3534
InterBatchParallelDataFetcher,
3635
)
3736
from pytorch_lightning.utilities.model_helpers import is_overridden
38-
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
37+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
38+
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
3939
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
4040

4141
log = logging.getLogger(__name__)
@@ -205,23 +205,6 @@ def on_run_start(self) -> None: # type: ignore[override]
205205
data_fetcher_cls = _select_data_fetcher(self.trainer)
206206
self._data_fetcher = data_fetcher_cls()
207207

208-
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
209-
if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (0, float("inf")):
210-
self.trainer.accumulate_grad_batches = self.trainer.accumulation_scheduler.get_accumulate_grad_batches(
211-
self.trainer.current_epoch
212-
)
213-
expected_steps = math.ceil(self.trainer.num_training_batches / self.trainer.accumulate_grad_batches)
214-
215-
# global_step is incremented during checkpointing (#11555)
216-
if (self.trainer.global_step - 1) % expected_steps != 0:
217-
rank_zero_warn(
218-
"You're resuming from a checkpoint that ended mid-epoch."
219-
" Training will start from the beginning of the next epoch."
220-
" This can cause unreliable results if further training is done,"
221-
" consider using an end of epoch checkpoint or use fault-tolerant training"
222-
" to restart as if training did not stop."
223-
)
224-
225208
self._is_fresh_start_epoch = True
226209
self._results.to(device=self.trainer.lightning_module.device)
227210

tests/models/test_restore.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from tests.helpers.datamodules import ClassifDataModule
3434
from tests.helpers.runif import RunIf
3535
from tests.helpers.simple_models import ClassificationModel
36-
from tests.helpers.utils import no_warning_call
36+
from tests.loops.test_loops import CustomException
3737

3838

3939
class ModelTrainerPropertyParity(Callback):
@@ -774,44 +774,59 @@ def test_model_pickle(tmpdir):
774774
cloudpickle.dumps(model)
775775

776776

777-
@pytest.mark.parametrize("stop_batch_idx", [4, 7])
778-
def test_restarting_mid_epoch_raises_warning(tmpdir, stop_batch_idx):
779-
"""Test that a warning is raised if training is restarted from mid-epoch."""
777+
class ExceptionModel(BoringModel):
778+
def __init__(self, stop_batch_idx):
779+
super().__init__()
780+
self.stop_batch_idx = stop_batch_idx
780781

781-
class CustomModel(BoringModel):
782-
def __init__(self, stop_batch_idx):
783-
super().__init__()
784-
self.stop_batch_idx = stop_batch_idx
782+
def training_step(self, batch, batch_idx):
783+
if batch_idx == self.stop_batch_idx:
784+
raise CustomException()
785+
return super().training_step(batch, batch_idx)
785786

786-
def training_step(self, batch, batch_idx):
787-
if (batch_idx + 1) == self.stop_batch_idx:
788-
self.trainer.should_stop = True
789787

790-
return super().training_step(batch, batch_idx)
788+
class ShouldStopModel(ExceptionModel):
789+
def training_step(self, batch, batch_idx):
790+
if batch_idx == self.stop_batch_idx:
791+
# setting should_stop is treated differently to raising an exception.
792+
# checking both tests that this warning is raised in the correct loop
793+
self.trainer.should_stop = True
794+
return super().training_step(batch, batch_idx)
791795

792-
limit_train_batches = 7
796+
797+
@pytest.mark.parametrize("stop_in_the_middle", (True, False))
798+
@pytest.mark.parametrize("model_cls", (ExceptionModel, ShouldStopModel))
799+
def test_restarting_mid_epoch_raises_warning(tmpdir, stop_in_the_middle, model_cls):
800+
"""Test that a warning is raised if training is restarted from mid-epoch."""
801+
limit_train_batches = 8
793802
trainer_kwargs = {
794803
"default_root_dir": tmpdir,
795804
"limit_train_batches": limit_train_batches,
805+
"limit_val_batches": 0,
796806
"enable_progress_bar": False,
797807
"enable_model_summary": False,
798808
}
799809
trainer = Trainer(max_epochs=1, **trainer_kwargs)
800-
model = CustomModel(stop_batch_idx)
801-
trainer.fit(model)
810+
model = model_cls(limit_train_batches // 2 if stop_in_the_middle else -1)
811+
812+
if stop_in_the_middle:
813+
with pytest.raises(CustomException):
814+
trainer.fit(model)
815+
else:
816+
trainer.fit(model)
802817

803818
ckpt_path = str(tmpdir / "resume.ckpt")
804819
trainer.save_checkpoint(ckpt_path)
805820

806-
trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs)
821+
trainer = Trainer(max_epochs=2, **trainer_kwargs)
822+
model.stop_batch_idx = -1
807823

808-
warning_raised = limit_train_batches != stop_batch_idx
809-
context_manager = pytest.warns if warning_raised else no_warning_call
810-
with context_manager(UserWarning, match="resuming from a checkpoint that ended mid-epoch"):
824+
context_manager = pytest.warns if stop_in_the_middle else tutils.no_warning_call
825+
with context_manager(UserWarning, match="resuming from a checkpoint that ended"):
811826
trainer.fit(model, ckpt_path=ckpt_path)
812827

813-
if warning_raised:
828+
if stop_in_the_middle:
814829
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}):
815-
trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs)
816-
with no_warning_call(UserWarning, match="resuming from a checkpoint that ended mid-epoch"):
830+
trainer = Trainer(max_epochs=2, **trainer_kwargs)
831+
with tutils.no_warning_call(UserWarning, match="resuming from a checkpoint that ended"):
817832
trainer.fit(model, ckpt_path=ckpt_path)

0 commit comments

Comments
 (0)