Skip to content

Commit fd2e778

Browse files
author
Sean Naren
authored
Improvements for rich progress (#9579)
1 parent 3f7872d commit fd2e778

File tree

2 files changed

+32
-23
lines changed

2 files changed

+32
-23
lines changed

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class RichProgressBar(ProgressBarBase):
125125
trainer = Trainer(callbacks=RichProgressBar())
126126
127127
Args:
128-
refresh_rate: the number of updates per second, must be strictly positive
128+
refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled.
129129
theme: Contains styles used to stylize the progress bar.
130130
131131
Raises:
@@ -135,15 +135,15 @@ class RichProgressBar(ProgressBarBase):
135135

136136
def __init__(
137137
self,
138-
refresh_rate: float = 1.0,
138+
refresh_rate_per_second: int = 10,
139139
theme: RichProgressBarTheme = RichProgressBarTheme(),
140140
) -> None:
141141
if not _RICH_AVAILABLE:
142142
raise ImportError(
143143
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
144144
)
145145
super().__init__()
146-
self._refresh_rate: float = refresh_rate
146+
self._refresh_rate_per_second: int = refresh_rate_per_second
147147
self._enabled: bool = True
148148
self._total_val_batches: int = 0
149149
self.progress: Progress = None
@@ -156,12 +156,17 @@ def __init__(
156156
self.theme = theme
157157

158158
@property
159-
def refresh_rate(self) -> int:
160-
return self._refresh_rate
159+
def refresh_rate_per_second(self) -> float:
160+
"""Refresh rate for Rich Progress.
161+
162+
Returns: Refresh rate for Progress Bar.
163+
Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress).
164+
"""
165+
return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1
161166

162167
@property
163168
def is_enabled(self) -> bool:
164-
return self._enabled and self.refresh_rate > 0
169+
return self._enabled and self._refresh_rate_per_second > 0
165170

166171
@property
167172
def is_disabled(self) -> bool:
@@ -189,7 +194,7 @@ def test_description(self) -> str:
189194
def predict_description(self) -> str:
190195
return "Predicting"
191196

192-
def setup(self, trainer, pl_module, stage):
197+
def setup(self, trainer, pl_module, stage: Optional[str] = None):
193198
self.progress = Progress(
194199
TextColumn("[progress.description]{task.description}"),
195200
BarColumn(complete_style=self.theme.progress_bar_complete, finished_style=self.theme.progress_bar_finished),
@@ -198,8 +203,10 @@ def setup(self, trainer, pl_module, stage):
198203
ProcessingSpeedColumn(style=self.theme.processing_speed),
199204
MetricsTextColumn(trainer, pl_module, stage),
200205
console=self.console,
201-
refresh_per_second=self.refresh_rate,
202-
).__enter__()
206+
refresh_per_second=self.refresh_rate_per_second,
207+
disable=self.is_disabled,
208+
)
209+
self.progress.start()
203210

204211
def on_sanity_check_start(self, trainer, pl_module):
205212
super().on_sanity_check_start(trainer, pl_module)
@@ -259,31 +266,23 @@ def on_predict_epoch_start(self, trainer, pl_module):
259266

260267
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
261268
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
262-
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
263-
self.progress.update(self.main_progress_bar_id, advance=1.0)
269+
self.progress.update(self.main_progress_bar_id, advance=1.0)
264270

265271
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
266272
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
267273
if trainer.sanity_checking:
268274
self.progress.update(self.val_sanity_progress_bar_id, advance=1.0)
269-
elif self.val_progress_bar_id and self._should_update(
270-
self.val_batch_idx, self.total_train_batches + self.total_val_batches
271-
):
275+
elif self.val_progress_bar_id:
272276
self.progress.update(self.main_progress_bar_id, advance=1.0)
273277
self.progress.update(self.val_progress_bar_id, advance=1.0)
274278

275279
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
276280
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
277-
if self._should_update(self.test_batch_idx, self.total_test_batches):
278-
self.progress.update(self.test_progress_bar_id, advance=1.0)
281+
self.progress.update(self.test_progress_bar_id, advance=1.0)
279282

280283
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
281284
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
282-
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
283-
self.progress.update(self.predict_progress_bar_id, advance=1.0)
284-
285-
def _should_update(self, current, total) -> bool:
286-
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
285+
self.progress.update(self.predict_progress_bar_id, advance=1.0)
287286

288287
def _get_train_description(self, current_epoch: int) -> str:
289288
train_description = f"Epoch {current_epoch}"
@@ -296,8 +295,8 @@ def _get_train_description(self, current_epoch: int) -> str:
296295
train_description += " "
297296
return train_description
298297

299-
def teardown(self, trainer, pl_module, stage):
300-
self.progress.__exit__(None, None, None)
298+
def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None:
299+
self.progress.stop()
301300

302301
def on_exception(self, trainer, pl_module, exception: BaseException) -> None:
303302
if isinstance(exception, KeyboardInterrupt):

tests/callbacks/test_rich_progress_bar.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ def test_rich_progress_bar_callback():
3434
assert isinstance(trainer.progress_bar_callback, RichProgressBar)
3535

3636

37+
@RunIf(rich=True)
38+
def test_rich_progress_bar_refresh_rate():
39+
progress_bar = RichProgressBar(refresh_rate_per_second=1)
40+
assert progress_bar.is_enabled
41+
assert not progress_bar.is_disabled
42+
progress_bar = RichProgressBar(refresh_rate_per_second=0)
43+
assert not progress_bar.is_enabled
44+
assert progress_bar.is_disabled
45+
46+
3747
@RunIf(rich=True)
3848
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
3949
def test_rich_progress_bar(progress_update, tmpdir):

0 commit comments

Comments
 (0)