From abd6895b4d136788f5e326cf9199de73406c50af Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Feb 2022 19:38:00 +0100 Subject: [PATCH 1/7] Fix current_epoch value on training end --- CHANGELOG.md | 11 ++-- .../callbacks/stochastic_weight_avg.py | 4 +- pytorch_lightning/loops/base.py | 5 +- .../loops/epoch/training_epoch_loop.py | 11 ++++ pytorch_lightning/loops/fit_loop.py | 50 ++++++++++--------- pytorch_lightning/loops/utilities.py | 9 +++- .../connectors/checkpoint_connector.py | 18 ++----- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/tuner/batch_size_scaling.py | 2 - pytorch_lightning/tuner/lr_finder.py | 2 - tests/callbacks/test_early_stopping.py | 14 +++--- tests/callbacks/test_timer.py | 12 ++--- tests/checkpointing/test_model_checkpoint.py | 23 ++++++--- tests/loops/test_loop_state_dict.py | 3 -- tests/loops/test_loops.py | 3 -- tests/loops/test_training_loop.py | 3 +- tests/models/test_hooks.py | 16 +++--- tests/models/test_restore.py | 13 ++--- .../connectors/test_checkpoint_connector.py | 1 - .../logging_/test_train_loop_logging.py | 4 +- tests/trainer/optimization/test_optimizers.py | 1 + tests/trainer/test_dataloaders.py | 5 +- tests/trainer/test_trainer.py | 12 ++--- tests/utilities/test_auto_restart.py | 1 - 24 files changed, 112 insertions(+), 113 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a1ad4da6a820..2782a4cb1d9f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -243,9 +243,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655)) -- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141)) - - - Moved `Strategy` classes to the `strategies` directory ([#11226](https://github.com/PyTorchLightning/pytorch-lightning/pull/11226)) @@ -266,6 +263,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360)) + +- The `trainer.current_epoch` value is now increased by 1 during and after `on_train_end` ([#8578](https://github.com/PyTorchLightning/pytorch-lightning/pull/8578)) + + - Inherit from `ABC` for `Accelerator`: Users need to implement `auto_device_count` ([#11521](https://github.com/PyTorchLightning/pytorch-lightning/pull/11521)) @@ -288,8 +289,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141)) + + - Deprecated `Trainer.{validated,tested,predicted}_ckpt_path` and replaced with read-only property `Trainer.ckpt_path` set when checkpoints loaded via `Trainer.{fit,validate,test,predict}` ([#11696](https://github.com/PyTorchLightning/pytorch-lightning/pull/11696)) + - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/pull/10103)) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index d18c9dcffd4b1..b8fc9b399044e 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -221,13 +221,13 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args): trainer.fit_loop._skip_backward = False def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): - if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1: + if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1: # BatchNorm epoch update. Reset state trainer.accumulate_grad_batches = self._accumulate_grad_batches trainer.num_training_batches -= 1 trainer.fit_loop.max_epochs -= 1 self.reset_momenta() - elif trainer.current_epoch == self.swa_end: + elif trainer.current_epoch - 1 == self.swa_end: # Last SWA epoch. Transfer weights from average model to pl_module self.transfer_weights(self._average_model, pl_module) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 8581fc0ae2bf2..65d0200cb77e9 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -21,7 +21,6 @@ import pytorch_lightning as pl from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import BaseProgress -from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException T = TypeVar("T") # the output type of `run` @@ -288,11 +287,9 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di destination[prefix + "state_dict"] = self.on_save_checkpoint() - # do not get the mode from `self.trainer` because it might not have been attached yet - ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled for k, v in self.__dict__.items(): key = prefix + k - if ft_enabled and isinstance(v, BaseProgress): + if isinstance(v, BaseProgress): destination[key] = v.state_dict() elif isinstance(v, Loop): v.state_dict(destination, key + ".") diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 5f0a7f4666689..8311bcac655d5 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from collections import defaultdict from typing import Any, Dict, Generator, Iterator, List, Optional, overload, Tuple, Union @@ -35,6 +36,9 @@ _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] +log = logging.getLogger(__name__) + + class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): """Runs over all batches in a dataloader (one epoch). @@ -99,6 +103,13 @@ def _is_validation_done(self) -> bool: @property def done(self) -> bool: """Evaluates when to leave the loop.""" + if self.trainer.should_stop and self.min_steps: + self.trainer.should_stop = self.global_step >= self.min_steps + if not self.trainer.should_stop: + log.info( + f"Trainer was signaled to stop but required minimum steps ({self.min_steps}) has not been met." + " Training will continue..." + ) return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop def connect( # type: ignore[override] diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 8bde81ea7018b..6f6a9dcb61fb8 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -51,7 +51,7 @@ class FitLoop(Loop[None]): def __init__( self, - min_epochs: Optional[int] = 1, + min_epochs: int = 0, max_epochs: int = 1000, ) -> None: super().__init__() @@ -133,6 +133,21 @@ def running_loss(self) -> TensorRunningAccum: """Returns the running loss.""" return self.epoch_loop.batch_loop.running_loss + @Loop.restarting.setter + def restarting(self, restarting: bool) -> None: + # if the last epoch completely finished, we are not actually restarting, we can check this to see if all + # current values are equal + values = ( + self.epoch_progress.current.ready, + self.epoch_progress.current.started, + self.epoch_progress.current.processed, + ) + finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values) + if finished_before_on_train_end: + self.epoch_progress.current.completed = self.epoch_progress.current.processed + restarting &= finished_before_on_train_end + Loop.restarting.fset(self, restarting) # call the parent setter + @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" @@ -156,31 +171,23 @@ def done(self) -> bool: """Evaluates when to leave the loop.""" # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = _is_max_limit_reached(self.global_step, self.max_steps) - stop_epochs = _is_max_limit_reached(self.epoch_progress.current.completed, self.max_epochs) - - should_stop = False - if self.trainer.should_stop: - # early stopping - met_min_epochs = self.epoch_progress.current.completed >= self.min_epochs if self.min_epochs else True - met_min_steps = self.global_step >= self.min_steps if self.min_steps else True - if met_min_epochs and met_min_steps: - should_stop = True - else: + stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs) + + if self.trainer.should_stop and self.min_epochs: + self.trainer.should_stop = self.epoch_progress.current.processed >= self.min_epochs + if not self.trainer.should_stop: log.info( - "Trainer was signaled to stop but required minimum epochs" - f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has" - " not been met. Training will continue..." + f"Trainer was signaled to stop but required minimum epochs ({self.min_epochs}) has not been met." + " Training will continue..." ) - self.trainer.should_stop = should_stop - - return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0 + return stop_steps or self.trainer.should_stop or stop_epochs or self.trainer.num_training_batches == 0 @property def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called # until `on_run_start`, we use `limit_train_batches` instead - return self.done or self.trainer.limit_train_batches == 0 + return self.trainer.limit_train_batches == 0 def connect(self, epoch_loop: TrainingEpochLoop) -> None: # type: ignore[override] """Connects a training epoch loop to this fit loop.""" @@ -240,7 +247,7 @@ def on_advance_start(self) -> None: # type: ignore[override] getattr(self.trainer.train_dataloader.sampler, "set_epoch", None) ): # set seed for distributed sampler (enables shuffling for each epoch) - self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.completed) + self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.processed) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) @@ -325,11 +332,6 @@ def on_advance_end(self) -> None: def on_run_end(self) -> None: """Calls the ``on_train_end`` hook.""" log.detail(f"{self.__class__.__name__}: train run ended") - # NOTE: the current_epoch is already incremented - # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit - # To simulate that current behavior, we decrement here. - # TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007 - self.epoch_progress.current.completed = max(self.epoch_progress.current.completed - 1, 0) # hook self.trainer._call_callback_hooks("on_train_end") diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 545e5839dff1a..13ae87fad50d1 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -71,7 +71,7 @@ def _parse_loop_limits( min_epochs: Optional[int], max_epochs: int, max_time: Optional[Union[str, timedelta, Dict[str, int]]], -) -> Tuple[Optional[int], int, Optional[int], int, Optional[Union[str, timedelta, Dict[str, int]]]]: +) -> Tuple[Optional[int], int, int, int, Optional[Union[str, timedelta, Dict[str, int]]]]: """This utility computes the default values for the minimum and maximum number of steps and epochs given the values the user has selected. @@ -95,7 +95,12 @@ def _parse_loop_limits( max_epochs = 1000 else: max_epochs = -1 - min_epochs = 1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs + if min_epochs is None and min_steps is not None: + # setting this allows FitLoop.done to re-evaluate should_stop when it gets triggered `on_fit_start` + min_epochs = 1 + if min_epochs is None: + # the default value is 0 so no training will be done when should_stop is triggered `on_fit_start` + min_epochs = 0 return min_steps, max_steps, min_epochs, max_epochs, max_time diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3560678baaa02..bb6889d797f5c 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,7 +21,6 @@ from torchmetrics import Metric import pytorch_lightning as pl -from pytorch_lightning.loops.utilities import _is_max_limit_reached from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE @@ -230,7 +229,7 @@ def restore_loops(self) -> None: assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") if state_dict is not None: - if self.trainer.state.fn == TrainerFn.FITTING: + if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) elif self.trainer.state.fn == TrainerFn.VALIDATING: self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) @@ -329,21 +328,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: LightningDataModule.__class__.__qualname__: pl DataModule's state } """ - - # dump epoch/global_step/pytorch-lightning_version - current_epoch = self.trainer.current_epoch - global_step = self.trainer.global_step - has_reached_max_steps = _is_max_limit_reached(global_step, self.trainer.max_steps) - - global_step += 1 - if not has_reached_max_steps: - current_epoch += 1 - model = self.trainer.lightning_module checkpoint = { - "epoch": current_epoch, - "global_step": global_step, + # the epoch is saved for compatibility but it's not relevant for restoration + "epoch": self.trainer.current_epoch, + "global_step": self.trainer.global_step + 1, "pytorch-lightning_version": pl.__version__, "state_dict": self._get_lightning_module_state_dict(), "loops": self._get_loops_state_dict(), diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f73e16604dcba..3663c211e4445 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2438,7 +2438,7 @@ def max_epochs(self) -> int: return self.fit_loop.max_epochs @property - def min_epochs(self) -> Optional[int]: + def min_epochs(self) -> int: return self.fit_loop.min_epochs @property diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index bf37e4357d8f7..1526e570dabe5 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -60,10 +60,8 @@ def scale_batch_size( # Save initial model, that is loaded after batch size is found ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") - trainer.fit_loop.epoch_progress.current.completed -= 1 trainer.fit_loop.global_step -= 1 trainer.save_checkpoint(ckpt_path) - trainer.fit_loop.epoch_progress.current.completed += 1 trainer.fit_loop.global_step += 1 params = __scale_batch_dump_params(trainer) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index ebfa9a1dd54b0..876ff7823b2dc 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -204,10 +204,8 @@ def lr_find( # Save initial model, that is loaded after learning rate is found ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt") - trainer.fit_loop.epoch_progress.current.completed -= 1 trainer.fit_loop.global_step -= 1 trainer.save_checkpoint(ckpt_path) - trainer.fit_loop.epoch_progress.current.completed += 1 trainer.fit_loop.global_step += 1 params = __lr_finder_dump_params(trainer) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 1f6bb158d7a93..60f1317019292 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -80,7 +80,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): # ensure state is persisted properly checkpoint = torch.load(checkpoint_filepath) # the checkpoint saves "epoch + 1" - early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] + early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"]] assert 4 == len(early_stop_callback.saved_states) es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}" assert checkpoint["callbacks"][es_name] == early_stop_callback_state @@ -143,7 +143,7 @@ def validation_epoch_end(self, outputs): enable_progress_bar=False, ) trainer.fit(model) - assert trainer.current_epoch == expected_stop_epoch + assert trainer.current_epoch - 1 == expected_stop_epoch @pytest.mark.parametrize("validation_step_none", [True, False]) @@ -179,7 +179,7 @@ def training_epoch_end(self, outputs): enable_progress_bar=False, ) trainer.fit(model) - assert trainer.current_epoch == expected_stop_epoch + assert trainer.current_epoch - 1 == expected_stop_epoch def test_pickling(tmpdir): @@ -236,7 +236,7 @@ def validation_epoch_end(self, outputs): max_epochs=20, ) trainer.fit(model) - assert trainer.current_epoch == expected_epoch, "early_stopping failed" + assert trainer.current_epoch - 1 == expected_epoch, "early_stopping failed" @pytest.mark.parametrize("stop_value", [torch.tensor(np.inf), torch.tensor(np.nan)]) @@ -260,7 +260,7 @@ def validation_epoch_end(self, outputs): max_epochs=10, ) trainer.fit(model) - assert trainer.current_epoch == expected_stop_epoch + assert trainer.current_epoch - 1 == expected_stop_epoch assert early_stopping.stopped_epoch == expected_stop_epoch @@ -388,7 +388,7 @@ def validation_epoch_end(self, outputs): self._epoch_end() def on_train_end(self) -> None: - assert self.trainer.current_epoch == self.expected_end_epoch, "Early Stopping Failed" + assert self.trainer.current_epoch - 1 == self.expected_end_epoch, "Early Stopping Failed" _ES_CHECK = dict(check_on_train_epoch_end=True) @@ -481,7 +481,7 @@ def validation_step(self, batch, batch_idx): if case == "val_check_interval": assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval) else: - assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch - 1 + assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch def test_early_stopping_squeezes(): diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py index c7c4e0458ee12..21792b430a35f 100644 --- a/tests/callbacks/test_timer.py +++ b/tests/callbacks/test_timer.py @@ -126,14 +126,8 @@ def test_timer_zero_duration_stop(tmpdir, interval): timer = Timer(duration=duration, interval=interval) trainer = Trainer(default_root_dir=tmpdir, callbacks=[timer]) trainer.fit(model) - if interval == "step": - # timer triggers stop on step end - assert trainer.global_step == 1 - assert trainer.current_epoch == 0 - else: - # timer triggers stop on epoch end - assert trainer.global_step == len(trainer.train_dataloader) - assert trainer.current_epoch == 0 + assert trainer.global_step == 0 + assert trainer.current_epoch == 0 @pytest.mark.parametrize("min_steps,min_epochs", [(None, 2), (3, None), (3, 2)]) @@ -144,7 +138,7 @@ def test_timer_duration_min_steps_override(tmpdir, min_steps, min_epochs): trainer = Trainer(default_root_dir=tmpdir, callbacks=[timer], min_steps=min_steps, min_epochs=min_epochs) trainer.fit(model) if min_epochs: - assert trainer.current_epoch >= min_epochs - 1 + assert trainer.current_epoch >= min_epochs if min_steps: assert trainer.global_step >= min_steps - 1 assert timer.time_elapsed() > duration.total_seconds() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 942f304bf9af5..252d992bd8db6 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -166,7 +166,7 @@ def on_validation_epoch_end(self): expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt" chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) - assert chk["epoch"] == epoch + 1 + assert chk["epoch"] == epoch assert chk["global_step"] == limit_train_batches * (epoch + 1) mc_specific_data = chk["callbacks"][ @@ -266,7 +266,7 @@ def _make_assertions(epoch, ix): assert math.isclose(score, expected_score, rel_tol=1e-4) chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) - assert chk["epoch"] == epoch + 1 + assert chk["epoch"] == epoch expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch) assert chk["global_step"] == expected_global_step @@ -821,7 +821,9 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) - assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] + # `-1` because this checkpoint is saved `on_train_epoch_end` which is considered part of the epoch so the + # `current_epoch` count has not been increased yet + assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] - 1 assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] ckpt_id = ( @@ -923,7 +925,9 @@ def get_last_checkpoint(ckpt_dir): def assert_checkpoint_content(ckpt_dir): chk = pl_load(get_last_checkpoint(ckpt_dir)) - assert chk["epoch"] == epochs + # `-1` because this checkpoint is saved `on_train_epoch_end` which is considered part of the epoch so the + # `current_epoch` count has not been increased yet + assert chk["epoch"] == epochs - 1 assert chk["global_step"] == 4 def assert_checkpoint_log_dir(idx): @@ -951,15 +955,15 @@ def assert_checkpoint_log_dir(idx): model = ExtendedBoringModel() trainer.fit(model) assert trainer.global_step == epochs * limit_train_batches - assert trainer.current_epoch == epochs - 1 + assert trainer.current_epoch == epochs assert_checkpoint_log_dir(0) assert_checkpoint_content(ckpt_dir) trainer.validate(model) - assert trainer.current_epoch == epochs - 1 + assert trainer.current_epoch == epochs trainer.test(model) - assert trainer.current_epoch == epochs - 1 + assert trainer.current_epoch == epochs for idx in range(1, 5): chk = get_last_checkpoint(ckpt_dir) @@ -977,14 +981,17 @@ def assert_checkpoint_log_dir(idx): trainer.fit(model, ckpt_path=chk) assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs + assert trainer.fit_loop.epoch_progress.current.processed == epochs trainer.validate(model) assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs + assert trainer.fit_loop.epoch_progress.current.processed == epochs trainer.fit(model) assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs + assert trainer.fit_loop.epoch_progress.current.processed == epochs assert_checkpoint_log_dir(idx) @@ -1171,7 +1178,7 @@ def test_ckpt_version_after_rerun_same_trainer(tmpdir): trainer.fit_loop.max_epochs = 4 trainer.fit(BoringModel()) - ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION) + ckpt_range = range(mc.STARTING_VERSION, trainer.max_epochs + mc.STARTING_VERSION - 1) expected = {"test.ckpt", *(f"test-v{i}.ckpt" for i in ckpt_range)} # check best_k_models state assert {Path(f).name for f in mc.best_k_models} == expected diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 28d96b324435f..47ccda50de7f8 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from unittest import mock from unittest.mock import Mock import pytest @@ -39,7 +37,6 @@ def test_loops_state_dict(): assert fit_loop.state_dict() == new_fit_loop.state_dict() -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_loops_state_dict_structure(): trainer = Trainer() trainer.train_dataloader = Mock() diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index b6f0c05d3fd72..3303e259692a5 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -230,7 +230,6 @@ def load_state_dict(self, state_dict: Dict) -> None: assert loop.outputs == list(range(10)) -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_loop_hierarchy(): @dataclass class SimpleProgress(BaseProgress): @@ -585,7 +584,6 @@ def configure_optimizers_multiple(self): assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @pytest.mark.parametrize("n_optimizers", (1, 3, 5)) def test_loop_state_on_complete_run(n_optimizers, tmpdir): n_epochs = 3 @@ -720,7 +718,6 @@ def train_dataloader(self): assert checkpoint["loops"]["fit_loop"] == expected -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_fit_loop_reset(tmpdir): """Test that the reset logic in fit- and epoch loop is aware of whether the loop is restarting from a completed loop or from a mid-epoch checkpoint.""" diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 86801f56266c6..8329389a93944 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -121,7 +121,8 @@ def validation_step(self, *args): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=10, limit_val_batches=1) trainer.fit(model) - assert trainer.current_epoch == 0 + # even though we stopped mid epoch, the fit loop finished normally and the current epoch was increased + assert trainer.current_epoch == 1 assert trainer.global_step == 5 assert model.validation_called_at == (0, 4) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 00ccaa3ec7c6c..0ffa1b7d1987a 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -279,11 +279,13 @@ def _train_batch(self, *args, **kwargs): return self._manual_train_batch(*args, **kwargs) @staticmethod - def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), current_epoch=0, **kwargs): + def _auto_train_batch( + trainer, model, batches, device=torch.device("cpu"), current_epoch=0, current_batch=0, **kwargs + ): using_native_amp = kwargs.get("amp_backend") == "native" using_deepspeed = kwargs.get("strategy") == "deepspeed" out = [] - for i in range(batches): + for i in range(current_batch, batches): out.extend( [ dict(name="on_before_batch_transfer", args=(ANY, 0)), @@ -483,7 +485,7 @@ def training_step(self, batch, batch_idx): trainer.fit(model) saved_ckpt = { "callbacks": ANY, - "epoch": 1, + "epoch": 0, "global_step": train_batches, "lr_schedulers": ANY, "optimizer_states": ANY, @@ -607,7 +609,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): trainer.fit(model, ckpt_path=best_model_path) loaded_ckpt = { "callbacks": ANY, - "epoch": 1, # TODO: wrong saved epoch, should be 0 + "epoch": 0, "global_step": 1, "lr_schedulers": ANY, "optimizer_states": ANY, @@ -615,8 +617,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): "state_dict": ANY, "loops": ANY, } - # TODO: wrong saved epoch, should be 0 - saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 2} + saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 1} expected = [ dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), @@ -648,8 +649,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): dict(name="on_epoch_start"), dict(name="Callback.on_train_epoch_start", args=(trainer, model)), dict(name="on_train_epoch_start"), - # TODO: wrong current epoch after reload - *model._train_batch(trainer, model, train_batches, current_epoch=1), + *model._train_batch(trainer, model, steps_after_reload, current_batch=1, current_epoch=1), dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 8e1006ce73147..35e17e513d858 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -232,8 +232,7 @@ def test_correct_step_and_epoch(tmpdir): assert trainer.global_step == 0 trainer.fit(model) - # TODO(@carmocca): should not need `-1` - assert trainer.current_epoch == first_max_epochs - 1 + assert trainer.current_epoch == first_max_epochs assert trainer.global_step == first_max_epochs * train_batches # save checkpoint after loop ends, training end called, epoch count increased @@ -260,8 +259,7 @@ def on_pretrain_routine_end(self) -> None: assert self.trainer.global_step == first_max_epochs * train_batches + 1 trainer.fit(TestModel(), ckpt_path=ckpt_path) - # TODO(@carmocca): should not need `-1` - assert trainer.current_epoch == max_epochs - 1 + assert trainer.current_epoch == max_epochs # TODO(@carmocca): should not need `+1` assert trainer.global_step == max_epochs * train_batches + 1 @@ -286,8 +284,7 @@ def on_train_epoch_end(self, *_): trainer.fit(TestModel()) trainer.fit_loop.max_epochs = 4 trainer.fit(TestModel()) - # TODO(@carmocca): 1 should not be duplicated - assert epochs == [0, 1, 1, 2, 3] + assert epochs == [0, 1, 2, 3] def test_try_resume_from_non_existing_checkpoint(tmpdir): @@ -582,8 +579,8 @@ def test_dp_resume(tmpdir): trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) - # track epoch before saving. Increment since we finished the current epoch, don't want to rerun - real_global_epoch = trainer.current_epoch + 1 + # track epoch before saving + real_global_epoch = trainer.current_epoch # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index 0547fed659fe5..0ba21dea404fa 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -153,7 +153,6 @@ def test_hpc_max_ckpt_version(tmpdir): ) -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_loops_restore(tmpdir): """Test that required loop state_dict is loaded correctly by checkpoint connector.""" model = BoringModel() diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 7699567860797..49651cf6ef0c0 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -488,8 +488,8 @@ def get_metrics(self, trainer: Trainer, model: LightningModule): def on_train_end(self, trainer: Trainer, model: LightningModule): metrics = self.get_metrics(trainer, model) - assert metrics["foo"] == self.trainer.current_epoch - assert metrics["foo_2"] == self.trainer.current_epoch + assert metrics["foo"] == self.trainer.current_epoch - 1 + assert metrics["foo_2"] == self.trainer.current_epoch - 1 model.callback_on_train_end_called = True progress_bar = TestProgressBar() diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 74abfda11b241..99071ce3d8f8a 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -416,6 +416,7 @@ def test_lr_scheduler_strict(step_mock, tmpdir, complete_epoch): } if complete_epoch: + trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs, max_steps=max_steps) with pytest.warns( RuntimeWarning, match=r"ReduceLROnPlateau conditioned on metric .* which is not available but strict" ): diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index a22dced3aa9db..27621e96f37c1 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -809,8 +809,9 @@ def gen(self): ) trainer.fit(model, train_dataloaders=train_dataloader) assert trainer.global_step == 2 * yield_at_all - # we expect the second epoch to be skipped - assert trainer.current_epoch == 1 + # even though the generator might not yield any data, the fit_loop still advances so the + # current epoch gets increased + assert trainer.current_epoch == 2 class DistribSamplerCallback(Callback): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0e2f488b4506e..6fede8a612f21 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -486,7 +486,7 @@ def test_trainer_max_steps_and_epochs(tmpdir): assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.global_step == num_train_samples * trainer.max_epochs - assert trainer.current_epoch == trainer.max_epochs - 1, "Model did not stop at max_epochs" + assert trainer.current_epoch == trainer.max_epochs, "Model did not stop at max_epochs" # if max_steps is positive and max_epochs is negative, use max_steps trainer_kwargs["max_epochs"] = -1 @@ -610,7 +610,7 @@ def training_step(self, batch, batch_idx): with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): trainer.fit(model) - message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue" + message = f"minimum epochs ({min_epochs}) has not been met. Training will continue" num_messages = sum(1 for record in caplog.records if message in record.message) assert num_messages == min_epochs - 2 assert model.training_step_invoked == min_epochs * 2 @@ -843,7 +843,7 @@ def training_epoch_end(self, *args, **kwargs): assert not torch.all(torch.eq(before_state_dict[key], after_state_dict[key])) assert trainer.state.finished, f"Training failed with {trainer.state}" - assert trainer.current_epoch == 0 + assert trainer.current_epoch == 1 assert model.training_step_invoked, "did not run `training_step` with `fast_dev_run=True`" assert model.training_epoch_end_invoked, "did not run `training_epoch_end` with `fast_dev_run=True`" @@ -880,7 +880,7 @@ def validation_epoch_end(self, *args, **kwargs): # check that limit_val_batches=0 turns off validation assert trainer.state.finished, f"Training failed with {trainer.state}" - assert trainer.current_epoch == 1 + assert trainer.current_epoch == 2 assert not model.validation_step_invoked, "`validation_step` should not run when `limit_val_batches=0`" assert not model.validation_epoch_end_invoked, "`validation_epoch_end` should not run when `limit_val_batches=0`" @@ -891,7 +891,7 @@ def validation_epoch_end(self, *args, **kwargs): trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - assert trainer.current_epoch == 0 + assert trainer.current_epoch == 1 assert model.validation_step_invoked, "did not run `validation_step` with `fast_dev_run=True`" assert model.validation_epoch_end_invoked, "did not run `validation_epoch_end` with `fast_dev_run=True`" @@ -1571,7 +1571,7 @@ def test_predict_return_predictions_cpu(return_predictions, precision, tmpdir): @pytest.mark.parametrize( ["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"], - [(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)], + [(0.2, 0, 0, 0, False), (0.5, 10, 2, 5, True)], ) def test_disabled_training_for_insufficient_limit_train_batches( tmpdir, limit_train_batches, global_step, num_training_batches, current_epoch, should_train diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index ff4db0051d6f4..5faefbc9f71b6 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -256,7 +256,6 @@ def __next__(self): @pytest.mark.parametrize( "num_workers", [0, pytest.param(1, marks=RunIf(slow=True)), pytest.param(2, marks=RunIf(slow=True))] ) -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_fast_forward_sampler_over_iterable_dataset(num_workers): """This test ensures ``FastForwardSampler`` and ``CaptureIterableDataset`` are properly being used to capture workers states.""" From 9228953e4330528756f08e0e7578c1747bb6497b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 2 Feb 2022 21:28:19 +0100 Subject: [PATCH 2/7] Remove code from #11552 --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 8311bcac655d5..5d6bf477fe360 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -490,11 +490,6 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: # TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = is_last_batch - - # while restarting with no fault-tolerant, batch_progress.current.ready is -1 - if batch_idx == -1: - return False - if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float("inf"): From 17ebb0c8ae5686c5f709041a387a9695bcb11126 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Feb 2022 06:57:18 +0100 Subject: [PATCH 3/7] Apply #11556 --- .../loops/epoch/training_epoch_loop.py | 11 ++++ pytorch_lightning/loops/fit_loop.py | 25 ++------ tests/models/test_restore.py | 59 ++++++++++++------- 3 files changed, 52 insertions(+), 43 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 5d6bf477fe360..4723525c40dfc 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import math from collections import defaultdict from typing import Any, Dict, Generator, Iterator, List, Optional, overload, Tuple, Union @@ -129,6 +130,16 @@ def reset(self) -> None: self.batch_progress.reset_on_restart() self.scheduler_progress.reset_on_restart() self.batch_loop.optimizer_loop.optim_progress.reset_on_restart() + + trainer = self.trainer + if not trainer.state._fault_tolerant_mode.is_enabled and trainer.num_training_batches != float("inf"): + expected_steps = math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches) + if self.global_step % expected_steps != 0: + rank_zero_warn( + "You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable" + "results if further training is done. Consider using an end-of-epoch checkpoint or enabling" + "fault-tolerant training." + ) else: self.batch_progress.reset_on_run() self.scheduler_progress.reset_on_run() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 6f6a9dcb61fb8..b43c8f522902d 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import math import os from functools import partial -from typing import Optional, Type +from typing import Optional +from typing import Type import pytorch_lightning as pl from pytorch_lightning.accelerators import GPUAccelerator @@ -26,7 +26,6 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import ( AbstractDataFetcher, @@ -35,7 +34,8 @@ InterBatchParallelDataFetcher, ) from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature log = logging.getLogger(__name__) @@ -205,23 +205,6 @@ def on_run_start(self) -> None: # type: ignore[override] data_fetcher_cls = _select_data_fetcher(self.trainer) self._data_fetcher = data_fetcher_cls() - ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled - if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (0, float("inf")): - self.trainer.accumulate_grad_batches = self.trainer.accumulation_scheduler.get_accumulate_grad_batches( - self.trainer.current_epoch - ) - expected_steps = math.ceil(self.trainer.num_training_batches / self.trainer.accumulate_grad_batches) - - # global_step is incremented during checkpointing (#11555) - if (self.trainer.global_step - 1) % expected_steps != 0: - rank_zero_warn( - "You're resuming from a checkpoint that ended mid-epoch." - " Training will start from the beginning of the next epoch." - " This can cause unreliable results if further training is done," - " consider using an end of epoch checkpoint or use fault-tolerant training" - " to restart as if training did not stop." - ) - self._is_fresh_start_epoch = True self._results.to(device=self.trainer.lightning_module.device) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 35e17e513d858..7b90247044f94 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -33,7 +33,7 @@ from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel -from tests.helpers.utils import no_warning_call +from tests.loops.test_loops import CustomException class ModelTrainerPropertyParity(Callback): @@ -774,44 +774,59 @@ def test_model_pickle(tmpdir): cloudpickle.dumps(model) -@pytest.mark.parametrize("stop_batch_idx", [4, 7]) -def test_restarting_mid_epoch_raises_warning(tmpdir, stop_batch_idx): - """Test that a warning is raised if training is restarted from mid-epoch.""" +class ExceptionModel(BoringModel): + def __init__(self, stop_batch_idx): + super().__init__() + self.stop_batch_idx = stop_batch_idx - class CustomModel(BoringModel): - def __init__(self, stop_batch_idx): - super().__init__() - self.stop_batch_idx = stop_batch_idx + def training_step(self, batch, batch_idx): + if batch_idx == self.stop_batch_idx: + raise CustomException() + return super().training_step(batch, batch_idx) - def training_step(self, batch, batch_idx): - if (batch_idx + 1) == self.stop_batch_idx: - self.trainer.should_stop = True - return super().training_step(batch, batch_idx) +class ShouldStopModel(ExceptionModel): + def training_step(self, batch, batch_idx): + if batch_idx == self.stop_batch_idx: + # setting should_stop is treated differently to raising an exception. + # checking both tests that this warning is raised in the correct loop + self.trainer.should_stop = True + return super().training_step(batch, batch_idx) - limit_train_batches = 7 + +@pytest.mark.parametrize("stop_in_the_middle", (True, False)) +@pytest.mark.parametrize("model_cls", (ExceptionModel, ShouldStopModel)) +def test_restarting_mid_epoch_raises_warning(tmpdir, stop_in_the_middle, model_cls): + """Test that a warning is raised if training is restarted from mid-epoch.""" + limit_train_batches = 8 trainer_kwargs = { "default_root_dir": tmpdir, "limit_train_batches": limit_train_batches, + "limit_val_batches": 0, "enable_progress_bar": False, "enable_model_summary": False, } trainer = Trainer(max_epochs=1, **trainer_kwargs) - model = CustomModel(stop_batch_idx) - trainer.fit(model) + model = model_cls(limit_train_batches // 2 if stop_in_the_middle else -1) + + if stop_in_the_middle: + with pytest.raises(CustomException): + trainer.fit(model) + else: + trainer.fit(model) ckpt_path = str(tmpdir / "resume.ckpt") trainer.save_checkpoint(ckpt_path) - trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs) + trainer = Trainer(max_epochs=2, **trainer_kwargs) + model.stop_batch_idx = -1 - warning_raised = limit_train_batches != stop_batch_idx - context_manager = pytest.warns if warning_raised else no_warning_call - with context_manager(UserWarning, match="resuming from a checkpoint that ended mid-epoch"): + context_manager = pytest.warns if stop_in_the_middle else tutils.no_warning_call + with context_manager(UserWarning, match="resuming from a checkpoint that ended"): trainer.fit(model, ckpt_path=ckpt_path) - if warning_raised: + if stop_in_the_middle: with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): - trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs) - with no_warning_call(UserWarning, match="resuming from a checkpoint that ended mid-epoch"): + trainer = Trainer(max_epochs=2, **trainer_kwargs) + with tutils.no_warning_call(UserWarning, match="resuming from a checkpoint that ended"): trainer.fit(model, ckpt_path=ckpt_path) From efdc403b801a526df2fc95765af4866c396e1bee Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 7 Feb 2022 19:11:00 +0100 Subject: [PATCH 4/7] Undo `fit_loop.done` change --- .../loops/epoch/training_epoch_loop.py | 11 ---------- pytorch_lightning/loops/fit_loop.py | 20 +++++++++++++------ tests/trainer/test_trainer.py | 2 +- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 4723525c40dfc..af3839793bc46 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging import math from collections import defaultdict from typing import Any, Dict, Generator, Iterator, List, Optional, overload, Tuple, Union @@ -37,9 +36,6 @@ _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] -log = logging.getLogger(__name__) - - class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): """Runs over all batches in a dataloader (one epoch). @@ -104,13 +100,6 @@ def _is_validation_done(self) -> bool: @property def done(self) -> bool: """Evaluates when to leave the loop.""" - if self.trainer.should_stop and self.min_steps: - self.trainer.should_stop = self.global_step >= self.min_steps - if not self.trainer.should_stop: - log.info( - f"Trainer was signaled to stop but required minimum steps ({self.min_steps}) has not been met." - " Training will continue..." - ) return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop def connect( # type: ignore[override] diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index b43c8f522902d..3f0657786bcab 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -173,14 +173,22 @@ def done(self) -> bool: stop_steps = _is_max_limit_reached(self.global_step, self.max_steps) stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs) - if self.trainer.should_stop and self.min_epochs: - self.trainer.should_stop = self.epoch_progress.current.processed >= self.min_epochs - if not self.trainer.should_stop: + should_stop = False + if self.trainer.should_stop: + # early stopping + met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True + met_min_steps = self.global_step >= self.min_steps if self.min_steps else True + if met_min_epochs and met_min_steps: + should_stop = True + else: log.info( - f"Trainer was signaled to stop but required minimum epochs ({self.min_epochs}) has not been met." - " Training will continue..." + "Trainer was signaled to stop but required minimum epochs" + f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has" + " not been met. Training will continue..." ) - return stop_steps or self.trainer.should_stop or stop_epochs or self.trainer.num_training_batches == 0 + self.trainer.should_stop = should_stop + + return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0 @property def skip(self) -> bool: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 6fede8a612f21..587ff0b7b9f72 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -610,7 +610,7 @@ def training_step(self, batch, batch_idx): with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): trainer.fit(model) - message = f"minimum epochs ({min_epochs}) has not been met. Training will continue" + message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue" num_messages = sum(1 for record in caplog.records if message in record.message) assert num_messages == min_epochs - 2 assert model.training_step_invoked == min_epochs * 2 From b2db296b270021f49fe536fd81f02e8b50a50673 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 8 Feb 2022 00:56:02 +0100 Subject: [PATCH 5/7] Update pytorch_lightning/loops/epoch/training_epoch_loop.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index af3839793bc46..167108b8f4631 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -126,8 +126,8 @@ def reset(self) -> None: if self.global_step % expected_steps != 0: rank_zero_warn( "You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable" - "results if further training is done. Consider using an end-of-epoch checkpoint or enabling" - "fault-tolerant training." + " results if further training is done. Consider using an end-of-epoch checkpoint or enabling" + " fault-tolerant training." ) else: self.batch_progress.reset_on_run() From 1be025d4742b982930378d57bdfac47041ab31a0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 10 Feb 2022 16:31:22 +0100 Subject: [PATCH 6/7] Keep `FitLoop.done` check --- pytorch_lightning/loops/fit_loop.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 3f0657786bcab..05457bcebeb50 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,8 +14,7 @@ import logging import os from functools import partial -from typing import Optional -from typing import Type +from typing import Optional, Type import pytorch_lightning as pl from pytorch_lightning.accelerators import GPUAccelerator @@ -34,8 +33,7 @@ InterBatchParallelDataFetcher, ) from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation -from pytorch_lightning.utilities.rank_zero import rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature log = logging.getLogger(__name__) @@ -195,7 +193,7 @@ def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called # until `on_run_start`, we use `limit_train_batches` instead - return self.trainer.limit_train_batches == 0 + return self.done or self.trainer.limit_train_batches == 0 def connect(self, epoch_loop: TrainingEpochLoop) -> None: # type: ignore[override] """Connects a training epoch loop to this fit loop.""" From 7afb814485af47d4941405a1261e759ec387f5dc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 10 Feb 2022 16:41:40 +0100 Subject: [PATCH 7/7] Comments requested by Thomas --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 1 + pytorch_lightning/loops/epoch/training_epoch_loop.py | 3 ++- pytorch_lightning/loops/fit_loop.py | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index b8fc9b399044e..f4b32a5bf0559 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -221,6 +221,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args): trainer.fit_loop._skip_backward = False def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + # the trainer increases the current epoch before this hook is called if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1: # BatchNorm epoch update. Reset state trainer.accumulate_grad_batches = self._accumulate_grad_batches diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 167108b8f4631..f60bfd401688f 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -127,7 +127,8 @@ def reset(self) -> None: rank_zero_warn( "You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable" " results if further training is done. Consider using an end-of-epoch checkpoint or enabling" - " fault-tolerant training." + " fault-tolerant training:" + " https://pytorch-lightning.readthedocs.io/en/stable/advanced/fault_tolerant_training.html" ) else: self.batch_progress.reset_on_run() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 05457bcebeb50..8cbe4c167a29d 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -169,6 +169,8 @@ def done(self) -> bool: """Evaluates when to leave the loop.""" # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = _is_max_limit_reached(self.global_step, self.max_steps) + # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved. + # we use it here because the checkpoint data won't have `completed` increased yet stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs) should_stop = False