Skip to content

Commit 9979454

Browse files
committed
Apply #11556
1 parent 3849964 commit 9979454

File tree

3 files changed

+49
-41
lines changed

3 files changed

+49
-41
lines changed

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import math
1516
from collections import defaultdict
1617
from typing import Any, Dict, Generator, Iterator, List, Optional, overload, Tuple, Union
1718

@@ -137,6 +138,17 @@ def reset(self) -> None:
137138
# seen per epoch, this is useful for tracking when validation is run multiple times per epoch
138139
self.val_loop.epoch_loop.batch_progress.total.reset()
139140

141+
ft_enabled = self.trainer.state._fault_tolerant_mode.is_enabled
142+
if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (0, float("inf")):
143+
expected_steps = math.ceil(self.trainer.num_training_batches / self.trainer.accumulate_grad_batches)
144+
if self.global_step % expected_steps != 0:
145+
rank_zero_warn(
146+
"You're resuming from a checkpoint that ended mid-epoch."
147+
" This can cause unreliable results if further training is done,"
148+
" consider using an end of epoch checkpoint or use fault-tolerant training"
149+
" to restart as if training did not stop."
150+
)
151+
140152
self._outputs = []
141153

142154
def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]

pytorch_lightning/loops/fit_loop.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
import math
1615
from typing import Optional
1716

1817
from pytorch_lightning.loops import Loop
@@ -22,10 +21,9 @@
2221
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
2322
from pytorch_lightning.trainer.progress import Progress
2423
from pytorch_lightning.trainer.supporters import TensorRunningAccum
25-
from pytorch_lightning.utilities.enums import _FaultTolerantMode
2624
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2725
from pytorch_lightning.utilities.model_helpers import is_overridden
28-
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
26+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
2927

3028
log = logging.getLogger(__name__)
3129

@@ -191,23 +189,6 @@ def on_run_start(self) -> None: # type: ignore[override]
191189
# reset train dataloader and val dataloader
192190
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
193191

194-
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
195-
if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (0, float("inf")):
196-
self.trainer.accumulate_grad_batches = self.trainer.accumulation_scheduler.get_accumulate_grad_batches(
197-
self.trainer.current_epoch
198-
)
199-
expected_steps = math.ceil(self.trainer.num_training_batches / self.trainer.accumulate_grad_batches)
200-
201-
# global_step is incremented during checkpointing (#11555)
202-
if (self.trainer.global_step - 1) % expected_steps != 0:
203-
rank_zero_warn(
204-
"You're resuming from a checkpoint that ended mid-epoch."
205-
" Training will start from the beginning of the next epoch."
206-
" This can cause unreliable results if further training is done,"
207-
" consider using an end of epoch checkpoint or use fault-tolerant training"
208-
" to restart as if training did not stop."
209-
)
210-
211192
self._is_fresh_start_epoch = True
212193
self._results.to(device=self.trainer.lightning_module.device)
213194
self.trainer._call_callback_hooks("on_train_start")

tests/models/test_restore.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from tests.helpers.datamodules import ClassifDataModule
3434
from tests.helpers.runif import RunIf
3535
from tests.helpers.simple_models import ClassificationModel
36-
from tests.helpers.utils import no_warning_call
36+
from tests.loops.test_loops import CustomException
3737

3838

3939
class ModelTrainerPropertyParity(Callback):
@@ -774,44 +774,59 @@ def test_model_pickle(tmpdir):
774774
cloudpickle.dumps(model)
775775

776776

777-
@pytest.mark.parametrize("stop_batch_idx", [4, 7])
778-
def test_restarting_mid_epoch_raises_warning(tmpdir, stop_batch_idx):
779-
"""Test that a warning is raised if training is restarted from mid-epoch."""
777+
class ExceptionModel(BoringModel):
778+
def __init__(self, stop_batch_idx):
779+
super().__init__()
780+
self.stop_batch_idx = stop_batch_idx
780781

781-
class CustomModel(BoringModel):
782-
def __init__(self, stop_batch_idx):
783-
super().__init__()
784-
self.stop_batch_idx = stop_batch_idx
782+
def training_step(self, batch, batch_idx):
783+
if batch_idx == self.stop_batch_idx:
784+
raise CustomException()
785+
return super().training_step(batch, batch_idx)
785786

786-
def training_step(self, batch, batch_idx):
787-
if (batch_idx + 1) == self.stop_batch_idx:
788-
self.trainer.should_stop = True
789787

790-
return super().training_step(batch, batch_idx)
788+
class ShouldStopModel(ExceptionModel):
789+
def training_step(self, batch, batch_idx):
790+
if batch_idx == self.stop_batch_idx:
791+
# setting should_stop is treated differently to raising an exception.
792+
# checking both tests that this warning is raised in the correct loop
793+
self.trainer.should_stop = True
794+
return super().training_step(batch, batch_idx)
791795

792-
limit_train_batches = 7
796+
797+
@pytest.mark.parametrize("stop_in_the_middle", (True, False))
798+
@pytest.mark.parametrize("model_cls", (ExceptionModel, ShouldStopModel))
799+
def test_restarting_mid_epoch_raises_warning(tmpdir, stop_in_the_middle, model_cls):
800+
"""Test that a warning is raised if training is restarted from mid-epoch."""
801+
limit_train_batches = 8
793802
trainer_kwargs = {
794803
"default_root_dir": tmpdir,
795804
"limit_train_batches": limit_train_batches,
805+
"limit_val_batches": 0,
796806
"enable_progress_bar": False,
797807
"enable_model_summary": False,
798808
}
799809
trainer = Trainer(max_epochs=1, **trainer_kwargs)
800-
model = CustomModel(stop_batch_idx)
801-
trainer.fit(model)
810+
model = model_cls(limit_train_batches // 2 if stop_in_the_middle else -1)
811+
812+
if stop_in_the_middle:
813+
with pytest.raises(CustomException):
814+
trainer.fit(model)
815+
else:
816+
trainer.fit(model)
802817

803818
ckpt_path = str(tmpdir / "resume.ckpt")
804819
trainer.save_checkpoint(ckpt_path)
805820

806-
trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs)
821+
trainer = Trainer(max_epochs=2, **trainer_kwargs)
822+
model.stop_batch_idx = -1
807823

808-
warning_raised = limit_train_batches != stop_batch_idx
809-
context_manager = pytest.warns if warning_raised else no_warning_call
824+
context_manager = pytest.warns if stop_in_the_middle else tutils.no_warning_call
810825
with context_manager(UserWarning, match="resuming from a checkpoint that ended mid-epoch"):
811826
trainer.fit(model, ckpt_path=ckpt_path)
812827

813-
if warning_raised:
828+
if stop_in_the_middle:
814829
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}):
815-
trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs)
816-
with no_warning_call(UserWarning, match="resuming from a checkpoint that ended mid-epoch"):
830+
trainer = Trainer(max_epochs=2, **trainer_kwargs)
831+
with tutils.no_warning_call(UserWarning, match="resuming from a checkpoint that ended mid-epoch"):
817832
trainer.fit(model, ckpt_path=ckpt_path)

0 commit comments

Comments
 (0)