From ed1df60e18ab4b7c6fbfc8cfd1dadef61b80bc4b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 28 Oct 2022 01:55:30 +0200 Subject: [PATCH 1/7] test --- .../progress/test_rich_progress_bar.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 57be12c503811..f13b03cfa9054 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -419,6 +419,7 @@ def test_rich_progress_bar_padding(): assert len(progress_bar.validation_description) == len(train_description) +@RunIf(rich=True) def test_rich_progress_bar_can_be_pickled(): bar = RichProgressBar() trainer = Trainer( @@ -441,3 +442,38 @@ def test_rich_progress_bar_can_be_pickled(): pickle.dumps(bar) trainer.predict(model) pickle.dumps(bar) + + +@RunIf(rich=True) +def test_rich_progress_bar_disabled(tmpdir): + bar = RichProgressBar() + bar.disable() + assert bar.is_disabled + + # bar.progress = Mock() + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + limit_predict_batches=2, + max_epochs=1, + enable_model_summary=False, + enable_checkpointing=False, + callbacks=[bar], + ) + + with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.CustomProgress") as mocked: + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + trainer.predict(model) + + mocked.assert_not_called() + assert bar.main_progress_bar_id is None + assert bar.val_sanity_progress_bar_id is None + assert bar.val_progress_bar_id is None + assert bar.test_progress_bar_id is None + assert bar.predict_progress_bar_id is None From 18aeb283e33de344782aa1676edc244136d92244 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 28 Oct 2022 01:56:17 +0200 Subject: [PATCH 2/7] fix disabled --- .../callbacks/progress/rich_progress.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/callbacks/progress/rich_progress.py b/src/pytorch_lightning/callbacks/progress/rich_progress.py index 30f616035b59f..f79fa9abcc9d4 100644 --- a/src/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/src/pytorch_lightning/callbacks/progress/rich_progress.py @@ -255,9 +255,9 @@ def __init__( self._console: Optional[Console] = None self._console_kwargs = console_kwargs or {} self._enabled: bool = True - self.progress: Optional[Progress] = None - self.val_sanity_progress_bar_id: Optional["TaskID"] = None + self.progress: Optional[CustomProgress] = None self.main_progress_bar_id: Optional["TaskID"] + self.val_sanity_progress_bar_id: Optional["TaskID"] = None self.val_progress_bar_id: Optional["TaskID"] self.test_progress_bar_id: Optional["TaskID"] self.predict_progress_bar_id: Optional["TaskID"] @@ -336,6 +336,8 @@ def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod self.refresh() def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.is_disabled: + return total_batches = self.total_batches_current_epoch train_description = self._get_train_description(trainer.current_epoch) @@ -355,9 +357,11 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo def on_validation_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if not self.has_dataloader_changed(dataloader_idx) or self.progress is None: + if self.is_disabled or not self.has_dataloader_changed(dataloader_idx): return + assert self.progress is not None + if trainer.sanity_checking: if self.val_sanity_progress_bar_id is not None: self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) @@ -397,7 +401,7 @@ def _should_update(self, current: int, total: Union[int, float]) -> bool: return current % self.refresh_rate == 0 or current == total def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.val_progress_bar_id is not None and trainer.state.fn == "fit": + if self.is_enabled and self.val_progress_bar_id is not None and trainer.state.fn == "fit": assert self.progress is not None self.progress.update(self.val_progress_bar_id, advance=0, visible=False) self.refresh() @@ -416,7 +420,7 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") def on_test_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if not self.has_dataloader_changed(dataloader_idx): + if self.is_disabled or not self.has_dataloader_changed(dataloader_idx): return if self.test_progress_bar_id is not None: @@ -428,7 +432,7 @@ def on_test_batch_start( def on_predict_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if not self.has_dataloader_changed(dataloader_idx): + if self.is_disabled or not self.has_dataloader_changed(dataloader_idx): return if self.predict_progress_bar_id is not None: @@ -458,8 +462,9 @@ def on_validation_batch_end( batch_idx: int, dataloader_idx: int, ) -> None: + if self.is_disabled: + return if trainer.sanity_checking: - assert self.val_sanity_progress_bar_id is not None self._update(self.val_sanity_progress_bar_id, self.val_batch_idx) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar @@ -477,6 +482,8 @@ def on_test_batch_end( batch_idx: int, dataloader_idx: int, ) -> None: + if self.is_disabled: + return assert self.test_progress_bar_id is not None self._update(self.test_progress_bar_id, self.test_batch_idx) self.refresh() @@ -490,6 +497,8 @@ def on_predict_batch_end( batch_idx: int, dataloader_idx: int, ) -> None: + if self.is_disabled: + return assert self.predict_progress_bar_id is not None self._update(self.predict_progress_bar_id, self.predict_batch_idx) self.refresh() From 07c7eb77f060684d33b4a72db437c08ee7ad476f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 28 Oct 2022 01:57:02 +0200 Subject: [PATCH 3/7] move properties to top --- .../callbacks/progress/rich_progress.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/pytorch_lightning/callbacks/progress/rich_progress.py b/src/pytorch_lightning/callbacks/progress/rich_progress.py index f79fa9abcc9d4..3e5cc81a46882 100644 --- a/src/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/src/pytorch_lightning/callbacks/progress/rich_progress.py @@ -278,6 +278,30 @@ def is_enabled(self) -> bool: @property def is_disabled(self) -> bool: return not self.is_enabled + + @property + def val_progress_bar(self) -> Task: + assert self.progress is not None + assert self.val_progress_bar_id is not None + return self.progress.tasks[self.val_progress_bar_id] + + @property + def val_sanity_check_bar(self) -> Task: + assert self.progress is not None + assert self.val_sanity_progress_bar_id is not None + return self.progress.tasks[self.val_sanity_progress_bar_id] + + @property + def main_progress_bar(self) -> Task: + assert self.progress is not None + assert self.main_progress_bar_id is not None + return self.progress.tasks[self.main_progress_bar_id] + + @property + def test_progress_bar(self) -> Task: + assert self.progress is not None + assert self.test_progress_bar_id is not None + return self.progress.tasks[self.test_progress_bar_id] def _update_for_light_colab_theme(self) -> None: if _detect_light_colab_theme(): @@ -536,30 +560,6 @@ def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: self._stop_progress() - @property - def val_progress_bar(self) -> Task: - assert self.progress is not None - assert self.val_progress_bar_id is not None - return self.progress.tasks[self.val_progress_bar_id] - - @property - def val_sanity_check_bar(self) -> Task: - assert self.progress is not None - assert self.val_sanity_progress_bar_id is not None - return self.progress.tasks[self.val_sanity_progress_bar_id] - - @property - def main_progress_bar(self) -> Task: - assert self.progress is not None - assert self.main_progress_bar_id is not None - return self.progress.tasks[self.main_progress_bar_id] - - @property - def test_progress_bar(self) -> Task: - assert self.progress is not None - assert self.test_progress_bar_id is not None - return self.progress.tasks[self.test_progress_bar_id] - def configure_columns(self, trainer: "pl.Trainer") -> list: return [ TextColumn("[progress.description]{task.description}"), From a6b437dc824e69ae4a7dd036f08505525ec54b50 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 28 Oct 2022 02:04:14 +0200 Subject: [PATCH 4/7] test --- src/pytorch_lightning/callbacks/progress/rich_progress.py | 2 +- .../tests_pytorch/callbacks/progress/test_rich_progress_bar.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/callbacks/progress/rich_progress.py b/src/pytorch_lightning/callbacks/progress/rich_progress.py index 3e5cc81a46882..54d412f69ee03 100644 --- a/src/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/src/pytorch_lightning/callbacks/progress/rich_progress.py @@ -278,7 +278,7 @@ def is_enabled(self) -> bool: @property def is_disabled(self) -> bool: return not self.is_enabled - + @property def val_progress_bar(self) -> Task: assert self.progress is not None diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index f13b03cfa9054..ce5b36f72a755 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -446,12 +446,11 @@ def test_rich_progress_bar_can_be_pickled(): @RunIf(rich=True) def test_rich_progress_bar_disabled(tmpdir): + """Test that in a disabled bar there are no updates and no internal progress objects.""" bar = RichProgressBar() bar.disable() assert bar.is_disabled - # bar.progress = Mock() - model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, From cfc2dc81aa46d6aa16436aa9bf8f612f178c493f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Oct 2022 00:04:59 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/callbacks/progress/rich_progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/progress/rich_progress.py b/src/pytorch_lightning/callbacks/progress/rich_progress.py index 3e5cc81a46882..54d412f69ee03 100644 --- a/src/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/src/pytorch_lightning/callbacks/progress/rich_progress.py @@ -278,7 +278,7 @@ def is_enabled(self) -> bool: @property def is_disabled(self) -> bool: return not self.is_enabled - + @property def val_progress_bar(self) -> Task: assert self.progress is not None From da5d6c12d9650a796d593117fd6e974f4475a9fb Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 28 Oct 2022 02:05:50 +0200 Subject: [PATCH 6/7] changelog --- src/pytorch_lightning/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 159cb97097120..575d09188edd3 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed a pickling error when using `RichProgressBar` together with checkpointing ([#15319](https://github.com/Lightning-AI/lightning/pull/15319)) +- Fixed the `RichProgressBar` crashing when used with distributed strategies ([#15376](https://github.com/Lightning-AI/lightning/pull/15376)) ## [1.8.0] - 2022-MM-DD From ae3e28d5a07b4d44931e0fcddba7da23adcead93 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 28 Oct 2022 12:29:28 +0200 Subject: [PATCH 7/7] order --- .../callbacks/progress/rich_progress.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/callbacks/progress/rich_progress.py b/src/pytorch_lightning/callbacks/progress/rich_progress.py index d21dadd4586fd..1704a8f43effa 100644 --- a/src/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/src/pytorch_lightning/callbacks/progress/rich_progress.py @@ -280,10 +280,10 @@ def is_disabled(self) -> bool: return not self.is_enabled @property - def val_progress_bar(self) -> Task: + def main_progress_bar(self) -> Task: assert self.progress is not None - assert self.val_progress_bar_id is not None - return self.progress.tasks[self.val_progress_bar_id] + assert self.main_progress_bar_id is not None + return self.progress.tasks[self.main_progress_bar_id] @property def val_sanity_check_bar(self) -> Task: @@ -292,10 +292,10 @@ def val_sanity_check_bar(self) -> Task: return self.progress.tasks[self.val_sanity_progress_bar_id] @property - def main_progress_bar(self) -> Task: + def val_progress_bar(self) -> Task: assert self.progress is not None - assert self.main_progress_bar_id is not None - return self.progress.tasks[self.main_progress_bar_id] + assert self.val_progress_bar_id is not None + return self.progress.tasks[self.val_progress_bar_id] @property def test_progress_bar(self) -> Task: