diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 97f80cc7e4c7e..fc5f8b9d11ece 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import math from collections import defaultdict from typing import Any, Dict, Generator, Iterator, List, Optional, overload, Tuple, Union @@ -33,6 +34,8 @@ from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.warnings import WarningCache +log = logging.getLogger(__name__) + _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] @@ -100,6 +103,13 @@ def _is_validation_done(self) -> bool: @property def done(self) -> bool: """Evaluates when to leave the loop.""" + if self.trainer.should_stop and self.min_steps: + self.trainer.should_stop = self.global_step >= self.min_steps + if not self.trainer.should_stop: + log.info( + f"Trainer was signaled to stop but required minimum steps ({self.min_steps}) has not been met." + " Training will continue..." + ) return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop def connect( # type: ignore[override] diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 8cbe4c167a29d..a369d863f33f2 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -167,28 +167,19 @@ def _results(self) -> _ResultCollection: @property def done(self) -> bool: """Evaluates when to leave the loop.""" - # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = _is_max_limit_reached(self.global_step, self.max_steps) # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved. # we use it here because the checkpoint data won't have `completed` increased yet stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs) - should_stop = False - if self.trainer.should_stop: - # early stopping - met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True - met_min_steps = self.global_step >= self.min_steps if self.min_steps else True - if met_min_epochs and met_min_steps: - should_stop = True - else: + if self.trainer.should_stop and self.min_epochs: + self.trainer.should_stop = self.epoch_progress.current.processed >= self.min_epochs + if not self.trainer.should_stop: log.info( - "Trainer was signaled to stop but required minimum epochs" - f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has" - " not been met. Training will continue..." + f"Trainer was signaled to stop but required minimum epochs ({self.min_epochs}) has not been met." + " Training will continue..." ) - self.trainer.should_stop = should_stop - - return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0 + return stop_steps or self.trainer.should_stop or stop_epochs or self.trainer.num_training_batches == 0 @property def skip(self) -> bool: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 587ff0b7b9f72..6fede8a612f21 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -610,7 +610,7 @@ def training_step(self, batch, batch_idx): with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): trainer.fit(model) - message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue" + message = f"minimum epochs ({min_epochs}) has not been met. Training will continue" num_messages = sum(1 for record in caplog.records if message in record.message) assert num_messages == min_epochs - 2 assert model.training_step_invoked == min_epochs * 2