From 70cd0bf032331cc8fccc6a3beffe37cb882667c8 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 1 Oct 2020 00:08:08 +0530 Subject: [PATCH 1/5] Fix val_progress_bar total with num_sanity_val_steps --- pytorch_lightning/callbacks/progress.py | 6 +++-- tests/trainer/test_trainer.py | 35 +++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 9bffc9883a932..a3f7a41634420 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -314,6 +314,7 @@ def on_sanity_check_end(self, trainer, pl_module): super().on_sanity_check_end(trainer, pl_module) self.main_progress_bar.close() self.val_progress_bar.close() + self.val_progress_bar = None def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) @@ -340,8 +341,9 @@ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_id def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) - self.val_progress_bar = self.init_validation_tqdm() - self.val_progress_bar.total = convert_inf(self.total_val_batches) + if self.val_progress_bar is None: + self.val_progress_bar = self.init_validation_tqdm() + self.val_progress_bar.total = convert_inf(self.total_val_batches) def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d27a701cfae47..9f073fa3610f3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -956,7 +956,6 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): max_steps=1, ) assert trainer.num_sanity_val_steps == num_sanity_val_steps - val_dataloaders = model.val_dataloader__multiple_mixed_length() @pytest.mark.parametrize(['limit_val_batches'], [ @@ -980,7 +979,39 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): max_steps=1, ) assert trainer.num_sanity_val_steps == float('inf') - val_dataloaders = model.val_dataloader__multiple() + + +@pytest.mark.parametrize(['limit_val_batches', 'expected'], [ + pytest.param(0, 0), + pytest.param(5, 7), +]) +def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches, expected): + """ + Test val_progress_bar total with "num_sanity_val_steps" Trainer argument. + """ + class CustomCallback(Callback): + def __init__(self): + self.val_progress_bar_total = 0 + + def on_validation_epoch_end(self, trainer, pl_module): + self.val_progress_bar_total += trainer.progress_bar_callback.val_progress_bar.total + + model = EvalModelTemplate() + cb = CustomCallback() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + num_sanity_val_steps=2, + limit_train_batches=0, + limit_val_batches=limit_val_batches, + callbacks=[cb], + logger=False, + checkpoint_callback=False, + early_stop_callback=False, + ) + trainer.fit(model) + assert cb.val_progress_bar_total == expected @pytest.mark.parametrize("trainer_kwargs,expected", [ From 05b4d935da2d18dce1a1cdac61d7cffa9ec100c2 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 1 Oct 2020 00:12:06 +0530 Subject: [PATCH 2/5] chlog --- CHANGELOG.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c76175858f42e..6da61b6098c6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,9 +85,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335)) -- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630)) +- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([#3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630)) -- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735)) +- Fixed `val_progress_bar` total with `num_sanity_val_steps` ([#3751](https://github.com/PyTorchLightning/pytorch-lightning/pull/3751)) + +- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([#3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735)) - Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764)) From 9648bfd7c137580278c2c0239a9f4bc0c6605691 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 1 Oct 2020 01:21:51 +0530 Subject: [PATCH 3/5] Fix val_progress_bar total with num_sanity_val_steps --- pytorch_lightning/callbacks/progress.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index a3f7a41634420..43229047b76cc 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -220,6 +220,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0): self._refresh_rate = refresh_rate self._process_position = process_position self._enabled = True + self._sanity_check_completed = True self.main_progress_bar = None self.val_progress_bar = None self.test_progress_bar = None @@ -305,6 +306,7 @@ def init_test_tqdm(self) -> tqdm: return bar def on_sanity_check_start(self, trainer, pl_module): + self._sanity_check_completed = False super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches)) @@ -314,7 +316,7 @@ def on_sanity_check_end(self, trainer, pl_module): super().on_sanity_check_end(trainer, pl_module) self.main_progress_bar.close() self.val_progress_bar.close() - self.val_progress_bar = None + self._sanity_check_completed = True def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) @@ -341,7 +343,7 @@ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_id def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) - if self.val_progress_bar is None: + if self._sanity_check_completed is True: self.val_progress_bar = self.init_validation_tqdm() self.val_progress_bar.total = convert_inf(self.total_val_batches) From a735292d5c05b3e1790833cf2bf0194e53c22420 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 2 Oct 2020 02:26:47 +0530 Subject: [PATCH 4/5] move test --- tests/callbacks/test_progress_bar.py | 35 +++++++++++++++++++++++++++- tests/trainer/test_trainer.py | 33 -------------------------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 713bdf3c3c2c4..7ce4bec24bc79 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -1,7 +1,7 @@ import pytest from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint +from pytorch_lightning.callbacks import Callback, ProgressBarBase, ProgressBar, ModelCheckpoint from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -193,3 +193,36 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx trainer.test(model) assert progress_bar.test_batches_seen == progress_bar.total_test_batches + + +@pytest.mark.parametrize(['limit_val_batches', 'expected'], [ + pytest.param(0, 0), + pytest.param(5, 7), +]) +def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches, expected): + """ + Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument. + """ + class CustomCallback(Callback): + def __init__(self): + self.val_progress_bar_total = 0 + + def on_validation_epoch_end(self, trainer, pl_module): + self.val_progress_bar_total += trainer.progress_bar_callback.val_progress_bar.total + + model = EvalModelTemplate() + cb = CustomCallback() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + num_sanity_val_steps=2, + limit_train_batches=0, + limit_val_batches=limit_val_batches, + callbacks=[cb], + logger=False, + checkpoint_callback=False, + early_stop_callback=False, + ) + trainer.fit(model) + assert cb.val_progress_bar_total == expected diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9f073fa3610f3..cca5a71b6e053 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -981,39 +981,6 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): assert trainer.num_sanity_val_steps == float('inf') -@pytest.mark.parametrize(['limit_val_batches', 'expected'], [ - pytest.param(0, 0), - pytest.param(5, 7), -]) -def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches, expected): - """ - Test val_progress_bar total with "num_sanity_val_steps" Trainer argument. - """ - class CustomCallback(Callback): - def __init__(self): - self.val_progress_bar_total = 0 - - def on_validation_epoch_end(self, trainer, pl_module): - self.val_progress_bar_total += trainer.progress_bar_callback.val_progress_bar.total - - model = EvalModelTemplate() - cb = CustomCallback() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - num_sanity_val_steps=2, - limit_train_batches=0, - limit_val_batches=limit_val_batches, - callbacks=[cb], - logger=False, - checkpoint_callback=False, - early_stop_callback=False, - ) - trainer.fit(model) - assert cb.val_progress_bar_total == expected - - @pytest.mark.parametrize("trainer_kwargs,expected", [ pytest.param( dict(distributed_backend=None, gpus=None), From 90d3e8dbb5ce70ebd2aa514fb12c32e543681687 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 2 Oct 2020 16:58:49 +0530 Subject: [PATCH 5/5] replaced with sanity flag and suggestions --- pytorch_lightning/callbacks/progress.py | 5 +---- tests/callbacks/test_progress_bar.py | 11 ++++++----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 43229047b76cc..3db81fe322faf 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -220,7 +220,6 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0): self._refresh_rate = refresh_rate self._process_position = process_position self._enabled = True - self._sanity_check_completed = True self.main_progress_bar = None self.val_progress_bar = None self.test_progress_bar = None @@ -306,7 +305,6 @@ def init_test_tqdm(self) -> tqdm: return bar def on_sanity_check_start(self, trainer, pl_module): - self._sanity_check_completed = False super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches)) @@ -316,7 +314,6 @@ def on_sanity_check_end(self, trainer, pl_module): super().on_sanity_check_end(trainer, pl_module) self.main_progress_bar.close() self.val_progress_bar.close() - self._sanity_check_completed = True def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) @@ -343,7 +340,7 @@ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_id def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) - if self._sanity_check_completed is True: + if not trainer.running_sanity_check: self.val_progress_bar = self.init_validation_tqdm() self.val_progress_bar.total = convert_inf(self.total_val_batches) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 7ce4bec24bc79..91eecdcf37b19 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -1,7 +1,7 @@ import pytest from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import Callback, ProgressBarBase, ProgressBar, ModelCheckpoint +from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -203,15 +203,16 @@ def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches, expected): """ Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument. """ - class CustomCallback(Callback): + class CurrentProgressBar(ProgressBar): def __init__(self): + super().__init__() self.val_progress_bar_total = 0 def on_validation_epoch_end(self, trainer, pl_module): self.val_progress_bar_total += trainer.progress_bar_callback.val_progress_bar.total model = EvalModelTemplate() - cb = CustomCallback() + progress_bar = CurrentProgressBar() trainer = Trainer( default_root_dir=tmpdir, @@ -219,10 +220,10 @@ def on_validation_epoch_end(self, trainer, pl_module): num_sanity_val_steps=2, limit_train_batches=0, limit_val_batches=limit_val_batches, - callbacks=[cb], + callbacks=[progress_bar], logger=False, checkpoint_callback=False, early_stop_callback=False, ) trainer.fit(model) - assert cb.val_progress_bar_total == expected + assert trainer.progress_bar_callback.val_progress_bar_total == expected