Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from pytorch_lightning.utilities.imports import (
_PYTHON_GREATER_EQUAL_3_8_0,
_PYTHON_GREATER_EQUAL_3_10_0,
_RICH_AVAILABLE,
)
from pytorch_lightning.utilities.rank_zero import rank_zero_info

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -174,7 +178,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None:
)

if enable_progress_bar:
progress_bar_callback = TQDMProgressBar()
progress_bar_callback = RichProgressBar() if _RICH_AVAILABLE else TQDMProgressBar()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm hesitant of this change for two reasons:

  1. Rich pbar is slower than tqdm (there's an issue about this somewhere, cc @akihironitta)
  2. I believe CI installs rich, which means that now tqdm is untested. Since it's the most widely used, I find that risky.

self.trainer.callbacks.append(progress_bar_callback)

def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,11 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
# check the sanity dataloaders
num_sanity_val_steps = 4
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=0,
num_sanity_val_steps=num_sanity_val_steps,
callbacks=TQDMProgressBar(),
)
pbar = trainer.progress_bar_callback
with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
Expand All @@ -155,7 +159,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)]

# fit
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=TQDMProgressBar())
pbar = trainer.progress_bar_callback
with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
trainer.fit(model)
Expand Down Expand Up @@ -206,7 +210,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
def test_tqdm_progress_bar_fast_dev_run(tmpdir):
model = BoringModel()

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TQDMProgressBar())

trainer.fit(model)

Expand Down Expand Up @@ -326,16 +330,13 @@ def on_validation_epoch_end(self, *args):

def test_tqdm_progress_bar_default_value(tmpdir):
"""Test that a value of None defaults to refresh rate 1."""
trainer = Trainer(default_root_dir=tmpdir)
trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar())
assert trainer.progress_bar_callback.refresh_rate == 1


@mock.patch.dict(os.environ, {"COLAB_GPU": "1"})
def test_tqdm_progress_bar_value_on_colab(tmpdir):
"""Test that Trainer will override the default in Google COLAB."""
trainer = Trainer(default_root_dir=tmpdir)
assert trainer.progress_bar_callback.refresh_rate == 20

trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar())
assert trainer.progress_bar_callback.refresh_rate == 20

Expand Down Expand Up @@ -411,7 +412,12 @@ def training_step(self, batch, batch_idx):
return super().training_step(batch, batch_idx)

trainer = Trainer(
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
logger=False,
enable_checkpointing=False,
callbacks=TQDMProgressBar(),
)
trainer.fit(TestModel())

Expand Down Expand Up @@ -624,6 +630,7 @@ def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval):
strategy="ddp",
enable_progress_bar=True,
enable_model_summary=False,
callbacks=TQDMProgressBar(),
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,17 @@
ModelCheckpoint,
ModelSummary,
ProgressBarBase,
RichProgressBar,
TQDMProgressBar,
)
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from pytorch_lightning.utilities.imports import (
_PYTHON_GREATER_EQUAL_3_8_0,
_PYTHON_GREATER_EQUAL_3_10_0,
_RICH_AVAILABLE,
)


def test_checkpoint_callbacks_are_last(tmpdir):
Expand Down Expand Up @@ -180,7 +185,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
return trainer

early_stopping = EarlyStopping(monitor="foo")
progress_bar = TQDMProgressBar()
progress_bar = RichProgressBar if _RICH_AVAILABLE else TQDMProgressBar()
lr_monitor = LearningRateMonitor()
grad_accumulation = GradientAccumulationScheduler({1: 1})

Expand Down