From 2f41fde732a8e5b684ed5590fbfcd4798ae1bfe6 Mon Sep 17 00:00:00 2001 From: Jinyoung Lim Date: Tue, 28 Sep 2021 22:43:16 -0700 Subject: [PATCH 01/10] Added a warning and related unit tests for non-early-stopping condition but when max steps or epochs are met. --- pytorch_lightning/loops/fit_loop.py | 9 +++++++++ tests/loops/test_training_loop.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 9a4f7c510f303..8c9c2a77f649b 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -20,6 +20,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException log = logging.getLogger(__name__) @@ -159,6 +160,14 @@ def done(self) -> bool: f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has" " not been met. Training will continue..." ) + else: + if stop_steps or stop_epochs: + rank_zero_warn( + f"Trainer not signaled to stop but met maximum number of steps ({self.max_steps}) or" + f" epochs ({self.max_epochs}). If this was on purpose, ignore this warning... Otherwise," + " please increase maximum number of epochs or steps." + ) + self.trainer.should_stop = should_stop return stop_steps or should_stop or stop_epochs diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index ae36e56495f93..6650965f70604 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -14,9 +14,9 @@ import pytest import torch - from pytorch_lightning import seed_everything, Trainer from tests.helpers import BoringModel +from tests.helpers.utils import no_warning_call def test_outputs_format(tmpdir): @@ -126,7 +126,32 @@ def validation_step(self, *args): assert trainer.global_step == 5 assert model.validation_called_at == (0, 4) - +@pytest.mark.parametrize(["max_epochs", "current_epoch"],[(1, 0), (1, 1), (1, 2)]) +def test_warning_no_early_stoppoing_and_max_epochs(tmpdir, max_epochs, current_epoch): + """Test that training stops early with max epoch being reached.""" + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs) + trainer.fit_loop.current_epoch = current_epoch + if max_epochs > current_epoch: + with no_warning_call(UserWarning, match=r"Trainer not signaled to stop but met maximum number of steps "): + trainer.fit(model) + else: + with pytest.warns(UserWarning, match=r"Trainer not signaled to stop but met maximum number of steps "): + trainer.fit(model) + +@pytest.mark.parametrize(["max_steps", "global_steps"],[(1, 0), (1, 1), (1, 2)]) +def test_warning_no_early_stoppoing_and_max_steps(tmpdir, max_steps, global_steps): + """Test that training stops early with max steps being reached.""" + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_steps=max_steps) + trainer.fit_loop.global_step = global_steps + if max_steps > global_steps: + with no_warning_call(UserWarning, match=r"Trainer not signaled to stop but met maximum number of steps "): + trainer.fit(model) + else: + with pytest.warns(UserWarning, match=r"Trainer not signaled to stop but met maximum number of steps "): + trainer.fit(model) + def test_warning_valid_train_step_end(tmpdir): class ValidTrainStepEndModel(BoringModel): def training_step(self, batch, batch_idx): From 79cd9a7d78545f1a97691399fe7fcbd7728d5491 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Sep 2021 05:50:54 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/loops/test_training_loop.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 6650965f70604..b03cd75eccbc2 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -14,6 +14,7 @@ import pytest import torch + from pytorch_lightning import seed_everything, Trainer from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call @@ -126,7 +127,8 @@ def validation_step(self, *args): assert trainer.global_step == 5 assert model.validation_called_at == (0, 4) -@pytest.mark.parametrize(["max_epochs", "current_epoch"],[(1, 0), (1, 1), (1, 2)]) + +@pytest.mark.parametrize(["max_epochs", "current_epoch"], [(1, 0), (1, 1), (1, 2)]) def test_warning_no_early_stoppoing_and_max_epochs(tmpdir, max_epochs, current_epoch): """Test that training stops early with max epoch being reached.""" model = BoringModel() @@ -139,7 +141,8 @@ def test_warning_no_early_stoppoing_and_max_epochs(tmpdir, max_epochs, current_e with pytest.warns(UserWarning, match=r"Trainer not signaled to stop but met maximum number of steps "): trainer.fit(model) -@pytest.mark.parametrize(["max_steps", "global_steps"],[(1, 0), (1, 1), (1, 2)]) + +@pytest.mark.parametrize(["max_steps", "global_steps"], [(1, 0), (1, 1), (1, 2)]) def test_warning_no_early_stoppoing_and_max_steps(tmpdir, max_steps, global_steps): """Test that training stops early with max steps being reached.""" model = BoringModel() @@ -151,7 +154,8 @@ def test_warning_no_early_stoppoing_and_max_steps(tmpdir, max_steps, global_step else: with pytest.warns(UserWarning, match=r"Trainer not signaled to stop but met maximum number of steps "): trainer.fit(model) - + + def test_warning_valid_train_step_end(tmpdir): class ValidTrainStepEndModel(BoringModel): def training_step(self, batch, batch_idx): From ee6d00a79252e928a1dc4d5c47e148391dc10701 Mon Sep 17 00:00:00 2001 From: Jinyoung Lim Date: Tue, 28 Sep 2021 23:01:22 -0700 Subject: [PATCH 03/10] Updated the CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2462ffbdbb773..a2e0843452037 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -145,6 +145,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546)) +- Added a warning to notify `FitLoop` stopping when early stopping conditions are not met ([#9749](https://github.com/PyTorchLightning/pytorch-lightning/pull/9749)) + + ### Changed - `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)). From d99e7c682ebc5d27dc92414d747095810d54bc52 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 12 Oct 2021 11:02:37 +0100 Subject: [PATCH 04/10] typo --- CHANGELOG.md | 7 ++----- tests/loops/test_training_loop.py | 4 ++-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3492fb75d6b6b..c3d02cf3af529 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -157,19 +157,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546)) -- Added a warning to notify `FitLoop` stopping when early stopping conditions are not met ([#9749](https://github.com/PyTorchLightning/pytorch-lightning/pull/9749)) - - - Added `enable_progress_bar` to Trainer constructor ([#9664](https://github.com/PyTorchLightning/pytorch-lightning/pull/9664)) - Added `pl_legacy_patch` load utility for loading old checkpoints that have pickled legacy Lightning attributes ([#9166](https://github.com/PyTorchLightning/pytorch-lightning/pull/9166)) -- Added a warning to notify `FitLoop` stopping when early stopping conditions are not met ([#9749](https://github.com/PyTorchLightning/pytorch-lightning/pull/9749)) +- Added support for `torch.use_deterministic_algorithms` ([#9121](https://github.com/PyTorchLightning/pytorch-lightning/pull/9121)) -- Added support for `torch.use_deterministic_algorithms` ([#9121](https://github.com/PyTorchLightning/pytorch-lightning/pull/9121)) +- Added a warning to notify `FitLoop` stopping when early stopping conditions are not met ([#9749](https://github.com/PyTorchLightning/pytorch-lightning/pull/9749)) ### Changed diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index fdbe03db4d9b2..924ceb5f3212d 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -129,7 +129,7 @@ def validation_step(self, *args): @pytest.mark.parametrize(["max_epochs", "current_epoch"], [(1, 0), (1, 1), (1, 2)]) -def test_warning_no_early_stoppoing_and_max_epochs(tmpdir, max_epochs, current_epoch): +def test_warning_no_early_stopping_and_max_epochs(tmpdir, max_epochs, current_epoch): """Test that training stops early with max epoch being reached.""" model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs) @@ -143,7 +143,7 @@ def test_warning_no_early_stoppoing_and_max_epochs(tmpdir, max_epochs, current_e @pytest.mark.parametrize(["max_steps", "global_steps"], [(1, 0), (1, 1), (1, 2)]) -def test_warning_no_early_stoppoing_and_max_steps(tmpdir, max_steps, global_steps): +def test_warning_no_early_stopping_and_max_steps(tmpdir, max_steps, global_steps): """Test that training stops early with max steps being reached.""" model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, max_steps=max_steps) From d1560e10e5036865dfac39e3924150679e1ceee2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 21 Jul 2022 01:26:28 +0200 Subject: [PATCH 05/10] Update and add tests --- src/pytorch_lightning/loops/fit_loop.py | 29 +++++--- .../tests_pytorch/loops/test_training_loop.py | 72 ++++++++++++------- 2 files changed, 64 insertions(+), 37 deletions(-) diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index 8b54579a6bbfb..6ff1380f136f0 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -33,7 +33,7 @@ InterBatchParallelDataFetcher, ) from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_warn, rank_zero_info, rank_zero_debug from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature log = logging.getLogger(__name__) @@ -150,31 +150,40 @@ def _results(self) -> _ResultCollection: @property def done(self) -> bool: """Evaluates when to leave the loop.""" + if self.trainer.num_training_batches == 0: + rank_zero_info(f"`Trainer.fit` stopped: No training batches.") + return True + # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps) + if stop_steps: + rank_zero_info(f"`Trainer.fit` stopped: `max_steps={self.max_steps!r}` reached.") + return True + # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved. # we use it here because the checkpoint data won't have `completed` increased yet stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs) if stop_epochs: # in case they are not equal, override so `trainer.current_epoch` has the expected value self.epoch_progress.current.completed = self.epoch_progress.current.processed + rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.") + return True - should_stop = False if self.trainer.should_stop: # early stopping met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: - should_stop = True + self.trainer.should_stop = True + rank_zero_debug(f"`Trainer.fit` stopped: `should_stop` was set.") + return True else: - log.info( - "Trainer was signaled to stop but required minimum epochs" - f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has" - " not been met. Training will continue..." + rank_zero_info( + f"Trainer was signaled to stop but the required `min_epochs={self.min_epochs!r}` or" + f" `min_steps={self.min_steps!r}` has not been met. Training will continue..." ) - self.trainer.should_stop = should_stop - - return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0 + self.trainer.should_stop = False + return False @property def skip(self) -> bool: diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index ed779ba96fcfa..c77eec39a9e8a 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -11,12 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from unittest.mock import Mock + import pytest import torch from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.demos.boring_classes import BoringModel -from tests_pytorch.helpers.utils import no_warning_call +from pytorch_lightning.loops import FitLoop def test_outputs_format(tmpdir): @@ -137,32 +140,47 @@ def validation_step(self, *args): assert model.validation_called_at == (0, 5) -@pytest.mark.parametrize(["max_epochs", "current_epoch"], [(1, 0), (1, 1), (1, 2)]) -def test_warning_no_early_stopping_and_max_epochs(tmpdir, max_epochs, current_epoch): - """Test that training stops early with max epoch being reached.""" - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs) - trainer.fit_loop.current_epoch = current_epoch - if max_epochs > current_epoch: - with no_warning_call(UserWarning, match=r"Trainer not signaled to stop but met maximum number of steps "): - trainer.fit(model) - else: - with pytest.warns(UserWarning, match=r"Trainer not signaled to stop but met maximum number of steps "): - trainer.fit(model) - - -@pytest.mark.parametrize(["max_steps", "global_steps"], [(1, 0), (1, 1), (1, 2)]) -def test_warning_no_early_stopping_and_max_steps(tmpdir, max_steps, global_steps): - """Test that training stops early with max steps being reached.""" - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_steps=max_steps) - trainer.fit_loop.global_step = global_steps - if max_steps > global_steps: - with no_warning_call(UserWarning, match=r"Trainer not signaled to stop but met maximum number of steps "): - trainer.fit(model) - else: - with pytest.warns(UserWarning, match=r"Trainer not signaled to stop but met maximum number of steps "): - trainer.fit(model) +def test_fit_loop_done_log_messages(caplog): + fit_loop = FitLoop() + trainer = Mock(spec=Trainer) + fit_loop.trainer = trainer + + trainer.should_stop = False + trainer.num_training_batches = 5 + assert not fit_loop.done + assert not caplog.messages + + trainer.num_training_batches = 0 + assert fit_loop.done + assert 'No training batches' in caplog.text + caplog.clear() + trainer.num_training_batches = 5 + + epoch_loop = Mock() + epoch_loop.global_step = 10 + fit_loop.connect(epoch_loop=epoch_loop) + fit_loop.max_steps = 10 + assert fit_loop.done + assert 'max_steps=10` reached' in caplog.text + caplog.clear() + fit_loop.max_steps = 20 + + fit_loop.epoch_progress.current.processed = 3 + fit_loop.max_epochs = 3 + trainer.should_stop = True + assert fit_loop.done + assert 'max_epochs=3` reached' in caplog.text + caplog.clear() + fit_loop.max_epochs = 5 + + fit_loop.epoch_loop.min_steps = 0 + with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"): + assert fit_loop.done + assert 'should_stop` was set' in caplog.text + + fit_loop.epoch_loop.min_steps = 100 + assert not fit_loop.done + assert 'was signaled to stop but' in caplog.text def test_warning_valid_train_step_end(tmpdir): From 5f01ecc908b452fe3fc67783aa2db24b79e623f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 21 Jul 2022 01:28:31 +0200 Subject: [PATCH 06/10] CHANGELOG --- src/pytorch_lightning/CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 09617cee59d6a..096caed5aa33b 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `XLAEnvironment` cluster environment plugin ([#11330](https://github.com/PyTorchLightning/pytorch-lightning/pull/11330)) +- Added logging messages to notify when `FitLoop` stopping conditions are met ([#9749](https://github.com/PyTorchLightning/pytorch-lightning/pull/9749)) + + - Added support for calling unknown methods with `DummyLogger` ([#13224](https://github.com/PyTorchLightning/pytorch-lightning/pull/13224) @@ -141,9 +144,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604)) -- Added a warning to notify `FitLoop` stopping when early stopping conditions are not met ([#9749](https://github.com/PyTorchLightning/pytorch-lightning/pull/9749)) - - - Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646)) From 457e9db691d7053cae35bf5ed953336f1e48c502 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Jul 2022 23:30:12 +0000 Subject: [PATCH 07/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/loops/fit_loop.py | 2 +- tests/tests_pytorch/loops/test_training_loop.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index 6ff1380f136f0..ff9a32ed53ae4 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -33,7 +33,7 @@ InterBatchParallelDataFetcher, ) from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_warn, rank_zero_info, rank_zero_debug +from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature log = logging.getLogger(__name__) diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index c77eec39a9e8a..a9da6dcf2be6d 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -152,7 +152,7 @@ def test_fit_loop_done_log_messages(caplog): trainer.num_training_batches = 0 assert fit_loop.done - assert 'No training batches' in caplog.text + assert "No training batches" in caplog.text caplog.clear() trainer.num_training_batches = 5 @@ -161,7 +161,7 @@ def test_fit_loop_done_log_messages(caplog): fit_loop.connect(epoch_loop=epoch_loop) fit_loop.max_steps = 10 assert fit_loop.done - assert 'max_steps=10` reached' in caplog.text + assert "max_steps=10` reached" in caplog.text caplog.clear() fit_loop.max_steps = 20 @@ -169,18 +169,18 @@ def test_fit_loop_done_log_messages(caplog): fit_loop.max_epochs = 3 trainer.should_stop = True assert fit_loop.done - assert 'max_epochs=3` reached' in caplog.text + assert "max_epochs=3` reached" in caplog.text caplog.clear() fit_loop.max_epochs = 5 fit_loop.epoch_loop.min_steps = 0 with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"): assert fit_loop.done - assert 'should_stop` was set' in caplog.text + assert "should_stop` was set" in caplog.text fit_loop.epoch_loop.min_steps = 100 assert not fit_loop.done - assert 'was signaled to stop but' in caplog.text + assert "was signaled to stop but" in caplog.text def test_warning_valid_train_step_end(tmpdir): From b22adf161f84e446a7b93ae0b3ad72bfdaad80cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 21 Jul 2022 15:55:52 +0200 Subject: [PATCH 08/10] Suggestion --- src/pytorch_lightning/loops/fit_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index ff9a32ed53ae4..43a4ebc6ef88c 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -175,7 +175,7 @@ def done(self) -> bool: met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: self.trainer.should_stop = True - rank_zero_debug(f"`Trainer.fit` stopped: `should_stop` was set.") + rank_zero_debug(f"`Trainer.fit` stopped: `trainer.should_stop` was set.") return True else: rank_zero_info( From ea30f201ed5691d65ca58be93483ccaf5147937f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 21 Jul 2022 15:57:20 +0200 Subject: [PATCH 09/10] Update test --- tests/tests_pytorch/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index c46c0168db558..18a51ff399b67 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -615,7 +615,7 @@ def training_step(self, batch, batch_idx): with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"): trainer.fit(model) - message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue" + message = f"min_epochs={min_epochs}` or `min_steps=None` has not been met. Training will continue" num_messages = sum(1 for record in caplog.records if message in record.message) assert num_messages == min_epochs - 2 assert model.training_step_invoked == min_epochs * 2 From 522355ef4761f9b92d613908d13de53cd87b145c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 21 Jul 2022 17:08:16 +0200 Subject: [PATCH 10/10] pre-commit --- src/pytorch_lightning/loops/fit_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index 43a4ebc6ef88c..f4f7735f4b66e 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -151,7 +151,7 @@ def _results(self) -> _ResultCollection: def done(self) -> bool: """Evaluates when to leave the loop.""" if self.trainer.num_training_batches == 0: - rank_zero_info(f"`Trainer.fit` stopped: No training batches.") + rank_zero_info("`Trainer.fit` stopped: No training batches.") return True # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop @@ -175,7 +175,7 @@ def done(self) -> bool: met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: self.trainer.should_stop = True - rank_zero_debug(f"`Trainer.fit` stopped: `trainer.should_stop` was set.") + rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.") return True else: rank_zero_info(