Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a pickling error when using `RichProgressBar` together with checkpointing ([#15319](https://github.com/Lightning-AI/lightning/pull/15319))


- Fixed an issue with `RichProgressBar` not resetting the internal state for the sanity check progress ([#15377](https://github.com/Lightning-AI/lightning/pull/15377))


## [1.8.0] - 2022-MM-DD


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ def _stop_progress(self) -> None:

def _reset_progress_bar_ids(self) -> None:
self.main_progress_bar_id = None
self.val_sanity_progress_bar_id = None
self.val_progress_bar_id = None
self.test_progress_bar_id = None
self.predict_progress_bar_id = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,37 @@ def test_rich_progress_bar_can_be_pickled():
pickle.dumps(bar)
trainer.predict(model)
pickle.dumps(bar)


@RunIf(rich=True)
def test_rich_progress_bar_reset_bars():
"""Test that the progress bar resets all internal bars when a new trainer stage begins."""
bar = RichProgressBar()
assert bar.is_enabled
assert bar.progress is None
assert bar._progress_stopped is False

def _set_fake_bar_ids():
bar.main_progress_bar_id = 0
bar.val_sanity_progress_bar_id = 1
bar.val_progress_bar_id = 2
bar.test_progress_bar_id = 3
bar.predict_progress_bar_id = 4

for stage in ("train", "sanity_check", "validation", "test", "predict"):
hook_name = f"on_{stage}_start"
hook = getattr(bar, hook_name)

_set_fake_bar_ids() # pretend that bars are initialized from a previous run
hook(Mock(), Mock())
bar.teardown(Mock(), Mock(), Mock())

# assert all bars are reset
assert bar.main_progress_bar_id is None
assert bar.val_sanity_progress_bar_id is None
assert bar.val_progress_bar_id is None
assert bar.test_progress_bar_id is None
assert bar.predict_progress_bar_id is None

# the progress object remains in case we need it for the next stage
assert bar.progress is not None