Skip to content

Commit 45200fc

Browse files
author
Sean Naren
authored
Improvements for rich progress bar (#9559)
1 parent 3aba9d1 commit 45200fc

File tree

3 files changed

+119
-34
lines changed

3 files changed

+119
-34
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
110110
- Added a friendly error message when DDP attempts to spawn new distributed processes with rank > 0 ([#9005](https://github.com/PyTorchLightning/pytorch-lightning/pull/9005))
111111

112112

113-
- Added Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))
113+
- Added Rich Progress Bar:
114+
* Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))
115+
* Improvements for rich progress bar ([#9559](https://github.com/PyTorchLightning/pytorch-lightning/pull/9559))
114116

115117

116118
- Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,52 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from dataclasses import dataclass
1415
from datetime import timedelta
15-
from typing import Dict, Optional
16+
from typing import Optional, Union
1617

1718
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
1819
from pytorch_lightning.utilities import _RICH_AVAILABLE
1920

21+
Style = None
2022
if _RICH_AVAILABLE:
2123
from rich.console import Console, RenderableType
22-
from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn
24+
from rich.progress import BarColumn, Progress, ProgressColumn, TextColumn
25+
from rich.style import Style
2326
from rich.text import Text
2427

2528
class CustomTimeColumn(ProgressColumn):
2629

2730
# Only refresh twice a second to prevent jitter
2831
max_refresh = 0.5
2932

33+
def __init__(self, style: Union[str, Style]) -> None:
34+
self.style = style
35+
super().__init__()
36+
3037
def render(self, task) -> Text:
3138
elapsed = task.finished_time if task.finished else task.elapsed
3239
remaining = task.time_remaining
3340
elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed)))
3441
remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining)))
35-
return Text.from_markup(f"[progress.elapsed]{elapsed_delta} < [progress.remaining]{remaining_delta}")
42+
return Text(f"{elapsed_delta} {remaining_delta}", style=self.style)
3643

3744
class BatchesProcessedColumn(ProgressColumn):
45+
def __init__(self, style: Union[str, Style]):
46+
self.style = style
47+
super().__init__()
48+
3849
def render(self, task) -> RenderableType:
39-
return Text.from_markup(f"[magenta] {int(task.completed)}/{task.total}")
50+
return Text(f"{int(task.completed)}/{task.total}", style=self.style)
4051

4152
class ProcessingSpeedColumn(ProgressColumn):
53+
def __init__(self, style: Union[str, Style]):
54+
self.style = style
55+
super().__init__()
56+
4257
def render(self, task) -> RenderableType:
4358
task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00"
44-
return Text.from_markup(f"[progress.data.speed] {task_speed}it/s")
59+
return Text(f"{task_speed}it/s", style=self.style)
4560

4661
class MetricsTextColumn(ProgressColumn):
4762
"""A column containing text."""
@@ -71,19 +86,26 @@ def render(self, task) -> Text:
7186
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
7287
else:
7388
metrics = self._trainer.progress_bar_metrics
89+
7490
for k, v in metrics.items():
7591
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
7692
text = Text.from_markup(_text, style=None, justify="left")
7793
return text
7894

7995

80-
STYLES: Dict[str, str] = {
81-
"train": "red",
82-
"sanity_check": "yellow",
83-
"validate": "yellow",
84-
"test": "yellow",
85-
"predict": "yellow",
86-
}
96+
@dataclass
97+
class RichProgressBarTheme:
98+
"""Styles to associate to different base components.
99+
100+
https://rich.readthedocs.io/en/stable/style.html
101+
"""
102+
103+
text_color: str = "white"
104+
progress_bar_complete: Union[str, Style] = "#6206E0"
105+
progress_bar_finished: Union[str, Style] = "#6206E0"
106+
batch_process: str = "white"
107+
time: str = "grey54"
108+
processing_speed: str = "grey70"
87109

88110

89111
class RichProgressBar(ProgressBarBase):
@@ -104,13 +126,18 @@ class RichProgressBar(ProgressBarBase):
104126
105127
Args:
106128
refresh_rate: the number of updates per second, must be strictly positive
129+
theme: Contains styles used to stylize the progress bar.
107130
108131
Raises:
109132
ImportError:
110133
If required `rich` package is not installed on the device.
111134
"""
112135

113-
def __init__(self, refresh_rate: float = 1.0):
136+
def __init__(
137+
self,
138+
refresh_rate: float = 1.0,
139+
theme: RichProgressBarTheme = RichProgressBarTheme(),
140+
) -> None:
114141
if not _RICH_AVAILABLE:
115142
raise ImportError(
116143
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
@@ -126,6 +153,7 @@ def __init__(self, refresh_rate: float = 1.0):
126153
self.test_progress_bar_id: Optional[int] = None
127154
self.predict_progress_bar_id: Optional[int] = None
128155
self.console = Console(record=True)
156+
self.theme = theme
129157

130158
@property
131159
def refresh_rate(self) -> int:
@@ -147,39 +175,36 @@ def enable(self) -> None:
147175

148176
@property
149177
def sanity_check_description(self) -> str:
150-
return "[Validation Sanity Check]"
178+
return "Validation Sanity Check"
151179

152180
@property
153181
def validation_description(self) -> str:
154-
return "[Validation]"
182+
return "Validation"
155183

156184
@property
157185
def test_description(self) -> str:
158-
return "[Testing]"
186+
return "Testing"
159187

160188
@property
161189
def predict_description(self) -> str:
162-
return "[Predicting]"
190+
return "Predicting"
163191

164192
def setup(self, trainer, pl_module, stage):
165193
self.progress = Progress(
166-
SpinnerColumn(),
167194
TextColumn("[progress.description]{task.description}"),
168-
BarColumn(),
169-
BatchesProcessedColumn(),
170-
"[",
171-
CustomTimeColumn(),
172-
ProcessingSpeedColumn(),
195+
BarColumn(complete_style=self.theme.progress_bar_complete, finished_style=self.theme.progress_bar_finished),
196+
BatchesProcessedColumn(style=self.theme.batch_process),
197+
CustomTimeColumn(style=self.theme.time),
198+
ProcessingSpeedColumn(style=self.theme.processing_speed),
173199
MetricsTextColumn(trainer, pl_module, stage),
174-
"]",
175200
console=self.console,
176201
refresh_per_second=self.refresh_rate,
177202
).__enter__()
178203

179204
def on_sanity_check_start(self, trainer, pl_module):
180205
super().on_sanity_check_start(trainer, pl_module)
181206
self.val_sanity_progress_bar_id = self.progress.add_task(
182-
f"[{STYLES['sanity_check']}]{self.sanity_check_description}",
207+
f"[{self.theme.text_color}]{self.sanity_check_description}",
183208
total=trainer.num_sanity_val_steps,
184209
)
185210

@@ -201,15 +226,15 @@ def on_train_epoch_start(self, trainer, pl_module):
201226
train_description = self._get_train_description(trainer.current_epoch)
202227

203228
self.main_progress_bar_id = self.progress.add_task(
204-
f"[{STYLES['train']}]{train_description}",
229+
f"[{self.theme.text_color}]{train_description}",
205230
total=total_batches,
206231
)
207232

208233
def on_validation_epoch_start(self, trainer, pl_module):
209234
super().on_validation_epoch_start(trainer, pl_module)
210235
if self._total_val_batches > 0:
211236
self.val_progress_bar_id = self.progress.add_task(
212-
f"[{STYLES['validate']}]{self.validation_description}",
237+
f"[{self.theme.text_color}]{self.validation_description}",
213238
total=self._total_val_batches,
214239
)
215240

@@ -221,14 +246,14 @@ def on_validation_epoch_end(self, trainer, pl_module):
221246
def on_test_epoch_start(self, trainer, pl_module):
222247
super().on_train_epoch_start(trainer, pl_module)
223248
self.test_progress_bar_id = self.progress.add_task(
224-
f"[{STYLES['test']}]{self.test_description}",
249+
f"[{self.theme.text_color}]{self.test_description}",
225250
total=self.total_test_batches,
226251
)
227252

228253
def on_predict_epoch_start(self, trainer, pl_module):
229254
super().on_predict_epoch_start(trainer, pl_module)
230255
self.predict_progress_bar_id = self.progress.add_task(
231-
f"[{STYLES['predict']}]{self.predict_description}",
256+
f"[{self.theme.text_color}]{self.predict_description}",
232257
total=self.total_predict_batches,
233258
)
234259

@@ -261,7 +286,7 @@ def _should_update(self, current, total) -> bool:
261286
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
262287

263288
def _get_train_description(self, current_epoch: int) -> str:
264-
train_description = f"[Epoch {current_epoch}]"
289+
train_description = f"Epoch {current_epoch}"
265290
if len(self.validation_description) > len(train_description):
266291
# Padding is required to avoid flickering due of uneven lengths of "Epoch X"
267292
# and "Validation" Bar description
@@ -273,3 +298,7 @@ def _get_train_description(self, current_epoch: int) -> str:
273298

274299
def teardown(self, trainer, pl_module, stage):
275300
self.progress.__exit__(None, None, None)
301+
302+
def on_exception(self, trainer, pl_module, exception: BaseException) -> None:
303+
if isinstance(exception, KeyboardInterrupt):
304+
self.progress.stop()

tests/callbacks/test_rich_progress_bar.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from unittest import mock
15+
from unittest.mock import DEFAULT
1516

1617
import pytest
1718

1819
from pytorch_lightning import Trainer
1920
from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar
21+
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
2022
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
2123
from tests.helpers.boring_model import BoringModel
2224
from tests.helpers.runif import RunIf
2325

2426

2527
@RunIf(rich=True)
2628
def test_rich_progress_bar_callback():
27-
2829
trainer = Trainer(callbacks=RichProgressBar())
2930

3031
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
@@ -36,7 +37,6 @@ def test_rich_progress_bar_callback():
3637
@RunIf(rich=True)
3738
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
3839
def test_rich_progress_bar(progress_update, tmpdir):
39-
4040
model = BoringModel()
4141

4242
trainer = Trainer(
@@ -58,7 +58,61 @@ def test_rich_progress_bar(progress_update, tmpdir):
5858

5959

6060
def test_rich_progress_bar_import_error():
61-
6261
if not _RICH_AVAILABLE:
6362
with pytest.raises(ImportError, match="`RichProgressBar` requires `rich` to be installed."):
6463
Trainer(callbacks=RichProgressBar())
64+
65+
66+
@RunIf(rich=True)
67+
def test_rich_progress_bar_custom_theme(tmpdir):
68+
"""Test to ensure that custom theme styles are used."""
69+
with mock.patch.multiple(
70+
"pytorch_lightning.callbacks.progress.rich_progress",
71+
BarColumn=DEFAULT,
72+
BatchesProcessedColumn=DEFAULT,
73+
CustomTimeColumn=DEFAULT,
74+
ProcessingSpeedColumn=DEFAULT,
75+
) as mocks:
76+
77+
theme = RichProgressBarTheme()
78+
79+
progress_bar = RichProgressBar(theme=theme)
80+
progress_bar.setup(Trainer(tmpdir), BoringModel(), stage=None)
81+
82+
assert progress_bar.theme == theme
83+
args, kwargs = mocks["BarColumn"].call_args
84+
assert kwargs["complete_style"] == theme.progress_bar_complete
85+
assert kwargs["finished_style"] == theme.progress_bar_finished
86+
87+
args, kwargs = mocks["BatchesProcessedColumn"].call_args
88+
assert kwargs["style"] == theme.batch_process
89+
90+
args, kwargs = mocks["CustomTimeColumn"].call_args
91+
assert kwargs["style"] == theme.time
92+
93+
args, kwargs = mocks["ProcessingSpeedColumn"].call_args
94+
assert kwargs["style"] == theme.processing_speed
95+
96+
97+
@RunIf(rich=True)
98+
def test_rich_progress_bar_keyboard_interrupt(tmpdir):
99+
"""Test to ensure that when the user keyboard interrupts, we close the progress bar."""
100+
101+
class TestModel(BoringModel):
102+
def on_train_start(self) -> None:
103+
raise KeyboardInterrupt
104+
105+
model = TestModel()
106+
107+
with mock.patch(
108+
"pytorch_lightning.callbacks.progress.rich_progress.Progress.stop", autospec=True
109+
) as mock_progress_stop:
110+
progress_bar = RichProgressBar()
111+
trainer = Trainer(
112+
default_root_dir=tmpdir,
113+
fast_dev_run=True,
114+
callbacks=progress_bar,
115+
)
116+
117+
trainer.fit(model)
118+
mock_progress_stop.assert_called_once()

0 commit comments

Comments
 (0)