diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index ba1b5b8bf7be3..709492fb77a49 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -33,7 +33,11 @@ from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 +from pytorch_lightning.utilities.imports import ( + _PYTHON_GREATER_EQUAL_3_8_0, + _PYTHON_GREATER_EQUAL_3_10_0, + _RICH_AVAILABLE, +) from pytorch_lightning.utilities.rank_zero import rank_zero_info _log = logging.getLogger(__name__) @@ -174,7 +178,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None: ) if enable_progress_bar: - progress_bar_callback = TQDMProgressBar() + progress_bar_callback = RichProgressBar() if _RICH_AVAILABLE else TQDMProgressBar() self.trainer.callbacks.append(progress_bar_callback) def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None: diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index 105f380be58f5..cff217e282fb9 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -141,7 +141,11 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): # check the sanity dataloaders num_sanity_val_steps = 4 trainer = Trainer( - default_root_dir=tmpdir, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=0, + num_sanity_val_steps=num_sanity_val_steps, + callbacks=TQDMProgressBar(), ) pbar = trainer.progress_bar_callback with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): @@ -155,7 +159,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] # fit - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=TQDMProgressBar()) pbar = trainer.progress_bar_callback with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): trainer.fit(model) @@ -206,7 +210,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): def test_tqdm_progress_bar_fast_dev_run(tmpdir): model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TQDMProgressBar()) trainer.fit(model) @@ -326,16 +330,13 @@ def on_validation_epoch_end(self, *args): def test_tqdm_progress_bar_default_value(tmpdir): """Test that a value of None defaults to refresh rate 1.""" - trainer = Trainer(default_root_dir=tmpdir) + trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar()) assert trainer.progress_bar_callback.refresh_rate == 1 @mock.patch.dict(os.environ, {"COLAB_GPU": "1"}) def test_tqdm_progress_bar_value_on_colab(tmpdir): """Test that Trainer will override the default in Google COLAB.""" - trainer = Trainer(default_root_dir=tmpdir) - assert trainer.progress_bar_callback.refresh_rate == 20 - trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar()) assert trainer.progress_bar_callback.refresh_rate == 20 @@ -411,7 +412,12 @@ def training_step(self, batch, batch_idx): return super().training_step(batch, batch_idx) trainer = Trainer( - default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + logger=False, + enable_checkpointing=False, + callbacks=TQDMProgressBar(), ) trainer.fit(TestModel()) @@ -624,6 +630,7 @@ def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval): strategy="ddp", enable_progress_bar=True, enable_model_summary=False, + callbacks=TQDMProgressBar(), ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index b7ecab6998658..0b44939dec2de 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -26,12 +26,17 @@ ModelCheckpoint, ModelSummary, ProgressBarBase, + RichProgressBar, TQDMProgressBar, ) from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector -from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 +from pytorch_lightning.utilities.imports import ( + _PYTHON_GREATER_EQUAL_3_8_0, + _PYTHON_GREATER_EQUAL_3_10_0, + _RICH_AVAILABLE, +) def test_checkpoint_callbacks_are_last(tmpdir): @@ -180,7 +185,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): return trainer early_stopping = EarlyStopping(monitor="foo") - progress_bar = TQDMProgressBar() + progress_bar = RichProgressBar if _RICH_AVAILABLE else TQDMProgressBar() lr_monitor = LearningRateMonitor() grad_accumulation = GradientAccumulationScheduler({1: 1})