diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 159cb97097120..ca8d24d9f5d21 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a pickling error when using `RichProgressBar` together with checkpointing ([#15319](https://github.com/Lightning-AI/lightning/pull/15319)) +- Fixed an issue with `RichProgressBar` not resetting the internal state for the sanity check progress ([#15377](https://github.com/Lightning-AI/lightning/pull/15377)) + + ## [1.8.0] - 2022-MM-DD diff --git a/src/pytorch_lightning/callbacks/progress/rich_progress.py b/src/pytorch_lightning/callbacks/progress/rich_progress.py index 30f616035b59f..b866874ed9cb4 100644 --- a/src/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/src/pytorch_lightning/callbacks/progress/rich_progress.py @@ -512,6 +512,7 @@ def _stop_progress(self) -> None: def _reset_progress_bar_ids(self) -> None: self.main_progress_bar_id = None + self.val_sanity_progress_bar_id = None self.val_progress_bar_id = None self.test_progress_bar_id = None self.predict_progress_bar_id = None diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 57be12c503811..b9bf650e0523c 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -441,3 +441,37 @@ def test_rich_progress_bar_can_be_pickled(): pickle.dumps(bar) trainer.predict(model) pickle.dumps(bar) + + +@RunIf(rich=True) +def test_rich_progress_bar_reset_bars(): + """Test that the progress bar resets all internal bars when a new trainer stage begins.""" + bar = RichProgressBar() + assert bar.is_enabled + assert bar.progress is None + assert bar._progress_stopped is False + + def _set_fake_bar_ids(): + bar.main_progress_bar_id = 0 + bar.val_sanity_progress_bar_id = 1 + bar.val_progress_bar_id = 2 + bar.test_progress_bar_id = 3 + bar.predict_progress_bar_id = 4 + + for stage in ("train", "sanity_check", "validation", "test", "predict"): + hook_name = f"on_{stage}_start" + hook = getattr(bar, hook_name) + + _set_fake_bar_ids() # pretend that bars are initialized from a previous run + hook(Mock(), Mock()) + bar.teardown(Mock(), Mock(), Mock()) + + # assert all bars are reset + assert bar.main_progress_bar_id is None + assert bar.val_sanity_progress_bar_id is None + assert bar.val_progress_bar_id is None + assert bar.test_progress_bar_id is None + assert bar.predict_progress_bar_id is None + + # the progress object remains in case we need it for the next stage + assert bar.progress is not None