From 8618d7478cbc4a7373b30d39b74830c87eb401a2 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 02:24:10 +0200 Subject: [PATCH 1/7] rich as the default bar --- .../trainer/connectors/callback_connector.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index ba1b5b8bf7be3..8b5dfac3f774f 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -33,7 +33,8 @@ from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 +from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0, \ + _RICH_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_info _log = logging.getLogger(__name__) @@ -174,7 +175,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None: ) if enable_progress_bar: - progress_bar_callback = TQDMProgressBar() + progress_bar_callback = RichProgressBar() if _RICH_AVAILABLE else TQDMProgressBar() self.trainer.callbacks.append(progress_bar_callback) def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None: From f1a8dc37dff26706abd911cb0eb525fe3a91605b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 03:39:03 +0200 Subject: [PATCH 2/7] fix assertion errors --- .../callbacks/progress/rich_progress.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/callbacks/progress/rich_progress.py b/src/pytorch_lightning/callbacks/progress/rich_progress.py index 5677f95a1f9d6..df5df211a9dee 100644 --- a/src/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/src/pytorch_lightning/callbacks/progress/rich_progress.py @@ -354,7 +354,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo def on_validation_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if not self.has_dataloader_changed(dataloader_idx) or self.progress is None: + if self.is_disabled or not self.has_dataloader_changed(dataloader_idx): return if trainer.sanity_checking: @@ -396,7 +396,7 @@ def _should_update(self, current: int, total: Union[int, float]) -> bool: return current % self.refresh_rate == 0 or current == total def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.val_progress_bar_id is not None and trainer.state.fn == "fit": + if self.is_enabled and self.val_progress_bar_id is not None and trainer.state.fn == "fit": assert self.progress is not None self.progress.update(self.val_progress_bar_id, advance=0, visible=False) self.refresh() @@ -415,7 +415,7 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") def on_test_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if not self.has_dataloader_changed(dataloader_idx): + if self.is_disabled or not self.has_dataloader_changed(dataloader_idx): return if self.test_progress_bar_id is not None: @@ -427,7 +427,7 @@ def on_test_batch_start( def on_predict_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - if not self.has_dataloader_changed(dataloader_idx): + if self.is_disabled or not self.has_dataloader_changed(dataloader_idx): return if self.predict_progress_bar_id is not None: @@ -457,8 +457,9 @@ def on_validation_batch_end( batch_idx: int, dataloader_idx: int, ) -> None: + if self.is_disabled: + return if trainer.sanity_checking: - assert self.val_sanity_progress_bar_id is not None self._update(self.val_sanity_progress_bar_id, self.val_batch_idx) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar @@ -476,6 +477,8 @@ def on_test_batch_end( batch_idx: int, dataloader_idx: int, ) -> None: + if self.is_disabled: + return assert self.test_progress_bar_id is not None self._update(self.test_progress_bar_id, self.test_batch_idx) self.refresh() @@ -489,6 +492,8 @@ def on_predict_batch_end( batch_idx: int, dataloader_idx: int, ) -> None: + if self.is_disabled: + return assert self.predict_progress_bar_id is not None self._update(self.predict_progress_bar_id, self.predict_batch_idx) self.refresh() From 8f1ce5cf8bb8fb9db4434c3a6153efcbf9e2d4d7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 03:39:14 +0200 Subject: [PATCH 3/7] fixes --- .../trainer/connectors/test_callback_connector.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index b7ecab6998658..558d37d757066 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -26,12 +26,13 @@ ModelCheckpoint, ModelSummary, ProgressBarBase, - TQDMProgressBar, + TQDMProgressBar, RichProgressBar, ) from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector -from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 +from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0, \ + _RICH_AVAILABLE def test_checkpoint_callbacks_are_last(tmpdir): @@ -180,7 +181,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): return trainer early_stopping = EarlyStopping(monitor="foo") - progress_bar = TQDMProgressBar() + progress_bar = RichProgressBar if _RICH_AVAILABLE else TQDMProgressBar() lr_monitor = LearningRateMonitor() grad_accumulation = GradientAccumulationScheduler({1: 1}) From ece84d07a90fbb1ed77b96edb4904100ca5f4e03 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 03:42:58 +0200 Subject: [PATCH 4/7] test fixes --- .../trainer/connectors/callback_connector.py | 7 ++++-- .../progress/test_tqdm_progress_bar.py | 23 ++++++++++++------- .../connectors/test_callback_connector.py | 10 +++++--- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 8b5dfac3f774f..709492fb77a49 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -33,8 +33,11 @@ from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0, \ - _RICH_AVAILABLE +from pytorch_lightning.utilities.imports import ( + _PYTHON_GREATER_EQUAL_3_8_0, + _PYTHON_GREATER_EQUAL_3_10_0, + _RICH_AVAILABLE, +) from pytorch_lightning.utilities.rank_zero import rank_zero_info _log = logging.getLogger(__name__) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index d71d28d4e2712..a231e1f863adb 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -141,7 +141,11 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): # check the sanity dataloaders num_sanity_val_steps = 4 trainer = Trainer( - default_root_dir=tmpdir, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=0, + num_sanity_val_steps=num_sanity_val_steps, + callbacks=TQDMProgressBar(), ) pbar = trainer.progress_bar_callback with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): @@ -155,7 +159,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] # fit - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=TQDMProgressBar()) pbar = trainer.progress_bar_callback with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): trainer.fit(model) @@ -206,7 +210,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): def test_tqdm_progress_bar_fast_dev_run(tmpdir): model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TQDMProgressBar()) trainer.fit(model) @@ -326,16 +330,13 @@ def on_validation_epoch_end(self, *args): def test_tqdm_progress_bar_default_value(tmpdir): """Test that a value of None defaults to refresh rate 1.""" - trainer = Trainer(default_root_dir=tmpdir) + trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar()) assert trainer.progress_bar_callback.refresh_rate == 1 @mock.patch.dict(os.environ, {"COLAB_GPU": "1"}) def test_tqdm_progress_bar_value_on_colab(tmpdir): """Test that Trainer will override the default in Google COLAB.""" - trainer = Trainer(default_root_dir=tmpdir) - assert trainer.progress_bar_callback.refresh_rate == 20 - trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar()) assert trainer.progress_bar_callback.refresh_rate == 20 @@ -411,7 +412,12 @@ def training_step(self, batch, batch_idx): return super().training_step(batch, batch_idx) trainer = Trainer( - default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + logger=False, + enable_checkpointing=False, + callbacks=TQDMProgressBar(), ) trainer.fit(TestModel()) @@ -614,6 +620,7 @@ def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval): strategy="ddp", enable_progress_bar=True, enable_model_summary=False, + callbacks=TQDMProgressBar(), ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index 558d37d757066..0b44939dec2de 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -26,13 +26,17 @@ ModelCheckpoint, ModelSummary, ProgressBarBase, - TQDMProgressBar, RichProgressBar, + RichProgressBar, + TQDMProgressBar, ) from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector -from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0, \ - _RICH_AVAILABLE +from pytorch_lightning.utilities.imports import ( + _PYTHON_GREATER_EQUAL_3_8_0, + _PYTHON_GREATER_EQUAL_3_10_0, + _RICH_AVAILABLE, +) def test_checkpoint_callbacks_are_last(tmpdir): From 825266b38fec46a73ec5385c3fc3dcee9f43acf8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 04:08:42 +0200 Subject: [PATCH 5/7] mypy --- src/pytorch_lightning/callbacks/progress/rich_progress.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pytorch_lightning/callbacks/progress/rich_progress.py b/src/pytorch_lightning/callbacks/progress/rich_progress.py index df5df211a9dee..a1f12453adcbe 100644 --- a/src/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/src/pytorch_lightning/callbacks/progress/rich_progress.py @@ -357,6 +357,8 @@ def on_validation_batch_start( if self.is_disabled or not self.has_dataloader_changed(dataloader_idx): return + assert self.progress is not None + if trainer.sanity_checking: if self.val_sanity_progress_bar_id is not None: self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) From 8e5ca37d7ef5885bc9362e054225045743ebc271 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 28 Oct 2022 00:54:41 +0200 Subject: [PATCH 6/7] reset the sanity bar --- src/pytorch_lightning/callbacks/progress/rich_progress.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_lightning/callbacks/progress/rich_progress.py b/src/pytorch_lightning/callbacks/progress/rich_progress.py index e91726d4a9141..aa0a1464b5204 100644 --- a/src/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/src/pytorch_lightning/callbacks/progress/rich_progress.py @@ -520,6 +520,7 @@ def _stop_progress(self) -> None: def _reset_progress_bar_ids(self) -> None: self.main_progress_bar_id = None self.val_progress_bar_id = None + self.val_sanity_progress_bar_id = None self.test_progress_bar_id = None self.predict_progress_bar_id = None From 3bdf5564efbeae27e2090bbe95a9e7e041b0a75a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 29 Oct 2022 02:54:10 +0200 Subject: [PATCH 7/7] reset --- src/pytorch_lightning/callbacks/progress/rich_progress.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/progress/rich_progress.py b/src/pytorch_lightning/callbacks/progress/rich_progress.py index 100b2b9dbc92f..1704a8f43effa 100644 --- a/src/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/src/pytorch_lightning/callbacks/progress/rich_progress.py @@ -547,7 +547,6 @@ def _reset_progress_bar_ids(self) -> None: self.main_progress_bar_id = None self.val_sanity_progress_bar_id = None self.val_progress_bar_id = None - self.val_sanity_progress_bar_id = None self.test_progress_bar_id = None self.predict_progress_bar_id = None