Skip to content
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))


- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497))


- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))


Expand Down
59 changes: 36 additions & 23 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ class RichProgressBar(ProgressBarBase):
trainer = Trainer(callbacks=RichProgressBar())

Args:
refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled.
refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
Set it to ``0`` to disable the display.
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
theme: Contains styles used to stylize the progress bar.

Expand All @@ -222,7 +223,7 @@ class RichProgressBar(ProgressBarBase):

def __init__(
self,
refresh_rate_per_second: int = 10,
refresh_rate: int = 1,
leave: bool = False,
theme: RichProgressBarTheme = RichProgressBarTheme(),
) -> None:
Expand All @@ -231,7 +232,7 @@ def __init__(
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`."
)
super().__init__()
self._refresh_rate_per_second: int = refresh_rate_per_second
self._refresh_rate: int = refresh_rate
self._leave: bool = leave
self._enabled: bool = True
self.progress: Optional[Progress] = None
Expand All @@ -242,17 +243,12 @@ def __init__(
self.theme = theme

@property
def refresh_rate_per_second(self) -> float:
"""Refresh rate for Rich Progress.

Returns: Refresh rate for Progress Bar.
Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress).
"""
return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1
def refresh_rate(self) -> float:
return self._refresh_rate

@property
def is_enabled(self) -> bool:
return self._enabled and self._refresh_rate_per_second > 0
return self._enabled and self.refresh_rate > 0

@property
def is_disabled(self) -> bool:
Expand Down Expand Up @@ -289,14 +285,18 @@ def _init_progress(self, trainer):
self.progress = CustomProgress(
*self.configure_columns(trainer),
self._metric_component,
refresh_per_second=self.refresh_rate_per_second,
auto_refresh=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the fix I proposed in #9647 to @SeanNaren to prevent threading issues in the render function.
Did you check that this might enable #9647 to pass the CI now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked to Sean regarding it, it didn't work!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kaushikb11, I'm working on #13937 and debugging it has led me to this PR. Do you remember if there is any particular reason why we changed auto_refresh from True to False? Was it for #10362?

disable=self.is_disabled,
console=self._console,
)
self.progress.start()
# progress has started
self._progress_stopped = False

def refresh(self) -> None:
if self.progress:
self.progress.refresh()

def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self._init_progress(trainer)
Expand Down Expand Up @@ -328,10 +328,12 @@ def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self._init_progress(trainer)
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
self.refresh()

def on_sanity_check_end(self, trainer, pl_module):
super().on_sanity_check_end(trainer, pl_module)
self._update(self.val_sanity_progress_bar_id, visible=False)
self.refresh()

def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
Expand All @@ -354,6 +356,7 @@ def on_train_epoch_start(self, trainer, pl_module):
self.progress.reset(
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
)
self.refresh()

def on_validation_epoch_start(self, trainer, pl_module):
super().on_validation_epoch_start(trainer, pl_module)
Expand All @@ -364,52 +367,62 @@ def on_validation_epoch_start(self, trainer, pl_module):
val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch
total_val_batches = self.total_val_batches * val_checks_per_epoch
self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False)
self.refresh()

def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
if self.progress is not None:
return self.progress.add_task(
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
)

def _update(self, progress_bar_id: int, visible: bool = True) -> None:
if self.progress is not None:
self.progress.update(progress_bar_id, advance=1.0, visible=visible)
def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None:
if self.progress is not None and self._should_update(current, total):
self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible)
self.refresh()

def _should_update(self, current: int, total: int) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def on_validation_epoch_end(self, trainer, pl_module):
super().on_validation_epoch_end(trainer, pl_module)
if self.val_progress_bar_id is not None:
self._update(self.val_progress_bar_id, visible=False)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False)

def on_test_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
self.refresh()

def on_predict_epoch_start(self, trainer, pl_module):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
self.refresh()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
self._update(self.main_progress_bar_id)
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
self._update_metrics(trainer, pl_module)
self.refresh()

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if trainer.sanity_checking:
self._update(self.val_sanity_progress_bar_id)
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches)
elif self.val_progress_bar_id is not None:
# check to see if we should update the main training progress bar
if self.main_progress_bar_id is not None:
self._update(self.main_progress_bar_id)
self._update(self.val_progress_bar_id)
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches)
self.refresh()

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
self._update(self.test_progress_bar_id)
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches)
self.refresh()

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
self._update(self.predict_progress_bar_id)
self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches)
self.refresh()

def _get_train_description(self, current_epoch: int) -> str:
train_description = f"Epoch {current_epoch}"
Expand Down
27 changes: 24 additions & 3 deletions tests/callbacks/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def test_rich_progress_bar_callback():


@RunIf(rich=True)
def test_rich_progress_bar_refresh_rate():
progress_bar = RichProgressBar(refresh_rate_per_second=1)
def test_rich_progress_bar_refresh_rate_enabled():
progress_bar = RichProgressBar(refresh_rate=1)
assert progress_bar.is_enabled
assert not progress_bar.is_disabled
progress_bar = RichProgressBar(refresh_rate_per_second=0)
progress_bar = RichProgressBar(refresh_rate=0)
assert not progress_bar.is_enabled
assert progress_bar.is_disabled

Expand Down Expand Up @@ -180,3 +180,24 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):
)
trainer.fit(model)
assert mock_progress_reset.call_count == reset_call_count


@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(0, 0), (3, 7)]))
def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, expected_call_count):

model = BoringModel()

trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
limit_train_batches=6,
limit_val_batches=6,
max_epochs=1,
callbacks=RichProgressBar(refresh_rate=refresh_rate),
)

trainer.fit(model)

assert progress_update.call_count == expected_call_count