Skip to content

Commit fde326d

Browse files
quancskaushikb11tchaton
authored
make RichProgressBar more flexible with Rich.Console (#10875)
Co-authored-by: Kaushik B <[email protected]> Co-authored-by: thomas chaton <[email protected]>
1 parent 5eecdca commit fde326d

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4343
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700))
4444

4545

46+
- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))
47+
48+
4649
### Changed
4750

4851
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
@@ -111,6 +114,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
111114
- Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015))
112115

113116

117+
114118
### Deprecated
115119

116120
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import math
1515
from dataclasses import dataclass
1616
from datetime import timedelta
17-
from typing import Any, Optional, Union
17+
from typing import Any, Dict, Optional, Union
1818

1919
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
2020
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -211,6 +211,7 @@ class RichProgressBar(ProgressBarBase):
211211
Set it to ``0`` to disable the display.
212212
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
213213
theme: Contains styles used to stylize the progress bar.
214+
console_kwargs: Args for constructing a `Console`
214215
215216
Raises:
216217
ModuleNotFoundError:
@@ -227,6 +228,7 @@ def __init__(
227228
refresh_rate: int = 1,
228229
leave: bool = False,
229230
theme: RichProgressBarTheme = RichProgressBarTheme(),
231+
console_kwargs: Optional[Dict[str, Any]] = None,
230232
) -> None:
231233
if not _RICH_AVAILABLE:
232234
raise MisconfigurationException(
@@ -236,6 +238,7 @@ def __init__(
236238
super().__init__()
237239
self._refresh_rate: int = refresh_rate
238240
self._leave: bool = leave
241+
self._console_kwargs = console_kwargs or {}
239242
self._enabled: bool = True
240243
self.progress: Optional[Progress] = None
241244
self.val_sanity_progress_bar_id: Optional[int] = None
@@ -281,7 +284,7 @@ def predict_description(self) -> str:
281284
def _init_progress(self, trainer):
282285
if self.is_enabled and (self.progress is None or self._progress_stopped):
283286
self._reset_progress_bar_ids()
284-
self._console: Console = Console()
287+
self._console = Console(**self._console_kwargs)
285288
self._console.clear_live()
286289
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
287290
self.progress = CustomProgress(
@@ -324,7 +327,7 @@ def __getstate__(self):
324327

325328
def __setstate__(self, state):
326329
self.__dict__ = state
327-
state["_console"] = Console()
330+
self._console = Console(**self._console_kwargs)
328331

329332
def on_sanity_check_start(self, trainer, pl_module):
330333
super().on_sanity_check_start(trainer, pl_module)

0 commit comments

Comments
 (0)