diff --git a/CHANGELOG.md b/CHANGELOG.md index c76175858f42e..6da61b6098c6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,9 +85,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335)) -- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630)) +- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([#3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630)) -- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735)) +- Fixed `val_progress_bar` total with `num_sanity_val_steps` ([#3751](https://github.com/PyTorchLightning/pytorch-lightning/pull/3751)) + +- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([#3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735)) - Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764)) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 9bffc9883a932..3db81fe322faf 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -340,8 +340,9 @@ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_id def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) - self.val_progress_bar = self.init_validation_tqdm() - self.val_progress_bar.total = convert_inf(self.total_val_batches) + if not trainer.running_sanity_check: + self.val_progress_bar = self.init_validation_tqdm() + self.val_progress_bar.total = convert_inf(self.total_val_batches) def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 713bdf3c3c2c4..91eecdcf37b19 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -193,3 +193,37 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx trainer.test(model) assert progress_bar.test_batches_seen == progress_bar.total_test_batches + + +@pytest.mark.parametrize(['limit_val_batches', 'expected'], [ + pytest.param(0, 0), + pytest.param(5, 7), +]) +def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches, expected): + """ + Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument. + """ + class CurrentProgressBar(ProgressBar): + def __init__(self): + super().__init__() + self.val_progress_bar_total = 0 + + def on_validation_epoch_end(self, trainer, pl_module): + self.val_progress_bar_total += trainer.progress_bar_callback.val_progress_bar.total + + model = EvalModelTemplate() + progress_bar = CurrentProgressBar() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + num_sanity_val_steps=2, + limit_train_batches=0, + limit_val_batches=limit_val_batches, + callbacks=[progress_bar], + logger=False, + checkpoint_callback=False, + early_stop_callback=False, + ) + trainer.fit(model) + assert trainer.progress_bar_callback.val_progress_bar_total == expected diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d27a701cfae47..cca5a71b6e053 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -956,7 +956,6 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): max_steps=1, ) assert trainer.num_sanity_val_steps == num_sanity_val_steps - val_dataloaders = model.val_dataloader__multiple_mixed_length() @pytest.mark.parametrize(['limit_val_batches'], [ @@ -980,7 +979,6 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): max_steps=1, ) assert trainer.num_sanity_val_steps == float('inf') - val_dataloaders = model.val_dataloader__multiple() @pytest.mark.parametrize("trainer_kwargs,expected", [