Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 22 additions & 23 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class RichProgressBar(ProgressBarBase):
trainer = Trainer(callbacks=RichProgressBar())

Args:
refresh_rate: the number of updates per second, must be strictly positive
refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled.
theme: Contains styles used to stylize the progress bar.

Raises:
Expand All @@ -135,15 +135,15 @@ class RichProgressBar(ProgressBarBase):

def __init__(
self,
refresh_rate: float = 1.0,
refresh_rate_per_second: int = 10,
theme: RichProgressBarTheme = RichProgressBarTheme(),
) -> None:
if not _RICH_AVAILABLE:
raise ImportError(
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
)
super().__init__()
self._refresh_rate: float = refresh_rate
self._refresh_rate_per_second: int = refresh_rate_per_second
self._enabled: bool = True
self._total_val_batches: int = 0
self.progress: Progress = None
Expand All @@ -156,12 +156,17 @@ def __init__(
self.theme = theme

@property
def refresh_rate(self) -> int:
return self._refresh_rate
def refresh_rate_per_second(self) -> float:
"""Refresh rate for Rich Progress.

Returns: Refresh rate for Progress Bar.
Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress).
"""
return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1

@property
def is_enabled(self) -> bool:
return self._enabled and self.refresh_rate > 0
return self._enabled and self._refresh_rate_per_second > 0

@property
def is_disabled(self) -> bool:
Expand Down Expand Up @@ -189,7 +194,7 @@ def test_description(self) -> str:
def predict_description(self) -> str:
return "Predicting"

def setup(self, trainer, pl_module, stage):
def setup(self, trainer, pl_module, stage: Optional[str] = None):
self.progress = Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(complete_style=self.theme.progress_bar_complete, finished_style=self.theme.progress_bar_finished),
Expand All @@ -198,8 +203,10 @@ def setup(self, trainer, pl_module, stage):
ProcessingSpeedColumn(style=self.theme.processing_speed),
MetricsTextColumn(trainer, pl_module, stage),
console=self.console,
refresh_per_second=self.refresh_rate,
).__enter__()
refresh_per_second=self.refresh_rate_per_second,
disable=self.is_disabled,
)
self.progress.start()

def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
Expand Down Expand Up @@ -259,31 +266,23 @@ def on_predict_epoch_start(self, trainer, pl_module):

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
self.progress.update(self.main_progress_bar_id, advance=1.0)
self.progress.update(self.main_progress_bar_id, advance=1.0)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if trainer.sanity_checking:
self.progress.update(self.val_sanity_progress_bar_id, advance=1.0)
elif self.val_progress_bar_id and self._should_update(
self.val_batch_idx, self.total_train_batches + self.total_val_batches
):
elif self.val_progress_bar_id:
self.progress.update(self.main_progress_bar_id, advance=1.0)
self.progress.update(self.val_progress_bar_id, advance=1.0)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.test_batch_idx, self.total_test_batches):
self.progress.update(self.test_progress_bar_id, advance=1.0)
self.progress.update(self.test_progress_bar_id, advance=1.0)

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
self.progress.update(self.predict_progress_bar_id, advance=1.0)

def _should_update(self, current, total) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
self.progress.update(self.predict_progress_bar_id, advance=1.0)

def _get_train_description(self, current_epoch: int) -> str:
train_description = f"Epoch {current_epoch}"
Expand All @@ -296,8 +295,8 @@ def _get_train_description(self, current_epoch: int) -> str:
train_description += " "
return train_description

def teardown(self, trainer, pl_module, stage):
self.progress.__exit__(None, None, None)
def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None:
self.progress.stop()

def on_exception(self, trainer, pl_module, exception: BaseException) -> None:
if isinstance(exception, KeyboardInterrupt):
Expand Down
10 changes: 10 additions & 0 deletions tests/callbacks/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ def test_rich_progress_bar_callback():
assert isinstance(trainer.progress_bar_callback, RichProgressBar)


@RunIf(rich=True)
def test_rich_progress_bar_refresh_rate():
progress_bar = RichProgressBar(refresh_rate_per_second=1)
assert progress_bar.is_enabled
assert not progress_bar.is_disabled
progress_bar = RichProgressBar(refresh_rate_per_second=0)
assert not progress_bar.is_enabled
assert progress_bar.is_disabled


@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
def test_rich_progress_bar(progress_update, tmpdir):
Expand Down