Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
e079f2f
By default enable rich progress bar if rich is available
Sep 22, 2021
c2374a1
Improvements for rich progress
Sep 22, 2021
276ea83
Add CHANGELOG.md
Sep 22, 2021
1a0d7d8
Merge branch 'master' into feat/enable_rich_default
Sep 22, 2021
254d361
Merge branch 'master' into feat/enable_rich_default
Sep 23, 2021
3205972
Update pytorch_lightning/trainer/connectors/callback_connector.py
Sep 23, 2021
20dc085
Improvements for rich progress
Sep 23, 2021
00ee0d2
Merge remote-tracking branch 'origin/feat/enable_rich_default' into f…
Sep 23, 2021
96ce411
Update tests/callbacks/test_rich_progress_bar.py
Sep 23, 2021
50d56ae
Add check
Sep 23, 2021
ae8bf33
Improvements for rich progress
Sep 23, 2021
2082873
Handle DDP
Sep 23, 2021
6a81e26
Merge branch 'master' into feat/enable_rich_default
Sep 23, 2021
41786be
Update pytorch_lightning/trainer/connectors/callback_connector.py
Sep 23, 2021
cffc91d
Add comments
Sep 23, 2021
79f0c8f
Cleanups, fix test for example sizes
Sep 24, 2021
7735fc6
Add comment
Sep 24, 2021
b7403ae
Small logic fixes
Sep 24, 2021
57ee691
Fix name
Sep 24, 2021
9f0115c
Improvements for rich progress
Sep 27, 2021
90cbe1e
Fix tests
Sep 27, 2021
9210d41
Merge branch 'master' into feat/enable_rich_default
Sep 27, 2021
56b9000
Set variable correctly
Sep 27, 2021
41ade74
Cleanup terminate
Sep 27, 2021
df327db
Improvements for rich progress
Sep 27, 2021
e0b1365
Make sure to clean up process
Sep 27, 2021
24ea978
Improvements for rich progress
Sep 28, 2021
a29401f
Improvements for rich progress
Sep 28, 2021
9ce2795
Fixes
Sep 28, 2021
5ce97f2
Merge branch 'master' into feat/enable_rich_default
Sep 28, 2021
baad9ac
Merge branch 'master' into feat/enable_rich_default
Sep 29, 2021
377ae9c
Merge branch 'master' into feat/enable_rich_default
Sep 29, 2021
0a4c60a
Update message
Sep 29, 2021
24f1b3c
Merge branch 'master' into feat/enable_rich_default
Sep 29, 2021
e931689
Fix pulse bar
Sep 30, 2021
80725a2
Merge branch 'master' into feat/enable_rich_default
Sep 30, 2021
b8bfef8
Merge branch 'master' into feat/enable_rich_default
Oct 5, 2021
94b70c2
Attempt to fix tests
Oct 5, 2021
0377894
temp
Oct 5, 2021
129b7d4
Revert "temp"
Oct 5, 2021
adcd780
Remove enable
Oct 5, 2021
1b1ed94
delay console init
awaelchli Oct 5, 2021
5d8cb30
isenabled
Oct 5, 2021
fe68e78
Clear out from state
Oct 5, 2021
3d7827e
Fix var name
Oct 5, 2021
0c7559e
Turn off progress bar for everything
Oct 6, 2021
1c7a8af
Remove checks, enable rich object
Oct 6, 2021
ed41233
Revert "Remove checks, enable rich object"
Oct 6, 2021
c969f61
Remove console
Oct 6, 2021
008b9f0
Clear update
Oct 6, 2021
4be11db
Remove more
Oct 6, 2021
3207a7c
Attempt
Oct 6, 2021
184df4b
verify non-spawn support
awaelchli Oct 12, 2021
ece1f8d
Merge branch 'master' into feat/enable_rich_default
awaelchli Oct 12, 2021
1ca117f
Revert "Remove more"
Oct 12, 2021
f176fc6
Revert "Clear update"
Oct 12, 2021
13fbf4b
Revert "Remove console"
Oct 12, 2021
cc74594
Various reverts from debugging
Oct 12, 2021
ef41642
Remove bool
Oct 12, 2021
a6ecb5f
Address code review
Oct 13, 2021
71adf49
Drop metrics to see if code passes
Oct 13, 2021
740cee3
Merge branch 'master' into feat/enable_rich_default
Oct 18, 2021
da93fa0
re-introduce code
Oct 18, 2021
a737b5a
Drop case
Oct 20, 2021
3e6579f
Clean up code by keeping reference to progress bar in metrics column
Oct 20, 2021
46808e8
Update pytorch_lightning/trainer/connectors/callback_connector.py
Oct 20, 2021
b560615
Add prints
Oct 20, 2021
91245f5
skip test
kaushikb11 Oct 21, 2021
ebc0d17
Update tests/callbacks/test_rich_progress_bar.py
kaushikb11 Oct 22, 2021
dd10585
move to special tests
tchaton Oct 22, 2021
f4211d4
Attempt a new fix
Oct 25, 2021
432d68b
Merge branch 'master' into feat/enable_rich_default
Oct 25, 2021
761df3f
debug
Oct 25, 2021
72f3d06
Cleanup
Oct 25, 2021
ba16ff9
Cleanup
Oct 25, 2021
1c18df1
Update tests/models/test_grad_norm.py
Oct 25, 2021
785d54f
Remove
Oct 25, 2021
d5d0d02
Merge branch 'master' into feat/enable_rich_default
Oct 25, 2021
ec090c0
add note
Oct 25, 2021
f05ad1a
Apply suggestions from code review
Borda Oct 25, 2021
377424f
Include ddp spawn guard again
Oct 25, 2021
17fafc5
Merge branch 'master' into feat/enable_rich_default
awaelchli Oct 26, 2021
9eb1ed8
See if turning of auto refresh fixes test
Oct 27, 2021
463bde2
Drop check
Oct 27, 2021
8f10fa9
Merge branch 'master' into feat/enable_rich_default
Nov 1, 2021
21c6a83
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2021
75fff92
merge fixes
Nov 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 74 additions & 45 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def render(self, task: "Task") -> ProgressBar:
total=max(0, task.total),
completed=max(0, task.completed),
width=None if self.bar_width is None else max(1, self.bar_width),
pulse=not task.started or math.isfinite(task.remaining),
pulse=not task.started or not math.isfinite(task.remaining),
animation_time=task.get_time(),
style=self.style,
complete_style=self.complete_style,
Expand Down Expand Up @@ -129,13 +129,19 @@ def render(self, task) -> RenderableType:
class MetricsTextColumn(ProgressColumn):
"""A column containing text."""

def __init__(self, trainer, pl_module):
def __init__(self, trainer):
self._trainer = trainer
self._pl_module = pl_module
self._tasks = {}
self._current_task_id = 0
self.metrics = {}
super().__init__()

def update(self, metrics):
# called when metrics are ready to be rendered.
# this is due to preventing render from causing deadlock issues by requesting metrics
# in separate thread.
self.metrics = metrics

def render(self, task) -> Text:
from pytorch_lightning.trainer.states import TrainerFn

Expand All @@ -149,14 +155,8 @@ def render(self, task) -> Text:
if self._trainer.training and task.id != self._current_task_id:
return self._tasks[task.id]
_text = ""
# TODO(@daniellepintz): make this code cleaner
progress_bar_callback = getattr(self._trainer, "progress_bar_callback", None)
if progress_bar_callback:
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
else:
metrics = self._trainer.progress_bar_metrics

for k, v in metrics.items():

for k, v in self.metrics.items():
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
return Text(_text, justify="left")

Expand Down Expand Up @@ -194,7 +194,11 @@ class RichProgressBar(ProgressBarBase):
trainer = Trainer(callbacks=RichProgressBar())

Args:
refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled.
refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
Set it to ``0`` to disable the display. By default, the :class:`~pytorch_lightning.trainer.trainer.Trainer`
uses this implementation of the progress bar and sets the refresh rate to the value provided to the
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
theme: Contains styles used to stylize the progress bar.

Raises:
Expand All @@ -204,35 +208,30 @@ class RichProgressBar(ProgressBarBase):

def __init__(
self,
refresh_rate_per_second: int = 10,
refresh_rate: int = 1,
theme: RichProgressBarTheme = RichProgressBarTheme(),
) -> None:
if not _RICH_AVAILABLE:
raise ModuleNotFoundError(
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`."
)
super().__init__()
self._refresh_rate_per_second: int = refresh_rate_per_second
self._refresh_rate: int = refresh_rate
self._enabled: bool = True
self.progress: Optional[Progress] = None
self.val_sanity_progress_bar_id: Optional[int] = None
self._reset_progress_bar_ids()
self._metric_component = None
self._progress_stopped: bool = False
self.theme = theme
self._console: Console = Console()

@property
def refresh_rate_per_second(self) -> float:
"""Refresh rate for Rich Progress.

Returns: Refresh rate for Progress Bar.
Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress).
"""
return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1
def refresh_rate(self) -> float:
return self._refresh_rate

@property
def is_enabled(self) -> bool:
return self._enabled and self._refresh_rate_per_second > 0
return self._enabled and self.refresh_rate > 0

@property
def is_disabled(self) -> bool:
Expand Down Expand Up @@ -260,10 +259,12 @@ def test_description(self) -> str:
def predict_description(self) -> str:
return "Predicting"

def _init_progress(self, trainer, pl_module):
if self.progress is None or self._progress_stopped:
def _init_progress(self, trainer):
if self.is_enabled and (self.progress is None or self._progress_stopped):
self._reset_progress_bar_ids()
self._console: Console = Console()
self._console.clear_live()
self._metric_component = MetricsTextColumn(trainer)
self.progress = CustomProgress(
TextColumn("[progress.description]{task.description}"),
CustomBarColumn(
Expand All @@ -274,51 +275,56 @@ def _init_progress(self, trainer, pl_module):
BatchesProcessedColumn(style=self.theme.batch_process),
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
MetricsTextColumn(trainer, pl_module),
refresh_per_second=self.refresh_rate_per_second,
self._metric_component,
auto_refresh=False,
disable=self.is_disabled,
console=self._console,
)
self.progress.start()
# progress has started
self._progress_stopped = False

def refresh(self):
if self.progress:
self.progress.refresh()

def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)

def on_predict_start(self, trainer, pl_module):
super().on_predict_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)

def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)

def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)

def __getstate__(self):
# can't pickle the rich progress objects
state = self.__dict__.copy()
state["progress"] = None
state["_console"] = None
state["progress"] = None
return state

def __setstate__(self, state):
self.__dict__ = state
# reset console reference after loading progress
self._console = Console()
state["_console"] = Console()

def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self._init_progress(trainer, pl_module)
self._init_progress(trainer)
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
self.refresh()

def on_sanity_check_end(self, trainer, pl_module):
super().on_sanity_check_end(trainer, pl_module)
self._update(self.val_sanity_progress_bar_id, visible=False)
self.refresh()

def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
Expand All @@ -334,7 +340,9 @@ def on_train_epoch_start(self, trainer, pl_module):
train_description = self._get_train_description(trainer.current_epoch)
if self.main_progress_bar_id is None:
self.main_progress_bar_id = self._add_task(total_batches, train_description)
self.progress.reset(self.main_progress_bar_id, total=total_batches, description=train_description)
if self.progress is not None:
self.progress.reset(self.main_progress_bar_id, total=total_batches, description=train_description)
self.refresh()

def on_validation_epoch_start(self, trainer, pl_module):
super().on_validation_epoch_start(trainer, pl_module)
Expand All @@ -345,51 +353,67 @@ def on_validation_epoch_start(self, trainer, pl_module):
val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch
total_val_batches = self.total_val_batches * val_checks_per_epoch
self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False)
self.refresh()

def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
if self.progress is not None:
return self.progress.add_task(
f"[{self.theme.text_color}]{description}", total=total_batches, visible=visible
)

def _update(self, progress_bar_id: int, visible: bool = True) -> None:
if self.progress is not None:
def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None:
if self.progress is not None and self._should_update(current, total):
self.progress.update(progress_bar_id, advance=1.0, visible=visible)
self.refresh()

def _should_update(self, current, total) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def on_validation_epoch_end(self, trainer, pl_module):
super().on_validation_epoch_end(trainer, pl_module)
if self.val_progress_bar_id is not None:
self._update(self.val_progress_bar_id, visible=False)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False)

def on_validation_end(self, trainer, pl_module) -> None:
super().on_validation_end(trainer, pl_module)
self._update_metrics(trainer, pl_module)

def on_test_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
self.refresh()

def on_predict_epoch_start(self, trainer, pl_module):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
self.refresh()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
self._update(self.main_progress_bar_id)
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
self._update_metrics(trainer, pl_module)
self.refresh()

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if trainer.sanity_checking:
self._update(self.val_sanity_progress_bar_id)
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches)
elif self.val_progress_bar_id is not None:
# check to see if we should update the main training progress bar
if self.main_progress_bar_id is not None:
self._update(self.main_progress_bar_id)
self._update(self.val_progress_bar_id)
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches)
self.refresh()

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
self._update(self.test_progress_bar_id)
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches)
self.refresh()

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
self._update(self.predict_progress_bar_id)
self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches)
self.refresh()

def _get_train_description(self, current_epoch: int) -> str:
train_description = f"Epoch {current_epoch}"
Expand All @@ -414,6 +438,11 @@ def _reset_progress_bar_ids(self):
self.test_progress_bar_id: Optional[int] = None
self.predict_progress_bar_id: Optional[int] = None

def _update_metrics(self, trainer, pl_module) -> None:
metrics = self.get_metrics(trainer, pl_module)
if self._metric_component:
self._metric_component.update(metrics)

def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None:
self._stop_progress()

Expand Down
38 changes: 25 additions & 13 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
)
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info
from pytorch_lightning.utilities import _RICH_AVAILABLE, ModelSummaryMode, rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn


class CallbackConnector:
Expand Down Expand Up @@ -216,26 +216,38 @@ def _configure_swa_callbacks(self):
self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks

def configure_progress_bar(self, refresh_rate=None, process_position=0):
if os.getenv("COLAB_GPU") and refresh_rate is None:
# smaller refresh rate on colab causes crashes, choose a higher value
refresh_rate = 20
refresh_rate = 1 if refresh_rate is None else refresh_rate

# if progress bar callback already exists return it
progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)]
if len(progress_bars) > 1:
raise MisconfigurationException(
"You added multiple progress bar callbacks to the Trainer, but currently only one"
" progress bar is supported."
)
if len(progress_bars) == 1:
progress_bar_callback = progress_bars[0]
elif refresh_rate > 0:
return progress_bars[0]
# check if progress bar has been turned off (i.e refresh_rate == 0)
if refresh_rate == 0:
return
# if Rich is available and refresh_rate is None return Rich ProgressBar
if _RICH_AVAILABLE:
if refresh_rate is None:
progress_bar_callback = RichProgressBar()
self.trainer.callbacks.append(progress_bar_callback)
return progress_bar_callback
rank_zero_warn(
"`RichProgressBar` does not support setting the refresh rate via the Trainer."
" If you'd like to change the refresh rate and continue using the `RichProgressBar`,"
" please pass `callbacks=RichProgressBar(refresh_rate=X)`."
" Setting to the `TQDM ProgressBar`."
)
# else return new TQDMProgressBar
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be indented in?

if os.getenv("COLAB_GPU") and refresh_rate is None:
# smaller refresh rate on colab causes crashes for TQDM, choose a higher value
refresh_rate = 20
refresh_rate = 1 if refresh_rate is None else refresh_rate
progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position)
self.trainer.callbacks.append(progress_bar_callback)
else:
progress_bar_callback = None

return progress_bar_callback
return progress_bar_callback

def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None:
if max_time is None:
Expand Down
Loading