Skip to content

Commit e451fa2

Browse files
authored
Fix main progress bar counter when val_check_interval=int and check_val_every_n_epoch=None (#12832)
1 parent cd20699 commit e451fa2

File tree

11 files changed

+85
-32
lines changed

11 files changed

+85
-32
lines changed

docs/source-pytorch/common/trainer.rst

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,7 +1479,8 @@ How often within one training epoch to check the validation set.
14791479
Can specify as float or int.
14801480

14811481
- pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch.
1482-
- pass an ``int`` to check after a fixed number of training batches.
1482+
- pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training
1483+
batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or iteration-based training.
14831484

14841485
.. testcode::
14851486

@@ -1489,10 +1490,13 @@ Can specify as float or int.
14891490
# check validation set 4 times during a training epoch
14901491
trainer = Trainer(val_check_interval=0.25)
14911492

1492-
# check validation set every 1000 training batches
1493+
# check validation set every 1000 training batches in the current epoch
1494+
trainer = Trainer(val_check_interval=1000)
1495+
1496+
# check validation set every 1000 training batches across complete epochs or during iteration-based training
14931497
# use this when using iterableDataset and your dataset has no length
14941498
# (ie: production cases with streaming data)
1495-
trainer = Trainer(val_check_interval=1000)
1499+
trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None)
14961500

14971501

14981502
.. code-block:: python

src/pytorch_lightning/CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
144144
- Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646))
145145

146146

147-
-
147+
- Updated `val_check_interval`(int) to consider total train batches processed instead of `_batches_that_stepped` for validation check during training ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)
148148

149149

150150
### Deprecated
@@ -345,6 +345,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
345345
- Fixed `Trainer.predict(return_predictions=False)` to track prediction's batch_indices ([#13629](https://github.com/Lightning-AI/lightning/pull/13629))
346346

347347

348+
- Fixed main progress bar counter when `val_check_interval=int` and `check_val_every_n_epoch=None` ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)
349+
350+
348351
## [1.6.5] - 2022-07-13
349352

350353
### Fixed

src/pytorch_lightning/callbacks/progress/base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,27 @@ def total_val_batches(self) -> Union[int, float]:
172172
assert self._trainer is not None
173173
return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0
174174

175+
@property
176+
def total_batches_current_epoch(self) -> Union[int, float]:
177+
total_train_batches = self.total_train_batches
178+
total_val_batches = self.total_val_batches
179+
assert self._trainer is not None
180+
181+
if total_train_batches != float("inf") and total_val_batches != float("inf"):
182+
# val can be checked multiple times per epoch
183+
val_check_batch = self.trainer.val_check_batch
184+
if self.trainer.check_val_every_n_epoch is None:
185+
train_batches_processed = self.trainer.fit_loop.total_batch_idx + 1
186+
val_checks_per_epoch = ((train_batches_processed + total_train_batches) // val_check_batch) - (
187+
train_batches_processed // val_check_batch
188+
)
189+
else:
190+
val_checks_per_epoch = total_train_batches // val_check_batch
191+
192+
total_val_batches = total_val_batches * val_checks_per_epoch
193+
194+
return total_train_batches + total_val_batches
195+
175196
def has_dataloader_changed(self, dataloader_idx: int) -> bool:
176197
old_dataloader_idx = self._current_eval_dataloader_idx
177198
self._current_eval_dataloader_idx = dataloader_idx

src/pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -324,16 +324,9 @@ def on_sanity_check_end(self, trainer, pl_module):
324324
self.refresh()
325325

326326
def on_train_epoch_start(self, trainer, pl_module):
327-
total_train_batches = self.total_train_batches
328-
total_val_batches = self.total_val_batches
329-
if total_train_batches != float("inf"):
330-
# val can be checked multiple times per epoch
331-
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
332-
total_val_batches = total_val_batches * val_checks_per_epoch
333-
334-
total_batches = total_train_batches + total_val_batches
335-
327+
total_batches = self.total_batches_current_epoch
336328
train_description = self._get_train_description(trainer.current_epoch)
329+
337330
if self.main_progress_bar_id is not None and self._leave:
338331
self._stop_progress()
339332
self._init_progress(trainer)
@@ -343,6 +336,7 @@ def on_train_epoch_start(self, trainer, pl_module):
343336
self.progress.reset(
344337
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
345338
)
339+
346340
self.refresh()
347341

348342
def on_validation_batch_start(

src/pytorch_lightning/callbacks/progress/tqdm_progress.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,13 +252,7 @@ def on_train_start(self, *_: Any) -> None:
252252
self.main_progress_bar = self.init_train_tqdm()
253253

254254
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
255-
total_train_batches = self.total_train_batches
256-
total_val_batches = self.total_val_batches
257-
if total_train_batches != float("inf") and total_val_batches != float("inf"):
258-
# val can be checked multiple times per epoch
259-
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
260-
total_val_batches = total_val_batches * val_checks_per_epoch
261-
total_batches = total_train_batches + total_val_batches
255+
total_batches = self.total_batches_current_epoch
262256
self.main_progress_bar.reset(convert_inf(total_batches))
263257
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
264258

src/pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
163163
Raises:
164164
StopIteration: When the epoch is canceled by the user returning -1
165165
"""
166-
if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch):
166+
if self.restarting and self._should_check_val_fx():
167167
# skip training and run validation in `on_advance_end`
168168
return
169169
# we are going to train first so the val loop does not need to restart
@@ -235,7 +235,7 @@ def on_advance_end(self) -> None:
235235
# -----------------------------------------
236236
# VALIDATE IF NEEDED
237237
# -----------------------------------------
238-
should_check_val = self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch)
238+
should_check_val = self._should_check_val_fx()
239239
if should_check_val:
240240
self.trainer.validating = True
241241
self._run_validation()
@@ -496,13 +496,14 @@ def _should_check_val_epoch(self) -> bool:
496496
or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
497497
)
498498

499-
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
499+
def _should_check_val_fx(self) -> bool:
500500
"""Decide if we should run validation."""
501501
if not self._should_check_val_epoch():
502502
return False
503503

504504
# val_check_batch is inf for iterable datasets with no length defined
505505
is_infinite_dataset = self.trainer.val_check_batch == float("inf")
506+
is_last_batch = self.batch_progress.is_last_batch
506507
if is_last_batch and is_infinite_dataset:
507508
return True
508509

@@ -512,13 +513,11 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
512513
# TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
513514
is_val_check_batch = is_last_batch
514515
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
515-
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
516+
is_val_check_batch = (self.batch_idx + 1) % self.trainer.limit_train_batches == 0
516517
elif self.trainer.val_check_batch != float("inf"):
517518
# if `check_val_every_n_epoch is `None`, run a validation loop every n training batches
518519
# else condition it based on the batch_idx of the current epoch
519-
current_iteration = (
520-
self._batches_that_stepped if self.trainer.check_val_every_n_epoch is None else batch_idx
521-
)
520+
current_iteration = self.total_batch_idx if self.trainer.check_val_every_n_epoch is None else self.batch_idx
522521
is_val_check_batch = (current_iteration + 1) % self.trainer.val_check_batch == 0
523522

524523
return is_val_check_batch

src/pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,8 @@ def __init__(
394394
val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
395395
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
396396
batches. An ``int`` value can only be higher than the number of training batches when
397-
``check_val_every_n_epoch=None``.
397+
``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
398+
across epochs or during iteration-based training.
398399
Default: ``1.0``.
399400
400401
enable_model_summary: Whether to enable model summarization by default.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pytorch_lightning.demos.boring_classes import BoringModel
15+
from pytorch_lightning.trainer.trainer import Trainer
16+
17+
18+
def test_main_progress_bar_with_val_check_interval_int():
19+
"""Test the main progress bar count when val_check_interval=int and check_val_every_n_epoch=None."""
20+
train_batches = 5
21+
trainer = Trainer(
22+
limit_train_batches=train_batches, limit_val_batches=10, val_check_interval=3, check_val_every_n_epoch=None
23+
)
24+
model = BoringModel()
25+
trainer.progress_bar_callback.setup(trainer, model)
26+
trainer.strategy.connect(model)
27+
trainer._data_connector.attach_data(model)
28+
trainer.reset_train_dataloader()
29+
trainer.reset_val_dataloader()
30+
expected = [15, 25, 25, 15]
31+
32+
for count in expected:
33+
assert trainer.progress_bar_callback.total_batches_current_epoch == count
34+
trainer.fit_loop.epoch_loop.batch_progress.total.ready += train_batches

0 commit comments

Comments
 (0)