From 4800d462e26ae57b51093a322e4cf338674026d3 Mon Sep 17 00:00:00 2001 From: Yuan-Hang Zhang Date: Sat, 13 Feb 2021 11:07:27 +0800 Subject: [PATCH 1/4] Fix validation counter #5039 --- pytorch_lightning/callbacks/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 74e57e2b5642e..024b85e282dda 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -150,7 +150,7 @@ def total_val_batches(self) -> int: """ total_val_batches = 0 if not self.trainer.disable_validation: - is_val_epoch = (self.trainer.current_epoch) % self.trainer.check_val_every_n_epoch == 0 + is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 return total_val_batches From d4a6e83f3e546d0ad35a566ab59f23bd1df99fb2 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 13 Mar 2021 01:03:58 +0530 Subject: [PATCH 2/4] add test --- .pre-commit-config.yaml | 4 +- pytorch_lightning/callbacks/progress.py | 3 +- .../flags/test_check_val_every_n_epoch.py | 56 +++++++++++++++++++ 3 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 tests/trainer/flags/test_check_val_every_n_epoch.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21c52539a890d..ef5fdc8101045 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ default_language_version: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v3.4.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -35,6 +35,6 @@ repos: args: [--parallel, --in-place] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.790 + rev: v0.812 hooks: - id: mypy diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 024b85e282dda..489ef081bcc7c 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -149,9 +149,10 @@ def total_val_batches(self) -> int: validation dataloader is of infinite size. """ total_val_batches = 0 - if not self.trainer.disable_validation: + if self.trainer.enable_validation: is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 + return total_val_batches @property diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py new file mode 100644 index 0000000000000..e5d18e31c8837 --- /dev/null +++ b/tests/trainer/flags/test_check_val_every_n_epoch.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.trainer.states import TrainerState +from tests.helpers import BoringModel + + +@pytest.mark.parametrize( + 'max_epochs,expected_val_loop_calls,expected_val_batches', [ + (1, 0, [0]), + (4, 2, [0, 2, 0, 2]), + (5, 2, [0, 2, 0, 2, 0]), + ] +) +def test_check_val_every_n_epoch(tmpdir, max_epochs, expected_val_loop_calls, expected_val_batches): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.val_epoch_calls = 0 + self.val_batches = [] + + def on_train_epoch_end(self, *args, **kwargs): + self.val_batches.append(self.trainer.progress_bar_callback.total_val_batches) + + def on_validation_epoch_start(self) -> None: + self.val_epoch_calls += 1 + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=max_epochs, + num_sanity_val_steps=0, + limit_val_batches=2, + check_val_every_n_epoch=2, + logger=False, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + assert model.val_epoch_calls == expected_val_loop_calls + assert model.val_batches == expected_val_batches From 13936518e08430e9b56d7997106572a5cf8009c5 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Sat, 13 Mar 2021 01:10:19 +0530 Subject: [PATCH 3/4] Apply suggestions from code review --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ef5fdc8101045..21c52539a890d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ default_language_version: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.4.0 + rev: v2.3.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -35,6 +35,6 @@ repos: args: [--parallel, --in-place] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.812 + rev: v0.790 hooks: - id: mypy From e1b09c58a49d292035bb53d1c08e0a1bf8d3d52d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sun, 14 Mar 2021 02:16:52 +0100 Subject: [PATCH 4/4] Update tests/trainer/flags/test_check_val_every_n_epoch.py --- tests/trainer/flags/test_check_val_every_n_epoch.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py index e5d18e31c8837..f7f1403ecdbfd 100644 --- a/tests/trainer/flags/test_check_val_every_n_epoch.py +++ b/tests/trainer/flags/test_check_val_every_n_epoch.py @@ -28,11 +28,8 @@ def test_check_val_every_n_epoch(tmpdir, max_epochs, expected_val_loop_calls, expected_val_batches): class TestModel(BoringModel): - - def __init__(self): - super().__init__() - self.val_epoch_calls = 0 - self.val_batches = [] + val_epoch_calls = 0 + val_batches = [] def on_train_epoch_end(self, *args, **kwargs): self.val_batches.append(self.trainer.progress_bar_callback.total_val_batches)