Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar
from pytorch_lightning.utilities import _RICH_AVAILABLE

__all__ = [
'Callback',
Expand All @@ -32,3 +33,7 @@
'ProgressBar',
'ProgressBarBase',
]


if _RICH_AVAILABLE:
__all__ += ["RichProgressBar"]
150 changes: 150 additions & 0 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import importlib
import sys

import torch

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
if importlib.util.find_spec('ipywidgets') is not None:
Expand All @@ -30,6 +32,13 @@
from tqdm import tqdm

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import _RICH_AVAILABLE

if _RICH_AVAILABLE:
from rich.console import Console
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
from rich.table import Table
from rich.text import Text


class ProgressBarBase(Callback):
Expand Down Expand Up @@ -387,6 +396,147 @@ def _update_bar(self, bar):
bar.update(delta)


RichProgressBar = None


if _RICH_AVAILABLE:

class MetricsTextColumn(TextColumn):
"""A column containing text."""

def __init__(self, trainer):
self._trainer = trainer
super().__init__("")

def render(self, task) -> Text:
_text = ''
if "red" in f'{task.description}':
for k, v in self._trainer.progress_bar_dict.items():
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
if self.markup:
text = Text.from_markup(_text, style=self.style, justify=self.justify)
else:
text = Text(_text, style=self.style, justify=self.justify)
if self.highlighter:
self.highlighter.highlight(text)
return text

class RichProgressBar(ProgressBarBase):

def __init__(self, refresh_rate: int = 1, process_position: int = 0):
super().__init__()
self._refresh_rate = refresh_rate
self._process_position = process_position
self._enabled = True
self.main_progress_bar = None
self.val_progress_bar = None
self.test_progress_bar = None
self.console = Console(record=True)
self.tasks = {}

def __getstate__(self):
# can't pickle the tqdm objects
state = self.__dict__.copy()
state['main_progress_bar'] = None
state['val_progress_bar'] = None
state['test_progress_bar'] = None
return state

@property
def refresh_rate(self) -> int:
return self._refresh_rate

@property
def process_position(self) -> int:
return self._process_position

@property
def is_enabled(self) -> bool:
return self._enabled and self.refresh_rate > 0

@property
def is_disabled(self) -> bool:
return not self.is_enabled

def disable(self) -> None:
self._enabled = False

def enable(self) -> None:
self._enabled = True

def on_epoch_start(self, trainer, pl_module):
super().on_epoch_start(trainer, pl_module)
self.progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
MetricsTextColumn(trainer),
console=self.console,
transient=True,
).__enter__()
super().on_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float('inf'):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
self.total_batches = total_train_batches + total_val_batches
self.tasks["train"] = self.progress.add_task(
f"[red][Epoch {trainer.current_epoch}]",
total=self.total_batches,
)
if total_val_batches > 0:
self.tasks["val"] = self.progress.add_task(
f"[green][Epoch {trainer.current_epoch}]",
total=total_val_batches,
)

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
if getattr(self, "progress", None) is not None:
self.progress.update(self.tasks["train"], advance=1.)
self.progress.track(trainer.progress_bar_dict)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
if getattr(self, "progress", None) is not None:
self.progress.update(self.tasks["train"], advance=1.)
self.progress.update(self.tasks["val"], advance=1.)

def on_train_epoch_end(self, trainer, pl_module, *_):
super().on_train_end(trainer, pl_module)
self.progress.__exit__(None, None, None)
epoch_pbar_metrics = self.trainer.logger_connector.cached_results.get_epoch_pbar_metrics()
table = Table(show_header=True, header_style="bold magenta")
width = max([len(k) for k in epoch_pbar_metrics.keys()]) + 5
table.add_column(f"Metrics Epoch {trainer.current_epoch} ", style="dim", width=width)
table.add_column("Value")
for k, v in epoch_pbar_metrics.items():
v = round(v.item(), 4) if isinstance(v, torch.Tensor) else v
table.add_row(
k, str(v)
)
self.console.log(table)

def _should_update(self, current, total):
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def _update_bar(self, bar):
""" Updates the bar by the refresh rate without overshooting. """
if bar.total is not None:
delta = min(self.refresh_rate, bar.total - bar.n)
else:
# infinite / unknown size
delta = self.refresh_rate
if delta > 0:
bar.update(delta)


def convert_inf(x):
""" The tqdm doesn't support inf values. We have to convert it to None. """
if x == float('inf'):
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_GROUP_AVAILABLE,
_FAIRSCALE_PIPE_AVAILABLE,
_BOLTS_AVAILABLE,
_RICH_AVAILABLE,
_module_available,
)
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _module_available(module_path: str) -> bool:
return False


_RICH_AVAILABLE = _module_available("rich")
_APEX_AVAILABLE = _module_available("apex.amp")
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
Expand Down
51 changes: 47 additions & 4 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import Mock, call
from time import sleep
from unittest import mock
from unittest.mock import call, Mock

import pytest
from unittest import mock

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.utilities import _RICH_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate, BoringModel
from tests.base import BoringModel, EvalModelTemplate

if _RICH_AVAILABLE:
from pytorch_lightning.callbacks import RichProgressBar


@pytest.mark.parametrize('callbacks,refresh_rate', [
Expand Down Expand Up @@ -328,3 +333,41 @@ def test_test_progress_bar_update_amount(tmpdir, test_batches, refresh_rate, tes
)
trainer.test(model)
progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas])


@pytest.mark.skipif(not _RICH_AVAILABLE, reason="test requires rich installed")
def test_rich_progress_bar(tmpdir):
"""Test different ways the progress bar can be turned on."""

class LoggingModel(BoringModel):

def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
sleep(0.1)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
sleep(0.1)
return {"x": loss}

def test_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
sleep(0.1)
return {"y": loss}

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=RichProgressBar(),
max_epochs=2,
limit_train_batches=10,
limit_val_batches=5
)

trainer.fit(LoggingModel())