Skip to content

Commit a628d18

Browse files
authored
Fix val_progress_bar total with num_sanity_val_steps (#3751)
* Fix val_progress_bar total with num_sanity_val_steps * chlog * Fix val_progress_bar total with num_sanity_val_steps * move test * replaced with sanity flag and suggestions
1 parent 4da240e commit a628d18

File tree

4 files changed

+41
-6
lines changed

4 files changed

+41
-6
lines changed

CHANGELOG.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9393

9494
- Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335))
9595

96-
- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630))
96+
- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([#3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630))
9797

98-
- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735))
98+
- Fixed `val_progress_bar` total with `num_sanity_val_steps` ([#3751](https://github.com/PyTorchLightning/pytorch-lightning/pull/3751))
99+
100+
- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([#3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735))
99101

100102
- Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764))
101103

pytorch_lightning/callbacks/progress.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,9 @@ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_id
340340

341341
def on_validation_start(self, trainer, pl_module):
342342
super().on_validation_start(trainer, pl_module)
343-
self.val_progress_bar = self.init_validation_tqdm()
344-
self.val_progress_bar.total = convert_inf(self.total_val_batches)
343+
if not trainer.running_sanity_check:
344+
self.val_progress_bar = self.init_validation_tqdm()
345+
self.val_progress_bar.total = convert_inf(self.total_val_batches)
345346

346347
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
347348
super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)

tests/callbacks/test_progress_bar.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,37 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx
193193

194194
trainer.test(model)
195195
assert progress_bar.test_batches_seen == progress_bar.total_test_batches
196+
197+
198+
@pytest.mark.parametrize(['limit_val_batches', 'expected'], [
199+
pytest.param(0, 0),
200+
pytest.param(5, 7),
201+
])
202+
def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches, expected):
203+
"""
204+
Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument.
205+
"""
206+
class CurrentProgressBar(ProgressBar):
207+
def __init__(self):
208+
super().__init__()
209+
self.val_progress_bar_total = 0
210+
211+
def on_validation_epoch_end(self, trainer, pl_module):
212+
self.val_progress_bar_total += trainer.progress_bar_callback.val_progress_bar.total
213+
214+
model = EvalModelTemplate()
215+
progress_bar = CurrentProgressBar()
216+
217+
trainer = Trainer(
218+
default_root_dir=tmpdir,
219+
max_epochs=1,
220+
num_sanity_val_steps=2,
221+
limit_train_batches=0,
222+
limit_val_batches=limit_val_batches,
223+
callbacks=[progress_bar],
224+
logger=False,
225+
checkpoint_callback=False,
226+
early_stop_callback=False,
227+
)
228+
trainer.fit(model)
229+
assert trainer.progress_bar_callback.val_progress_bar_total == expected

tests/trainer/test_trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,6 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
957957
max_steps=1,
958958
)
959959
assert trainer.num_sanity_val_steps == num_sanity_val_steps
960-
val_dataloaders = model.val_dataloader__multiple_mixed_length()
961960

962961

963962
@pytest.mark.parametrize(['limit_val_batches'], [
@@ -981,7 +980,6 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
981980
max_steps=1,
982981
)
983982
assert trainer.num_sanity_val_steps == float('inf')
984-
val_dataloaders = model.val_dataloader__multiple()
985983

986984

987985
@pytest.mark.parametrize("trainer_kwargs,expected", [

0 commit comments

Comments
 (0)