Skip to content

Commit a1b0152

Browse files
committed
update
1 parent c379bed commit a1b0152

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from dataclasses import dataclass
1616
from datetime import timedelta
1717
from typing import Any, Dict, Optional, Union
18+
import copy
1819

1920
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
2021
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -211,7 +212,7 @@ class RichProgressBar(ProgressBarBase):
211212
Set it to ``0`` to disable the display.
212213
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
213214
theme: Contains styles used to stylize the progress bar.
214-
console_kwargs: Args for constructing a `Console`
215+
console_kwargs: Args for constructing a `Console` or a `Console` object
215216
216217
Raises:
217218
ModuleNotFoundError:
@@ -228,7 +229,7 @@ def __init__(
228229
refresh_rate: int = 1,
229230
leave: bool = False,
230231
theme: RichProgressBarTheme = RichProgressBarTheme(),
231-
console_kwargs: Optional[Dict[str, Any], Console] = None,
232+
console_kwargs: Optional[Union[Dict[str, Any], Console]] = None,
232233
) -> None:
233234
if not _RICH_AVAILABLE:
234235
raise MisconfigurationException(
@@ -284,7 +285,10 @@ def predict_description(self) -> str:
284285
def _init_progress(self, trainer):
285286
if self.is_enabled and (self.progress is None or self._progress_stopped):
286287
self._reset_progress_bar_ids()
287-
self._console = Console(**self._console_kwargs)
288+
if isinstance(self._console_kwargs, Console):
289+
self._console = copy.deepcopy(self._console_kwargs)
290+
else:
291+
self._console = Console(**self._console_kwargs)
288292
self._console.clear_live()
289293
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
290294
self.progress = CustomProgress(

0 commit comments

Comments
 (0)