Skip to content

Commit 5c9dbc3

Browse files
sailordiaryrohitgr7carmocca
authored andcommitted
Fix validation progress counter with check_val_every_n_epoch > 1 (Lightning-AI#5952)
Co-authored-by: rohitgr7 <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 8c01064 commit 5c9dbc3

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

pytorch_lightning/callbacks/progress.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,10 @@ def total_val_batches(self) -> int:
146146
validation dataloader is of infinite size.
147147
"""
148148
total_val_batches = 0
149-
if not self.trainer.disable_validation:
150-
is_val_epoch = (self.trainer.current_epoch) % self.trainer.check_val_every_n_epoch == 0
149+
if self.trainer.enable_validation:
150+
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
151151
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0
152+
152153
return total_val_batches
153154

154155
@property
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
16+
from pytorch_lightning.trainer import Trainer
17+
from pytorch_lightning.trainer.states import TrainerState
18+
from tests.helpers import BoringModel
19+
20+
21+
@pytest.mark.parametrize(
22+
'max_epochs,expected_val_loop_calls,expected_val_batches', [
23+
(1, 0, [0]),
24+
(4, 2, [0, 2, 0, 2]),
25+
(5, 2, [0, 2, 0, 2, 0]),
26+
]
27+
)
28+
def test_check_val_every_n_epoch(tmpdir, max_epochs, expected_val_loop_calls, expected_val_batches):
29+
30+
class TestModel(BoringModel):
31+
val_epoch_calls = 0
32+
val_batches = []
33+
34+
def on_train_epoch_end(self, *args, **kwargs):
35+
self.val_batches.append(self.trainer.progress_bar_callback.total_val_batches)
36+
37+
def on_validation_epoch_start(self) -> None:
38+
self.val_epoch_calls += 1
39+
40+
model = TestModel()
41+
trainer = Trainer(
42+
default_root_dir=tmpdir,
43+
max_epochs=max_epochs,
44+
num_sanity_val_steps=0,
45+
limit_val_batches=2,
46+
check_val_every_n_epoch=2,
47+
logger=False,
48+
)
49+
trainer.fit(model)
50+
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
51+
52+
assert model.val_epoch_calls == expected_val_loop_calls
53+
assert model.val_batches == expected_val_batches

0 commit comments

Comments
 (0)