From fefa23dc585aaa20f07232aac41c1a28b46e6750 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Sun, 15 Aug 2021 16:50:28 +0530 Subject: [PATCH 01/22] Add Rich Progress Bar --- pytorch_lightning/callbacks/__init__.py | 3 +- .../callbacks/progress/__init__.py | 23 ++ pytorch_lightning/callbacks/progress/base.py | 181 ++++++++++ .../callbacks/progress/progress.py | 340 ++++++++++++++++++ .../callbacks/progress/rich_progress.py | 117 ++++++ pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/imports.py | 1 + 7 files changed, 665 insertions(+), 1 deletion(-) create mode 100644 pytorch_lightning/callbacks/progress/__init__.py create mode 100644 pytorch_lightning/callbacks/progress/base.py create mode 100644 pytorch_lightning/callbacks/progress/progress.py create mode 100644 pytorch_lightning/callbacks/progress/rich_progress.py diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index 23601d34b8e20..d2c405b5c2d10 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -20,7 +20,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter -from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar from pytorch_lightning.callbacks.pruning import ModelPruning from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging @@ -45,4 +45,5 @@ "QuantizationAwareTraining", "StochasticWeightAveraging", "Timer", + "RichProgressBar", ] diff --git a/pytorch_lightning/callbacks/progress/__init__.py b/pytorch_lightning/callbacks/progress/__init__.py new file mode 100644 index 0000000000000..2807f009a4ed8 --- /dev/null +++ b/pytorch_lightning/callbacks/progress/__init__.py @@ -0,0 +1,23 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Progress Bars +============= + +Use or override one of the progress bar callbacks. + +""" +from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401 +from pytorch_lightning.callbacks.progress.progress import ProgressBar # noqa: F401 +from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401 diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py new file mode 100644 index 0000000000000..db1de97a2291f --- /dev/null +++ b/pytorch_lightning/callbacks/progress/base.py @@ -0,0 +1,181 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.callbacks import Callback + + +class ProgressBarBase(Callback): + r""" + The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback` + that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. + You should implement your highly custom progress bars with this as the base class. + + Example:: + + class LitProgressBar(ProgressBarBase): + + def __init__(self): + super().__init__() # don't forget this :) + self.enable = True + + def disable(self): + self.enable = False + + def on_train_batch_end(self, trainer, pl_module, outputs): + super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :) + percent = (self.train_batch_idx / self.total_train_batches) * 100 + sys.stdout.flush() + sys.stdout.write(f'{percent:.01f} percent complete \r') + + bar = LitProgressBar() + trainer = Trainer(callbacks=[bar]) + + """ + + def __init__(self): + + self._trainer = None + self._train_batch_idx = 0 + self._val_batch_idx = 0 + self._test_batch_idx = 0 + self._predict_batch_idx = 0 + + @property + def trainer(self): + return self._trainer + + @property + def train_batch_idx(self) -> int: + """ + The current batch index being processed during training. + Use this to update your progress bar. + """ + return self._train_batch_idx + + @property + def val_batch_idx(self) -> int: + """ + The current batch index being processed during validation. + Use this to update your progress bar. + """ + return self._val_batch_idx + + @property + def test_batch_idx(self) -> int: + """ + The current batch index being processed during testing. + Use this to update your progress bar. + """ + return self._test_batch_idx + + @property + def predict_batch_idx(self) -> int: + """ + The current batch index being processed during predicting. + Use this to update your progress bar. + """ + return self._predict_batch_idx + + @property + def total_train_batches(self) -> int: + """ + The total number of training batches during training, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + training dataloader is of infinite size. + """ + return self.trainer.num_training_batches + + @property + def total_val_batches(self) -> int: + """ + The total number of validation batches during validation, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + validation dataloader is of infinite size. + """ + total_val_batches = 0 + if self.trainer.enable_validation: + is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 + + return total_val_batches + + @property + def total_test_batches(self) -> int: + """ + The total number of testing batches during testing, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + test dataloader is of infinite size. + """ + return sum(self.trainer.num_test_batches) + + @property + def total_predict_batches(self) -> int: + """ + The total number of predicting batches during testing, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + predict dataloader is of infinite size. + """ + return sum(self.trainer.num_predict_batches) + + def disable(self): + """ + You should provide a way to disable the progress bar. + The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the + output on processes that have a rank different from 0, e.g., in multi-node training. + """ + raise NotImplementedError + + def enable(self): + """ + You should provide a way to enable the progress bar. + The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training + routines like the :ref:`learning rate finder ` + to temporarily enable and disable the main progress bar. + """ + raise NotImplementedError + + def print(self, *args, **kwargs): + """ + You should provide a way to print without breaking the progress bar. + """ + print(*args, **kwargs) + + def on_init_end(self, trainer): + self._trainer = trainer + + def on_train_start(self, trainer, pl_module): + self._train_batch_idx = trainer.fit_loop.batch_idx + + def on_train_epoch_start(self, trainer, pl_module): + self._train_batch_idx = 0 + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self._train_batch_idx += 1 + + def on_validation_start(self, trainer, pl_module): + self._val_batch_idx = 0 + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self._val_batch_idx += 1 + + def on_test_start(self, trainer, pl_module): + self._test_batch_idx = 0 + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self._test_batch_idx += 1 + + def on_predict_epoch_start(self, trainer, pl_module): + self._predict_batch_idx = 0 + + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self._predict_batch_idx += 1 diff --git a/pytorch_lightning/callbacks/progress/progress.py b/pytorch_lightning/callbacks/progress/progress.py new file mode 100644 index 0000000000000..e9834cee81d7a --- /dev/null +++ b/pytorch_lightning/callbacks/progress/progress.py @@ -0,0 +1,340 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import io +import math +import os +import sys +from typing import Optional, Union + +# check if ipywidgets is installed before importing tqdm.auto +# to ensure it won't fail and a progress bar is displayed +if importlib.util.find_spec("ipywidgets") is not None: + from tqdm.auto import tqdm as _tqdm +else: + from tqdm import tqdm as _tqdm + +from pytorch_lightning.callbacks.progress.base import ProgressBarBase + +_PAD_SIZE = 5 + + +class tqdm(_tqdm): + """ + Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering + """ + + @staticmethod + def format_num(n) -> str: + """Add additional padding to the formatted numbers""" + should_be_padded = isinstance(n, (float, str)) + if not isinstance(n, str): + n = _tqdm.format_num(n) + if should_be_padded and "e" not in n: + if "." not in n and len(n) < _PAD_SIZE: + try: + _ = float(n) + except ValueError: + return n + n += "." + n += "0" * (_PAD_SIZE - len(n)) + return n + + +class ProgressBar(ProgressBarBase): + r""" + This is the default progress bar used by Lightning. It prints to `stdout` using the + :mod:`tqdm` package and shows up to four different bars: + + - **sanity check progress:** the progress during the sanity check run + - **main progress:** shows training + validation progress combined. It also accounts for + multiple validation runs during training when + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used. + - **validation progress:** only visible during validation; + shows total progress over all validation datasets. + - **test progress:** only active when testing; shows total progress over all test datasets. + + For infinite datasets, the progress bar never ends. + + If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override + specific methods of the callback class and pass your custom implementation to the + :class:`~pytorch_lightning.trainer.trainer.Trainer`: + + Example:: + + class LitProgressBar(ProgressBar): + + def init_validation_tqdm(self): + bar = super().init_validation_tqdm() + bar.set_description('running validation ...') + return bar + + bar = LitProgressBar() + trainer = Trainer(callbacks=[bar]) + + Args: + refresh_rate: + Determines at which rate (in number of batches) the progress bars get updated. + Set it to ``0`` to disable the display. By default, the + :class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress + bar and sets the refresh rate to the value provided to the + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the + :class:`~pytorch_lightning.trainer.trainer.Trainer`. + process_position: + Set this to a value greater than ``0`` to offset the progress bars by this many lines. + This is useful when you have progress bars defined elsewhere and want to show all of them + together. This corresponds to + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the + :class:`~pytorch_lightning.trainer.trainer.Trainer`. + + """ + + def __init__(self, refresh_rate: int = 1, process_position: int = 0): + super().__init__() + self._refresh_rate = refresh_rate + self._process_position = process_position + self._enabled = True + self.main_progress_bar = None + self.val_progress_bar = None + self.test_progress_bar = None + self.predict_progress_bar = None + + def __getstate__(self): + # can't pickle the tqdm objects + state = self.__dict__.copy() + state["main_progress_bar"] = None + state["val_progress_bar"] = None + state["test_progress_bar"] = None + state["predict_progress_bar"] = None + return state + + @property + def refresh_rate(self) -> int: + return self._refresh_rate + + @property + def process_position(self) -> int: + return self._process_position + + @property + def is_enabled(self) -> bool: + return self._enabled and self.refresh_rate > 0 + + @property + def is_disabled(self) -> bool: + return not self.is_enabled + + def disable(self) -> None: + self._enabled = False + + def enable(self) -> None: + self._enabled = True + + def init_sanity_tqdm(self) -> tqdm: + """Override this to customize the tqdm bar for the validation sanity run.""" + bar = tqdm( + desc="Validation sanity check", + position=(2 * self.process_position), + disable=self.is_disabled, + leave=False, + dynamic_ncols=True, + file=sys.stdout, + ) + return bar + + def init_train_tqdm(self) -> tqdm: + """Override this to customize the tqdm bar for training.""" + bar = tqdm( + desc="Training", + initial=self.train_batch_idx, + position=(2 * self.process_position), + disable=self.is_disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout, + smoothing=0, + ) + return bar + + def init_predict_tqdm(self) -> tqdm: + """Override this to customize the tqdm bar for predicting.""" + bar = tqdm( + desc="Predicting", + initial=self.train_batch_idx, + position=(2 * self.process_position), + disable=self.is_disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout, + smoothing=0, + ) + return bar + + def init_validation_tqdm(self) -> tqdm: + """Override this to customize the tqdm bar for validation.""" + # The main progress bar doesn't exist in `trainer.validate()` + has_main_bar = self.main_progress_bar is not None + bar = tqdm( + desc="Validating", + position=(2 * self.process_position + has_main_bar), + disable=self.is_disabled, + leave=False, + dynamic_ncols=True, + file=sys.stdout, + ) + return bar + + def init_test_tqdm(self) -> tqdm: + """Override this to customize the tqdm bar for testing.""" + bar = tqdm( + desc="Testing", + position=(2 * self.process_position), + disable=self.is_disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout, + ) + return bar + + def on_sanity_check_start(self, trainer, pl_module): + super().on_sanity_check_start(trainer, pl_module) + self.val_progress_bar = self.init_sanity_tqdm() + self.main_progress_bar = tqdm(disable=True) # dummy progress bar + + def on_sanity_check_end(self, trainer, pl_module): + super().on_sanity_check_end(trainer, pl_module) + self.main_progress_bar.close() + self.val_progress_bar.close() + + def on_train_start(self, trainer, pl_module): + super().on_train_start(trainer, pl_module) + self.main_progress_bar = self.init_train_tqdm() + + def on_train_epoch_start(self, trainer, pl_module): + super().on_train_epoch_start(trainer, pl_module) + total_train_batches = self.total_train_batches + total_val_batches = self.total_val_batches + if total_train_batches != float("inf") and total_val_batches != float("inf"): + # val can be checked multiple times per epoch + val_checks_per_epoch = total_train_batches // trainer.val_check_batch + total_val_batches = total_val_batches * val_checks_per_epoch + total_batches = total_train_batches + total_val_batches + reset(self.main_progress_bar, total_batches) + self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + total_batches = self.total_train_batches + self.total_val_batches + total_batches = convert_inf(total_batches) + if self._should_update(self.train_batch_idx, total_batches): + self._update_bar(self.main_progress_bar) + self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + + def on_validation_start(self, trainer, pl_module): + super().on_validation_start(trainer, pl_module) + if trainer.sanity_checking: + reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) + else: + self._update_bar(self.main_progress_bar) # fill up remaining + self.val_progress_bar = self.init_validation_tqdm() + reset(self.val_progress_bar, self.total_val_batches) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)): + self._update_bar(self.val_progress_bar) + self._update_bar(self.main_progress_bar) + + def on_validation_end(self, trainer, pl_module): + super().on_validation_end(trainer, pl_module) + if self.main_progress_bar is not None: + self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + self.val_progress_bar.close() + + def on_train_end(self, trainer, pl_module): + super().on_train_end(trainer, pl_module) + self.main_progress_bar.close() + + def on_test_start(self, trainer, pl_module): + super().on_test_start(trainer, pl_module) + self.test_progress_bar = self.init_test_tqdm() + self.test_progress_bar.total = convert_inf(self.total_test_batches) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self._should_update(self.test_batch_idx, self.total_test_batches): + self._update_bar(self.test_progress_bar) + + def on_test_end(self, trainer, pl_module): + super().on_test_end(trainer, pl_module) + self.test_progress_bar.close() + + def on_predict_epoch_start(self, trainer, pl_module): + super().on_predict_epoch_start(trainer, pl_module) + self.predict_progress_bar = self.init_predict_tqdm() + self.predict_progress_bar.total = convert_inf(self.total_predict_batches) + + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self._should_update(self.predict_batch_idx, self.total_predict_batches): + self._update_bar(self.predict_progress_bar) + + def on_predict_end(self, trainer, pl_module): + self.predict_progress_bar.close() + + def print( + self, *args, sep: str = " ", end: str = os.linesep, file: Optional[io.TextIOBase] = None, nolock: bool = False + ): + active_progress_bar = None + + if self.main_progress_bar is not None and not self.main_progress_bar.disable: + active_progress_bar = self.main_progress_bar + elif self.val_progress_bar is not None and not self.val_progress_bar.disable: + active_progress_bar = self.val_progress_bar + elif self.test_progress_bar is not None and not self.test_progress_bar.disable: + active_progress_bar = self.test_progress_bar + elif self.predict_progress_bar is not None and not self.predict_progress_bar.disable: + active_progress_bar = self.predict_progress_bar + + if active_progress_bar is not None: + s = sep.join(map(str, args)) + active_progress_bar.write(s, end=end, file=file, nolock=nolock) + + def _should_update(self, current, total) -> bool: + return self.is_enabled and (current % self.refresh_rate == 0 or current == total) + + def _update_bar(self, bar: Optional[tqdm]) -> None: + """Updates the bar by the refresh rate without overshooting.""" + if bar is None: + return + if bar.total is not None: + delta = min(self.refresh_rate, bar.total - bar.n) + else: + # infinite / unknown size + delta = self.refresh_rate + if delta > 0: + bar.update(delta) + + +def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: + """The tqdm doesn't support inf/nan values. We have to convert it to None.""" + if x is None or math.isinf(x) or math.isnan(x): + return None + return x + + +def reset(bar: tqdm, total: Optional[int] = None) -> None: + """Resets the tqdm bar to 0 progress with a new total, unless it is disabled.""" + if not bar.disable: + bar.reset(total=convert_inf(total)) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py new file mode 100644 index 0000000000000..e7056b60a8bfd --- /dev/null +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -0,0 +1,117 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.callbacks.progress.base import ProgressBarBase +from pytorch_lightning.utilities import _RICH_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _RICH_AVAILABLE: + from rich.console import Console + from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn + from rich.text import Text + + +class MetricsTextColumn(TextColumn): + """A column containing text.""" + + def __init__(self, trainer): + self._trainer = trainer + super().__init__("") + + def render(self, task) -> Text: + _text = "" + if "red" in f"{task.description}": + for k, v in self._trainer.progress_bar_dict.items(): + _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " + if self.markup: + text = Text.from_markup(_text, style=self.style, justify=self.justify) + else: + text = Text(_text, style=self.style, justify=self.justify) + if self.highlighter: + self.highlighter.highlight(text) + return text + + +class RichProgressBar(ProgressBarBase): + def __init__(self, refresh_rate: int = 1, process_position: int = 0): + if not _RICH_AVAILABLE: + raise MisconfigurationException("Rich progress bar is not available") + self._refresh_rate = refresh_rate + self._process_position = process_position + self._enabled = True + self.main_progress_bar = None + self.val_progress_bar = None + self.test_progress_bar = None + self.console = Console(record=True) + + @property + def refresh_rate(self) -> int: + return self._refresh_rate + + @property + def process_position(self) -> int: + return self._process_position + + @property + def is_enabled(self) -> bool: + return self._enabled and self.refresh_rate > 0 + + @property + def is_disabled(self) -> bool: + return not self.is_enabled + + def disable(self) -> None: + self._enabled = False + + def enable(self) -> None: + self._enabled = True + + def setup(self, trainer, pl_module, stage): + print("hello") + self.progress = Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeElapsedColumn(), + MetricsTextColumn(trainer), + console=self.console, + ).__enter__() + + def on_train_epoch_start(self, trainer, pl_module): + super().on_train_epoch_start(trainer, pl_module) + total_train_batches = self.total_train_batches + total_val_batches = self.total_val_batches + if total_train_batches != float("inf"): + # val can be checked multiple times per epoch + val_checks_per_epoch = total_train_batches // trainer.val_check_batch + total_val_batches = total_val_batches * val_checks_per_epoch + + total_batches = total_train_batches + total_val_batches + self.main_progress_bar = self.progress.add_task( + f"[red][Epoch {trainer.current_epoch}]", + total=total_batches, + ) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches): + if getattr(self, "progress", None) is not None: + self.progress.update(self.main_progress_bar, advance=1.0) + self.progress.track(trainer.progress_bar_dict) + + def _should_update(self, current, total): + return self.is_enabled and (current % self.refresh_rate == 0 or current == total) + + def teardown(self, trainer, pl_module, stage): + self.progress.__exit__(None, None, None) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index a151f1320a77b..97ec4245bf61b 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -43,6 +43,7 @@ _NATIVE_AMP_AVAILABLE, _OMEGACONF_AVAILABLE, _POPTORCH_AVAILABLE, + _RICH_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index f999847160256..103733417159d 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -90,6 +90,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCHMETRICS_LOWER_THAN_0_3 = _compare_version("torchmetrics", operator.lt, "0.3.0") _TORCHMETRICS_GREATER_EQUAL_0_3 = _compare_version("torchmetrics", operator.ge, "0.3.0") _XLA_AVAILABLE: bool = _module_available("torch_xla") +_RICH_AVAILABLE = _module_available("rich") from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402 From 6be12edfc97400dec6746b307be4f18e3c6faeeb Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Sun, 15 Aug 2021 20:33:10 +0530 Subject: [PATCH 02/22] Add CustomTimeColumn --- .../callbacks/progress/rich_progress.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index e7056b60a8bfd..aeb0767902653 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from datetime import timedelta + from pytorch_lightning.callbacks.progress.base import ProgressBarBase from pytorch_lightning.utilities import _RICH_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException if _RICH_AVAILABLE: from rich.console import Console - from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn + from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn from rich.text import Text @@ -42,6 +44,19 @@ def render(self, task) -> Text: return text +class CustomTimeColumn(ProgressColumn): + + # Only refresh twice a second to prevent jitter + max_refresh = 0.5 + + def render(self, task) -> Text: + elapsed = task.finished_time if task.finished else task.elapsed + remaining = task.time_remaining + elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed))) + remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining))) + return Text.from_markup(f"[progress.elapsed]{elapsed_delta} / [progress.remaining]{remaining_delta}") + + class RichProgressBar(ProgressBarBase): def __init__(self, refresh_rate: int = 1, process_position: int = 0): if not _RICH_AVAILABLE: @@ -77,13 +92,12 @@ def enable(self) -> None: self._enabled = True def setup(self, trainer, pl_module, stage): - print("hello") self.progress = Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), - TimeElapsedColumn(), + CustomTimeColumn(), MetricsTextColumn(trainer), console=self.console, ).__enter__() @@ -97,10 +111,10 @@ def on_train_epoch_start(self, trainer, pl_module): val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch - total_batches = total_train_batches + total_val_batches + # total_batches = total_train_batches + total_val_batches self.main_progress_bar = self.progress.add_task( f"[red][Epoch {trainer.current_epoch}]", - total=total_batches, + total=total_train_batches, ) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): From fa4255662baa1de4f14599417d015d3d37312246 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Sun, 15 Aug 2021 21:26:31 +0530 Subject: [PATCH 03/22] Add BatchesProcessedColumn & ProcessingSpeedColumn --- .../callbacks/progress/rich_progress.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index aeb0767902653..4ee9157e8349e 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -54,7 +54,18 @@ def render(self, task) -> Text: remaining = task.time_remaining elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed))) remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining))) - return Text.from_markup(f"[progress.elapsed]{elapsed_delta} / [progress.remaining]{remaining_delta}") + return Text.from_markup(f"[progress.elapsed]{elapsed_delta} < [progress.remaining]{remaining_delta}") + + +class BatchesProcessedColumn(ProgressColumn): + def render(self, task) -> Text: + return Text.from_markup(f"[magenta] {int(task.completed)}/{task.total}") + + +class ProcessingSpeedColumn(ProgressColumn): + def render(self, task) -> Text: + task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00" + return Text.from_markup(f"[progress.data.speed] {task_speed}it/s") class RichProgressBar(ProgressBarBase): @@ -96,9 +107,12 @@ def setup(self, trainer, pl_module, stage): SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + BatchesProcessedColumn(), + "[", CustomTimeColumn(), + ProcessingSpeedColumn(), MetricsTextColumn(trainer), + "]", console=self.console, ).__enter__() From 86e4c514a675859bbf57e7cd1d5b9ca730bf4cb8 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 16 Aug 2021 10:31:37 +0530 Subject: [PATCH 04/22] Add support for Testing Bar --- .../callbacks/progress/rich_progress.py | 86 ++++++++++++------- 1 file changed, 56 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 4ee9157e8349e..369df27080ca3 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -23,27 +23,6 @@ from rich.text import Text -class MetricsTextColumn(TextColumn): - """A column containing text.""" - - def __init__(self, trainer): - self._trainer = trainer - super().__init__("") - - def render(self, task) -> Text: - _text = "" - if "red" in f"{task.description}": - for k, v in self._trainer.progress_bar_dict.items(): - _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " - if self.markup: - text = Text.from_markup(_text, style=self.style, justify=self.justify) - else: - text = Text(_text, style=self.style, justify=self.justify) - if self.highlighter: - self.highlighter.highlight(text) - return text - - class CustomTimeColumn(ProgressColumn): # Only refresh twice a second to prevent jitter @@ -68,12 +47,35 @@ def render(self, task) -> Text: return Text.from_markup(f"[progress.data.speed] {task_speed}it/s") +class MetricsTextColumn(TextColumn): + """A column containing text.""" + + def __init__(self, trainer, stage): + self._trainer = trainer + self._stage = stage + super().__init__("") + + def render(self, task) -> Text: + _text = "" + if self._stage == "test": + return "" + if "red" in f"{task.description}" or "yellow" in f"{task.description}": + for k, v in self._trainer.progress_bar_dict.items(): + _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " + if self.markup: + text = Text.from_markup(_text, style=self.style, justify=self.justify) + else: + text = Text(_text, style=self.style, justify=self.justify) + if self.highlighter: + self.highlighter.highlight(text) + return text + + class RichProgressBar(ProgressBarBase): - def __init__(self, refresh_rate: int = 1, process_position: int = 0): + def __init__(self, refresh_rate: int = 1): if not _RICH_AVAILABLE: raise MisconfigurationException("Rich progress bar is not available") self._refresh_rate = refresh_rate - self._process_position = process_position self._enabled = True self.main_progress_bar = None self.val_progress_bar = None @@ -84,10 +86,6 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0): def refresh_rate(self) -> int: return self._refresh_rate - @property - def process_position(self) -> int: - return self._process_position - @property def is_enabled(self) -> bool: return self._enabled and self.refresh_rate > 0 @@ -111,7 +109,7 @@ def setup(self, trainer, pl_module, stage): "[", CustomTimeColumn(), ProcessingSpeedColumn(), - MetricsTextColumn(trainer), + MetricsTextColumn(trainer, stage), "]", console=self.console, ).__enter__() @@ -125,10 +123,23 @@ def on_train_epoch_start(self, trainer, pl_module): val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch - # total_batches = total_train_batches + total_val_batches + total_batches = total_train_batches + total_val_batches self.main_progress_bar = self.progress.add_task( f"[red][Epoch {trainer.current_epoch}]", - total=total_train_batches, + total=total_batches, + ) + if total_val_batches > 0: + self.val_progress_bar = self.progress.add_task( + "[yellow][Validation]", + total=total_val_batches, + ) + + def on_test_epoch_start(self, trainer, pl_module): + super().on_train_epoch_start(trainer, pl_module) + total_test_batches = self.total_test_batches + self.test_progress_bar = self.progress.add_task( + "[red][Testing]", + total=total_test_batches, ) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -138,6 +149,21 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data self.progress.update(self.main_progress_bar, advance=1.0) self.progress.track(trainer.progress_bar_dict) + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self.val_progress_bar and self._should_update( + self.val_batch_idx, self.total_train_batches + self.total_val_batches + ): + if getattr(self, "progress", None) is not None: + # self.progress.update(self.main_progress_bar, advance=1.) + self.progress.update(self.val_progress_bar, advance=1.0) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self._should_update(self.test_batch_idx, self.total_test_batches): + if getattr(self, "progress", None) is not None: + self.progress.update(self.test_progress_bar, advance=1.0) + def _should_update(self, current, total): return self.is_enabled and (current % self.refresh_rate == 0 or current == total) From c5a5824cda4f5559732ec4f74243422d83ed5b33 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 16 Aug 2021 13:09:07 +0530 Subject: [PATCH 05/22] Add updates for validation bar --- .../callbacks/progress/rich_progress.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 369df27080ca3..1c1f495d13efa 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -21,6 +21,8 @@ from rich.console import Console from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn from rich.text import Text +else: + ProgressColumn, TextColumn = None, None class CustomTimeColumn(ProgressColumn): @@ -77,6 +79,7 @@ def __init__(self, refresh_rate: int = 1): raise MisconfigurationException("Rich progress bar is not available") self._refresh_rate = refresh_rate self._enabled = True + self._total_val_batches = 0 self.main_progress_bar = None self.val_progress_bar = None self.test_progress_bar = None @@ -112,28 +115,37 @@ def setup(self, trainer, pl_module, stage): MetricsTextColumn(trainer, stage), "]", console=self.console, + refresh_per_second=self.refresh_rate, ).__enter__() def on_train_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches - total_val_batches = self.total_val_batches + self._total_val_batches = self.total_val_batches if total_train_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch - total_val_batches = total_val_batches * val_checks_per_epoch + self._total_val_batches = self._total_val_batches * val_checks_per_epoch - total_batches = total_train_batches + total_val_batches + total_batches = total_train_batches + self._total_val_batches self.main_progress_bar = self.progress.add_task( f"[red][Epoch {trainer.current_epoch}]", total=total_batches, ) - if total_val_batches > 0: + + def on_validation_epoch_start(self, trainer, pl_module): + super().on_validation_epoch_start(trainer, pl_module) + if self._total_val_batches > 0: self.val_progress_bar = self.progress.add_task( "[yellow][Validation]", - total=total_val_batches, + total=self._total_val_batches, ) + def on_validation_epoch_end(self, trainer, pl_module): + super().on_validation_epoch_end(trainer, pl_module) + if self.val_progress_bar is not None: + self.progress.update(self.val_progress_bar, visible=False) + def on_test_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) total_test_batches = self.total_test_batches @@ -147,7 +159,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches): if getattr(self, "progress", None) is not None: self.progress.update(self.main_progress_bar, advance=1.0) - self.progress.track(trainer.progress_bar_dict) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) @@ -155,7 +166,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, self.val_batch_idx, self.total_train_batches + self.total_val_batches ): if getattr(self, "progress", None) is not None: - # self.progress.update(self.main_progress_bar, advance=1.) + self.progress.update(self.main_progress_bar, advance=1.0) self.progress.update(self.val_progress_bar, advance=1.0) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): From 93a6e8b65297818c74492cb04f53bec75d1e2f7d Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 16 Aug 2021 13:25:52 +0530 Subject: [PATCH 06/22] Class refactor --- pytorch_lightning/callbacks/progress.py | 514 ------------------ .../callbacks/progress/rich_progress.py | 12 +- 2 files changed, 5 insertions(+), 521 deletions(-) delete mode 100644 pytorch_lightning/callbacks/progress.py diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py deleted file mode 100644 index 63c2485321a80..0000000000000 --- a/pytorch_lightning/callbacks/progress.py +++ /dev/null @@ -1,514 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Progress Bars -============= - -Use or override one of the progress bar callbacks. - -""" -import importlib -import io -import math -import os -import sys - -# check if ipywidgets is installed before importing tqdm.auto -# to ensure it won't fail and a progress bar is displayed -from typing import Optional, Union - -if importlib.util.find_spec("ipywidgets") is not None: - from tqdm.auto import tqdm as _tqdm -else: - from tqdm import tqdm as _tqdm - -from pytorch_lightning.callbacks import Callback - -_PAD_SIZE = 5 - - -class tqdm(_tqdm): - """ - Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering - """ - - @staticmethod - def format_num(n) -> str: - """Add additional padding to the formatted numbers""" - should_be_padded = isinstance(n, (float, str)) - if not isinstance(n, str): - n = _tqdm.format_num(n) - if should_be_padded and "e" not in n: - if "." not in n and len(n) < _PAD_SIZE: - try: - _ = float(n) - except ValueError: - return n - n += "." - n += "0" * (_PAD_SIZE - len(n)) - return n - - -class ProgressBarBase(Callback): - r""" - The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback` - that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. - You should implement your highly custom progress bars with this as the base class. - - Example:: - - class LitProgressBar(ProgressBarBase): - - def __init__(self): - super().__init__() # important :-) - self.enable = True - - def disable(self): - self.enable = False - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) # important - percent = (self.train_batch_idx / self.total_train_batches) * 100 - sys.stdout.flush() - sys.stdout.write(f'{percent:.01f} percent complete \r') - - bar = LitProgressBar() - trainer = Trainer(callbacks=[bar]) - """ - - def __init__(self): - - self._trainer = None - self._train_batch_idx = 0 - self._val_batch_idx = 0 - self._test_batch_idx = 0 - self._predict_batch_idx = 0 - - @property - def trainer(self): - return self._trainer - - @property - def train_batch_idx(self) -> int: - """ - The current batch index being processed during training. - Use this to update your progress bar. - """ - return self._train_batch_idx - - @property - def val_batch_idx(self) -> int: - """ - The current batch index being processed during validation. - Use this to update your progress bar. - """ - return self._val_batch_idx - - @property - def test_batch_idx(self) -> int: - """ - The current batch index being processed during testing. - Use this to update your progress bar. - """ - return self._test_batch_idx - - @property - def predict_batch_idx(self) -> int: - """ - The current batch index being processed during predicting. - Use this to update your progress bar. - """ - return self._predict_batch_idx - - @property - def total_train_batches(self) -> int: - """ - The total number of training batches during training, which may change from epoch to epoch. - Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the - training dataloader is of infinite size. - """ - return self.trainer.num_training_batches - - @property - def total_val_batches(self) -> int: - """ - The total number of validation batches during validation, which may change from epoch to epoch. - Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the - validation dataloader is of infinite size. - """ - total_val_batches = 0 - if self.trainer.enable_validation: - is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 - total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 - - return total_val_batches - - @property - def total_test_batches(self) -> int: - """ - The total number of testing batches during testing, which may change from epoch to epoch. - Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the - test dataloader is of infinite size. - """ - return sum(self.trainer.num_test_batches) - - @property - def total_predict_batches(self) -> int: - """ - The total number of predicting batches during testing, which may change from epoch to epoch. - Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the - predict dataloader is of infinite size. - """ - return sum(self.trainer.num_predict_batches) - - def disable(self): - """ - You should provide a way to disable the progress bar. - The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the - output on processes that have a rank different from 0, e.g., in multi-node training. - """ - raise NotImplementedError - - def enable(self): - """ - You should provide a way to enable the progress bar. - The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training - routines like the :ref:`learning rate finder ` - to temporarily enable and disable the main progress bar. - """ - raise NotImplementedError - - def print(self, *args, **kwargs): - """ - You should provide a way to print without breaking the progress bar. - """ - print(*args, **kwargs) - - def on_init_end(self, trainer): - self._trainer = trainer - - def on_train_start(self, trainer, pl_module): - self._train_batch_idx = trainer.fit_loop.batch_idx - - def on_train_epoch_start(self, trainer, pl_module): - self._train_batch_idx = 0 - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._train_batch_idx += 1 - - def on_validation_start(self, trainer, pl_module): - self._val_batch_idx = 0 - - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._val_batch_idx += 1 - - def on_test_start(self, trainer, pl_module): - self._test_batch_idx = 0 - - def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._test_batch_idx += 1 - - def on_predict_epoch_start(self, trainer, pl_module): - self._predict_batch_idx = 0 - - def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._predict_batch_idx += 1 - - -class ProgressBar(ProgressBarBase): - r""" - This is the default progress bar used by Lightning. It prints to `stdout` using the - :mod:`tqdm` package and shows up to four different bars: - - - **sanity check progress:** the progress during the sanity check run - - **main progress:** shows training + validation progress combined. It also accounts for - multiple validation runs during training when - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used. - - **validation progress:** only visible during validation; - shows total progress over all validation datasets. - - **test progress:** only active when testing; shows total progress over all test datasets. - - For infinite datasets, the progress bar never ends. - - If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override - specific methods of the callback class and pass your custom implementation to the - :class:`~pytorch_lightning.trainer.trainer.Trainer`: - - Example:: - - class LitProgressBar(ProgressBar): - - def init_validation_tqdm(self): - bar = super().init_validation_tqdm() - bar.set_description('running validation ...') - return bar - - bar = LitProgressBar() - trainer = Trainer(callbacks=[bar]) - - Args: - refresh_rate: - Determines at which rate (in number of batches) the progress bars get updated. - Set it to ``0`` to disable the display. By default, the - :class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress - bar and sets the refresh rate to the value provided to the - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the - :class:`~pytorch_lightning.trainer.trainer.Trainer`. - process_position: - Set this to a value greater than ``0`` to offset the progress bars by this many lines. - This is useful when you have progress bars defined elsewhere and want to show all of them - together. This corresponds to - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the - :class:`~pytorch_lightning.trainer.trainer.Trainer`. - - """ - - def __init__(self, refresh_rate: int = 1, process_position: int = 0): - super().__init__() - self._refresh_rate = refresh_rate - self._process_position = process_position - self._enabled = True - self.main_progress_bar = None - self.val_progress_bar = None - self.test_progress_bar = None - self.predict_progress_bar = None - - def __getstate__(self): - # can't pickle the tqdm objects - state = self.__dict__.copy() - state["main_progress_bar"] = None - state["val_progress_bar"] = None - state["test_progress_bar"] = None - state["predict_progress_bar"] = None - return state - - @property - def refresh_rate(self) -> int: - return self._refresh_rate - - @property - def process_position(self) -> int: - return self._process_position - - @property - def is_enabled(self) -> bool: - return self._enabled and self.refresh_rate > 0 - - @property - def is_disabled(self) -> bool: - return not self.is_enabled - - def disable(self) -> None: - self._enabled = False - - def enable(self) -> None: - self._enabled = True - - def init_sanity_tqdm(self) -> tqdm: - """Override this to customize the tqdm bar for the validation sanity run.""" - bar = tqdm( - desc="Validation sanity check", - position=(2 * self.process_position), - disable=self.is_disabled, - leave=False, - dynamic_ncols=True, - file=sys.stdout, - ) - return bar - - def init_train_tqdm(self) -> tqdm: - """Override this to customize the tqdm bar for training.""" - bar = tqdm( - desc="Training", - initial=self.train_batch_idx, - position=(2 * self.process_position), - disable=self.is_disabled, - leave=True, - dynamic_ncols=True, - file=sys.stdout, - smoothing=0, - ) - return bar - - def init_predict_tqdm(self) -> tqdm: - """Override this to customize the tqdm bar for predicting.""" - bar = tqdm( - desc="Predicting", - initial=self.train_batch_idx, - position=(2 * self.process_position), - disable=self.is_disabled, - leave=True, - dynamic_ncols=True, - file=sys.stdout, - smoothing=0, - ) - return bar - - def init_validation_tqdm(self) -> tqdm: - """Override this to customize the tqdm bar for validation.""" - # The main progress bar doesn't exist in `trainer.validate()` - has_main_bar = self.main_progress_bar is not None - bar = tqdm( - desc="Validating", - position=(2 * self.process_position + has_main_bar), - disable=self.is_disabled, - leave=False, - dynamic_ncols=True, - file=sys.stdout, - ) - return bar - - def init_test_tqdm(self) -> tqdm: - """Override this to customize the tqdm bar for testing.""" - bar = tqdm( - desc="Testing", - position=(2 * self.process_position), - disable=self.is_disabled, - leave=True, - dynamic_ncols=True, - file=sys.stdout, - ) - return bar - - def on_sanity_check_start(self, trainer, pl_module): - super().on_sanity_check_start(trainer, pl_module) - self.val_progress_bar = self.init_sanity_tqdm() - self.main_progress_bar = tqdm(disable=True) # dummy progress bar - - def on_sanity_check_end(self, trainer, pl_module): - super().on_sanity_check_end(trainer, pl_module) - self.main_progress_bar.close() - self.val_progress_bar.close() - - def on_train_start(self, trainer, pl_module): - super().on_train_start(trainer, pl_module) - self.main_progress_bar = self.init_train_tqdm() - - def on_train_epoch_start(self, trainer, pl_module): - super().on_train_epoch_start(trainer, pl_module) - total_train_batches = self.total_train_batches - total_val_batches = self.total_val_batches - if total_train_batches != float("inf") and total_val_batches != float("inf"): - # val can be checked multiple times per epoch - val_checks_per_epoch = total_train_batches // trainer.val_check_batch - total_val_batches = total_val_batches * val_checks_per_epoch - total_batches = total_train_batches + total_val_batches - reset(self.main_progress_bar, total_batches) - self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - total_batches = self.total_train_batches + self.total_val_batches - total_batches = convert_inf(total_batches) - if self._should_update(self.train_batch_idx, total_batches): - self._update_bar(self.main_progress_bar) - self.main_progress_bar.set_postfix(trainer.progress_bar_dict) - - def on_validation_start(self, trainer, pl_module): - super().on_validation_start(trainer, pl_module) - if trainer.sanity_checking: - reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) - else: - self._update_bar(self.main_progress_bar) # fill up remaining - self.val_progress_bar = self.init_validation_tqdm() - reset(self.val_progress_bar, self.total_val_batches) - - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)): - self._update_bar(self.val_progress_bar) - self._update_bar(self.main_progress_bar) - - def on_validation_end(self, trainer, pl_module): - super().on_validation_end(trainer, pl_module) - if self.main_progress_bar is not None: - self.main_progress_bar.set_postfix(trainer.progress_bar_dict) - self.val_progress_bar.close() - - def on_train_end(self, trainer, pl_module): - super().on_train_end(trainer, pl_module) - self.main_progress_bar.close() - - def on_test_start(self, trainer, pl_module): - super().on_test_start(trainer, pl_module) - self.test_progress_bar = self.init_test_tqdm() - self.test_progress_bar.total = convert_inf(self.total_test_batches) - - def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.test_batch_idx, self.total_test_batches): - self._update_bar(self.test_progress_bar) - - def on_test_end(self, trainer, pl_module): - super().on_test_end(trainer, pl_module) - self.test_progress_bar.close() - - def on_predict_epoch_start(self, trainer, pl_module): - super().on_predict_epoch_start(trainer, pl_module) - self.predict_progress_bar = self.init_predict_tqdm() - self.predict_progress_bar.total = convert_inf(self.total_predict_batches) - - def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.predict_batch_idx, self.total_predict_batches): - self._update_bar(self.predict_progress_bar) - - def on_predict_end(self, trainer, pl_module): - self.predict_progress_bar.close() - - def print( - self, *args, sep: str = " ", end: str = os.linesep, file: Optional[io.TextIOBase] = None, nolock: bool = False - ): - active_progress_bar = None - - if self.main_progress_bar is not None and not self.main_progress_bar.disable: - active_progress_bar = self.main_progress_bar - elif self.val_progress_bar is not None and not self.val_progress_bar.disable: - active_progress_bar = self.val_progress_bar - elif self.test_progress_bar is not None and not self.test_progress_bar.disable: - active_progress_bar = self.test_progress_bar - elif self.predict_progress_bar is not None and not self.predict_progress_bar.disable: - active_progress_bar = self.predict_progress_bar - - if active_progress_bar is not None: - s = sep.join(map(str, args)) - active_progress_bar.write(s, end=end, file=file, nolock=nolock) - - def _should_update(self, current, total) -> bool: - return self.is_enabled and (current % self.refresh_rate == 0 or current == total) - - def _update_bar(self, bar: Optional[tqdm]) -> None: - """Updates the bar by the refresh rate without overshooting.""" - if bar is None: - return - if bar.total is not None: - delta = min(self.refresh_rate, bar.total - bar.n) - else: - # infinite / unknown size - delta = self.refresh_rate - if delta > 0: - bar.update(delta) - - -def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: - """The tqdm doesn't support inf/nan values. We have to convert it to None.""" - if x is None or math.isinf(x) or math.isnan(x): - return None - return x - - -def reset(bar: tqdm, total: Optional[int] = None) -> None: - """Resets the tqdm bar to 0 progress with a new total, unless it is disabled.""" - if not bar.disable: - bar.reset(total=convert_inf(total)) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 1c1f495d13efa..c41d02f765e56 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -77,6 +77,7 @@ class RichProgressBar(ProgressBarBase): def __init__(self, refresh_rate: int = 1): if not _RICH_AVAILABLE: raise MisconfigurationException("Rich progress bar is not available") + super().__init__() self._refresh_rate = refresh_rate self._enabled = True self._total_val_batches = 0 @@ -157,23 +158,20 @@ def on_test_epoch_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches): - if getattr(self, "progress", None) is not None: - self.progress.update(self.main_progress_bar, advance=1.0) + self.progress.update(self.main_progress_bar, advance=1.0) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self.val_progress_bar and self._should_update( self.val_batch_idx, self.total_train_batches + self.total_val_batches ): - if getattr(self, "progress", None) is not None: - self.progress.update(self.main_progress_bar, advance=1.0) - self.progress.update(self.val_progress_bar, advance=1.0) + self.progress.update(self.main_progress_bar, advance=1.0) + self.progress.update(self.val_progress_bar, advance=1.0) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.test_batch_idx, self.total_test_batches): - if getattr(self, "progress", None) is not None: - self.progress.update(self.test_progress_bar, advance=1.0) + self.progress.update(self.test_progress_bar, advance=1.0) def _should_update(self, current, total): return self.is_enabled and (current % self.refresh_rate == 0 or current == total) From 4794814575d8d2c39f5030acb678610f9b7ad869 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 16 Aug 2021 15:08:25 +0530 Subject: [PATCH 07/22] Update --- .../callbacks/progress/rich_progress.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index c41d02f765e56..89045abf9ff0d 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -22,7 +22,7 @@ from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn from rich.text import Text else: - ProgressColumn, TextColumn = None, None + ProgressColumn, TextColumn, Text = None, None, None class CustomTimeColumn(ProgressColumn): @@ -49,13 +49,13 @@ def render(self, task) -> Text: return Text.from_markup(f"[progress.data.speed] {task_speed}it/s") -class MetricsTextColumn(TextColumn): +class MetricsTextColumn(ProgressColumn): """A column containing text.""" def __init__(self, trainer, stage): self._trainer = trainer self._stage = stage - super().__init__("") + super().__init__() def render(self, task) -> Text: _text = "" @@ -64,12 +64,7 @@ def render(self, task) -> Text: if "red" in f"{task.description}" or "yellow" in f"{task.description}": for k, v in self._trainer.progress_bar_dict.items(): _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " - if self.markup: - text = Text.from_markup(_text, style=self.style, justify=self.justify) - else: - text = Text(_text, style=self.style, justify=self.justify) - if self.highlighter: - self.highlighter.highlight(text) + text = Text.from_markup(_text, style=None, justify="left") return text From 1abb6ed6b11fac756d74f738bd97ac51d0fe547f Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 17 Aug 2021 10:48:27 +0530 Subject: [PATCH 08/22] Add support for display per epoch --- pytorch_lightning/callbacks/progress/rich_progress.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 89045abf9ff0d..8e85b2c28b1a3 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -55,13 +55,22 @@ class MetricsTextColumn(ProgressColumn): def __init__(self, trainer, stage): self._trainer = trainer self._stage = stage + self._tasks = {} + self._current_task_id = 0 super().__init__() def render(self, task) -> Text: + if "red" in task.description and task.id not in self._tasks: + self._tasks[task.id] = "None" + if self._renderable_cache: + self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1] + self._current_task_id = task.id + if "red" in task.description and task.id != self._current_task_id: + return self._tasks[task.id] _text = "" if self._stage == "test": return "" - if "red" in f"{task.description}" or "yellow" in f"{task.description}": + if "red" in task.description or "yellow" in task.description: for k, v in self._trainer.progress_bar_dict.items(): _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " text = Text.from_markup(_text, style=None, justify="left") From ca339767a367e7b038793ba8f4b3b08c7bb14814 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 17 Aug 2021 15:48:58 +0530 Subject: [PATCH 09/22] Update Sanity & predict bar --- .../callbacks/progress/rich_progress.py | 36 +++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 8e85b2c28b1a3..de33d34e37e06 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -60,6 +60,8 @@ def __init__(self, trainer, stage): super().__init__() def render(self, task) -> Text: + if self._stage == "test" or self._trainer.sanity_checking: + return "" if "red" in task.description and task.id not in self._tasks: self._tasks[task.id] = "None" if self._renderable_cache: @@ -68,8 +70,6 @@ def render(self, task) -> Text: if "red" in task.description and task.id != self._current_task_id: return self._tasks[task.id] _text = "" - if self._stage == "test": - return "" if "red" in task.description or "yellow" in task.description: for k, v in self._trainer.progress_bar_dict.items(): _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " @@ -123,6 +123,17 @@ def setup(self, trainer, pl_module, stage): refresh_per_second=self.refresh_rate, ).__enter__() + def on_sanity_check_start(self, trainer, pl_module): + super().on_sanity_check_start(trainer, pl_module) + self.val_sanity_progress_bar = self.progress.add_task( + "[yellow][Validation Sanity Check]", + total=trainer.num_sanity_val_steps, + ) + + def on_sanity_check_end(self, trainer, pl_module): + super().on_sanity_check_end(trainer, pl_module) + self.progress.update(self.val_sanity_progress_bar, visible=False) + def on_train_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches @@ -153,10 +164,16 @@ def on_validation_epoch_end(self, trainer, pl_module): def on_test_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) - total_test_batches = self.total_test_batches self.test_progress_bar = self.progress.add_task( - "[red][Testing]", - total=total_test_batches, + "[yellow][Testing]", + total=self.total_test_batches, + ) + + def on_predict_epoch_start(self, trainer, pl_module): + super().on_predict_epoch_start(trainer, pl_module) + self.predict_progress_bar = self.progress.add_task( + "[red][Predicting]", + total=self.total_predict_batches, ) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -166,7 +183,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self.val_progress_bar and self._should_update( + if trainer.sanity_checking: + self.progress.update(self.val_sanity_progress_bar, advance=1.0) + elif self.val_progress_bar and self._should_update( self.val_batch_idx, self.total_train_batches + self.total_val_batches ): self.progress.update(self.main_progress_bar, advance=1.0) @@ -177,6 +196,11 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal if self._should_update(self.test_batch_idx, self.total_test_batches): self.progress.update(self.test_progress_bar, advance=1.0) + def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self._should_update(self.predict_batch_idx, self.total_predict_batches): + self.progress.update(self.predict_progress_bar, advance=1.0) + def _should_update(self, current, total): return self.is_enabled and (current % self.refresh_rate == 0 or current == total) From 251cc4626cb24ce1b8d003e08ad8c16495efaf67 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 17 Aug 2021 20:20:11 +0530 Subject: [PATCH 10/22] Add rich for Model Summary --- .../callbacks/progress/rich_progress.py | 4 +- pytorch_lightning/trainer/trainer.py | 5 +- pytorch_lightning/utilities/model_summary.py | 79 ++++++++++++++++--- 3 files changed, 76 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index de33d34e37e06..99d0c0466a9a0 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -80,7 +80,9 @@ def render(self, task) -> Text: class RichProgressBar(ProgressBarBase): def __init__(self, refresh_rate: int = 1): if not _RICH_AVAILABLE: - raise MisconfigurationException("Rich progress bar is not available") + raise MisconfigurationException( + "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`." + ) super().__init__() self._refresh_rate = refresh_rate self._enabled = True diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 887cdd46a9db2..01e44bfb50439 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -25,7 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.accelerators import Accelerator, IPUAccelerator -from pytorch_lightning.callbacks import Callback +from pytorch_lightning.callbacks import Callback, RichProgressBar from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops import IteratorBatchProcessor, TrainingBatchLoop, TrainingEpochLoop @@ -1111,8 +1111,9 @@ def _pre_training_routine(self): # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: + use_rich = isinstance(self.progress_bar_callback, RichProgressBar) max_depth = ModelSummary.MODES[self.weights_summary] - summarize(ref_model, max_depth=max_depth) + summarize(ref_model, max_depth=max_depth, use_rich=use_rich) # on pretrain routine end self.on_pretrain_routine_end() diff --git a/pytorch_lightning/utilities/model_summary.py b/pytorch_lightning/utilities/model_summary.py index ff00868315216..5f9815bcd718f 100644 --- a/pytorch_lightning/utilities/model_summary.py +++ b/pytorch_lightning/utilities/model_summary.py @@ -25,9 +25,13 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities.imports import _RICH_AVAILABLE, _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.warnings import WarningCache +if _RICH_AVAILABLE: + from rich.console import Console + from rich.table import Table + log = logging.getLogger(__name__) warning_cache = WarningCache() @@ -299,12 +303,7 @@ def _forward_example_input(self) -> None: model(input_) model.train(mode) # restore mode of module - def __str__(self): - """ - Makes a summary listing with: - - Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size - """ + def _get_summary_data(self): arrays = [ [" ", list(map(str, range(len(self._layer_summary))))], ["Name", self.layer_names], @@ -314,6 +313,62 @@ def __str__(self): if self._model.example_input_array is not None: arrays.append(["In sizes", self.in_sizes]) arrays.append(["Out sizes", self.out_sizes]) + + return arrays + + def print_rich_summary(self): + + if not _RICH_AVAILABLE: + raise MisconfigurationException( + "`print_rich_summary` requires `rich` to be installed." " Install it by running `pip install rich`." + ) + + arrays = self._get_summary_data() + total_parameters = self.total_parameters + trainable_parameters = self.trainable_parameters + model_size = self.model_size + + console = Console() + + table = Table(title="Model Summary") + + table.add_column(" ") + table.add_column("Name", arrays[1][1], justify="left", style="cyan", no_wrap=True) + table.add_column("Type", arrays[2][1], style="magenta") + table.add_column("Params", arrays[3][1], justify="right", style="green") + + rows = list(zip(*(arr[1] for arr in arrays))) + for row in rows: + table.add_row(*row) + + console.print(table) + + # Formatting + s = "{:<{}}" + + parameters = [] + for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]: + parameters.append(s.format(get_human_readable_count(param), 10)) + + grid = Table.grid(expand=True) + grid.add_column() + grid.add_column() + + grid.add_row(f"[bold]Trainable params[/]: {parameters[0]}") + grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}") + grid.add_row(f"[bold]Total params[/]: {parameters[2]}") + grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}") + + console.print(grid) + + def __str__(self): + """ + Makes a summary listing with: + + Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size + """ + arrays = self._get_summary_data() + total_parameters = self.total_parameters trainable_parameters = self.trainable_parameters model_size = self.model_size @@ -435,7 +490,10 @@ def _is_lazy_weight_tensor(p: Tensor) -> bool: def summarize( - lightning_module: "pl.LightningModule", mode: Optional[str] = "top", max_depth: Optional[int] = None + lightning_module: "pl.LightningModule", + mode: Optional[str] = "top", + max_depth: Optional[int] = None, + use_rich: bool = False, ) -> Optional[ModelSummary]: """ Summarize the LightningModule specified by `lightning_module`. @@ -467,5 +525,8 @@ def summarize( raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}") else: model_summary = ModelSummary(lightning_module, max_depth=max_depth) - log.info("\n" + str(model_summary)) + if use_rich: + model_summary.print_rich_summary() + else: + log.info("\n" + str(model_summary)) return model_summary From 539ba084c694517afe69b85596fdec8809982b55 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 17 Aug 2021 20:36:12 +0530 Subject: [PATCH 11/22] Update model summary --- .../callbacks/progress/rich_progress.py | 100 ++++++++---------- pytorch_lightning/utilities/model_summary.py | 10 +- 2 files changed, 54 insertions(+), 56 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 99d0c0466a9a0..60bb4192e0cb8 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -21,60 +21,54 @@ from rich.console import Console from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn from rich.text import Text -else: - ProgressColumn, TextColumn, Text = None, None, None - -class CustomTimeColumn(ProgressColumn): - - # Only refresh twice a second to prevent jitter - max_refresh = 0.5 - - def render(self, task) -> Text: - elapsed = task.finished_time if task.finished else task.elapsed - remaining = task.time_remaining - elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed))) - remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining))) - return Text.from_markup(f"[progress.elapsed]{elapsed_delta} < [progress.remaining]{remaining_delta}") - - -class BatchesProcessedColumn(ProgressColumn): - def render(self, task) -> Text: - return Text.from_markup(f"[magenta] {int(task.completed)}/{task.total}") - - -class ProcessingSpeedColumn(ProgressColumn): - def render(self, task) -> Text: - task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00" - return Text.from_markup(f"[progress.data.speed] {task_speed}it/s") - - -class MetricsTextColumn(ProgressColumn): - """A column containing text.""" - - def __init__(self, trainer, stage): - self._trainer = trainer - self._stage = stage - self._tasks = {} - self._current_task_id = 0 - super().__init__() - - def render(self, task) -> Text: - if self._stage == "test" or self._trainer.sanity_checking: - return "" - if "red" in task.description and task.id not in self._tasks: - self._tasks[task.id] = "None" - if self._renderable_cache: - self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1] - self._current_task_id = task.id - if "red" in task.description and task.id != self._current_task_id: - return self._tasks[task.id] - _text = "" - if "red" in task.description or "yellow" in task.description: - for k, v in self._trainer.progress_bar_dict.items(): - _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " - text = Text.from_markup(_text, style=None, justify="left") - return text + class CustomTimeColumn(ProgressColumn): + + # Only refresh twice a second to prevent jitter + max_refresh = 0.5 + + def render(self, task) -> Text: + elapsed = task.finished_time if task.finished else task.elapsed + remaining = task.time_remaining + elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed))) + remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining))) + return Text.from_markup(f"[progress.elapsed]{elapsed_delta} < [progress.remaining]{remaining_delta}") + + class BatchesProcessedColumn(ProgressColumn): + def render(self, task) -> Text: + return Text.from_markup(f"[magenta] {int(task.completed)}/{task.total}") + + class ProcessingSpeedColumn(ProgressColumn): + def render(self, task) -> Text: + task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00" + return Text.from_markup(f"[progress.data.speed] {task_speed}it/s") + + class MetricsTextColumn(ProgressColumn): + """A column containing text.""" + + def __init__(self, trainer, stage): + self._trainer = trainer + self._stage = stage + self._tasks = {} + self._current_task_id = 0 + super().__init__() + + def render(self, task) -> Text: + if self._stage == "test" or self._trainer.sanity_checking: + return "" + if "red" in task.description and task.id not in self._tasks: + self._tasks[task.id] = "None" + if self._renderable_cache: + self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1] + self._current_task_id = task.id + if "red" in task.description and task.id != self._current_task_id: + return self._tasks[task.id] + _text = "" + if "red" in task.description or "yellow" in task.description: + for k, v in self._trainer.progress_bar_dict.items(): + _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " + text = Text.from_markup(_text, style=None, justify="left") + return text class RichProgressBar(ProgressBarBase): diff --git a/pytorch_lightning/utilities/model_summary.py b/pytorch_lightning/utilities/model_summary.py index 5f9815bcd718f..1592d74f228e4 100644 --- a/pytorch_lightning/utilities/model_summary.py +++ b/pytorch_lightning/utilities/model_summary.py @@ -333,9 +333,13 @@ def print_rich_summary(self): table = Table(title="Model Summary") table.add_column(" ") - table.add_column("Name", arrays[1][1], justify="left", style="cyan", no_wrap=True) - table.add_column("Type", arrays[2][1], style="magenta") - table.add_column("Params", arrays[3][1], justify="right", style="green") + table.add_column("Name", justify="left", style="cyan", no_wrap=True) + table.add_column("Type", style="magenta") + table.add_column("Params", justify="right", style="green") + + if self._model.example_input_array is not None: + table.add_column("In sizes", justify="right", style="green") + table.add_column("Out sizes", justify="right", style="green") rows = list(zip(*(arr[1] for arr in arrays))) for row in rows: From cbbf3bf3cf3deddb8197d4fc643e23e73d6fd6ac Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 17 Aug 2021 21:15:23 +0530 Subject: [PATCH 12/22] Update Styles --- .../callbacks/progress/rich_progress.py | 40 ++++++++++++------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 60bb4192e0cb8..5590a81be80e0 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from datetime import timedelta +from typing import Dict from pytorch_lightning.callbacks.progress.base import ProgressBarBase from pytorch_lightning.utilities import _RICH_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException if _RICH_AVAILABLE: - from rich.console import Console + from rich.console import Console, RenderableType from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn from rich.text import Text @@ -35,11 +36,11 @@ def render(self, task) -> Text: return Text.from_markup(f"[progress.elapsed]{elapsed_delta} < [progress.remaining]{remaining_delta}") class BatchesProcessedColumn(ProgressColumn): - def render(self, task) -> Text: + def render(self, task) -> RenderableType: return Text.from_markup(f"[magenta] {int(task.completed)}/{task.total}") class ProcessingSpeedColumn(ProgressColumn): - def render(self, task) -> Text: + def render(self, task) -> RenderableType: task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00" return Text.from_markup(f"[progress.data.speed] {task_speed}it/s") @@ -54,23 +55,31 @@ def __init__(self, trainer, stage): super().__init__() def render(self, task) -> Text: - if self._stage == "test" or self._trainer.sanity_checking: + if self._trainer.sanity_checking: return "" - if "red" in task.description and task.id not in self._tasks: + if self._trainer.training and task.id not in self._tasks: self._tasks[task.id] = "None" if self._renderable_cache: self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1] self._current_task_id = task.id - if "red" in task.description and task.id != self._current_task_id: + if self._trainer.training and task.id != self._current_task_id: return self._tasks[task.id] _text = "" - if "red" in task.description or "yellow" in task.description: - for k, v in self._trainer.progress_bar_dict.items(): - _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " + for k, v in self._trainer.progress_bar_dict.items(): + _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " text = Text.from_markup(_text, style=None, justify="left") return text +STYLES: Dict[str, str] = { + "train": "red", + "sanity_check": "yellow", + "validate": "yellow", + "test": "yellow", + "predict": "yellow", +} + + class RichProgressBar(ProgressBarBase): def __init__(self, refresh_rate: int = 1): if not _RICH_AVAILABLE: @@ -81,9 +90,12 @@ def __init__(self, refresh_rate: int = 1): self._refresh_rate = refresh_rate self._enabled = True self._total_val_batches = 0 + self.progress = None + self.val_sanity_progress_bar = None self.main_progress_bar = None self.val_progress_bar = None self.test_progress_bar = None + self.predict_progress_bar = None self.console = Console(record=True) @property @@ -122,7 +134,7 @@ def setup(self, trainer, pl_module, stage): def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_sanity_progress_bar = self.progress.add_task( - "[yellow][Validation Sanity Check]", + f"[{STYLES['sanity_check']}][Validation Sanity Check]", total=trainer.num_sanity_val_steps, ) @@ -141,7 +153,7 @@ def on_train_epoch_start(self, trainer, pl_module): total_batches = total_train_batches + self._total_val_batches self.main_progress_bar = self.progress.add_task( - f"[red][Epoch {trainer.current_epoch}]", + f"[{STYLES['train']}][Epoch {trainer.current_epoch}]", total=total_batches, ) @@ -149,7 +161,7 @@ def on_validation_epoch_start(self, trainer, pl_module): super().on_validation_epoch_start(trainer, pl_module) if self._total_val_batches > 0: self.val_progress_bar = self.progress.add_task( - "[yellow][Validation]", + f"[{STYLES['validate']}][Validation]", total=self._total_val_batches, ) @@ -161,14 +173,14 @@ def on_validation_epoch_end(self, trainer, pl_module): def on_test_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) self.test_progress_bar = self.progress.add_task( - "[yellow][Testing]", + f"[{STYLES['test']}][Testing]", total=self.total_test_batches, ) def on_predict_epoch_start(self, trainer, pl_module): super().on_predict_epoch_start(trainer, pl_module) self.predict_progress_bar = self.progress.add_task( - "[red][Predicting]", + f"[{STYLES['predict']}][Predicting]", total=self.total_predict_batches, ) From a94aff34f2e11638d2c219ae156fadd602aa6ab7 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 17 Aug 2021 22:07:06 +0530 Subject: [PATCH 13/22] Add tests --- CHANGELOG.md | 3 ++ tests/callbacks/test_rich_progress_bar.py | 63 +++++++++++++++++++++++ tests/helpers/runif.py | 7 +++ 3 files changed, 73 insertions(+) create mode 100644 tests/callbacks/test_rich_progress_bar.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e3d9382a4271..2371c558b7cdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974)) +- Added Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929)) + + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py new file mode 100644 index 0000000000000..4f094fc89792f --- /dev/null +++ b/tests/callbacks/test_rich_progress_bar.py @@ -0,0 +1,63 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + + +@RunIf(rich=True) +def test_rich_progress_bar_callback(): + + trainer = Trainer(callbacks=RichProgressBar()) + + progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] + + assert len(progress_bars) == 1 + assert isinstance(trainer.progress_bar_callback, RichProgressBar) + + +@RunIf(rich=True) +@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") +def test_rich_progress_bar(progress_update, tmpdir): + + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + limit_predict_batches=1, + max_steps=1, + callbacks=RichProgressBar(), + ) + + trainer.fit(model) + trainer.test(model) + trainer.predict(model) + + assert progress_update.call_count == 6 + + +def test_rich_progress_bar_misconfiguration(): + + with pytest.raises(MisconfigurationException, match="`RichProgressBar` requires `rich` to be installed."): + Trainer(callbacks=RichProgressBar()) diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 739e6ef0b8fd5..6abc2e377e6dc 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -28,6 +28,7 @@ _HOROVOD_AVAILABLE, _IPU_AVAILABLE, _NATIVE_AMP_AVAILABLE, + _RICH_AVAILABLE, _TORCH_QUANTIZE_AVAILABLE, _TPU_AVAILABLE, ) @@ -71,6 +72,7 @@ def __new__( fairscale: bool = False, fairscale_fully_sharded: bool = False, deepspeed: bool = False, + rich: bool = False, **kwargs, ): """ @@ -92,6 +94,7 @@ def __new__( fairscale: if `fairscale` module is required to run the test fairscale_fully_sharded: if `fairscale` fully sharded module is required to run the test deepspeed: if `deepspeed` module is required to run the test + rich: if `rich` module is required to run the test kwargs: native pytest.mark.skipif keyword arguments """ conditions = [] @@ -166,6 +169,10 @@ def __new__( conditions.append(not _DEEPSPEED_AVAILABLE) reasons.append("Deepspeed") + if rich: + conditions.append(not _RICH_AVAILABLE) + reasons.append("Rich") + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif( *args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs From 9d9625ec2d6f93bf75018241dc06908007d5f087 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 17 Aug 2021 22:18:49 +0530 Subject: [PATCH 14/22] Fix test --- pytorch_lightning/callbacks/progress/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress/__init__.py b/pytorch_lightning/callbacks/progress/__init__.py index 2807f009a4ed8..441c79a5ab1c6 100644 --- a/pytorch_lightning/callbacks/progress/__init__.py +++ b/pytorch_lightning/callbacks/progress/__init__.py @@ -19,5 +19,5 @@ """ from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401 -from pytorch_lightning.callbacks.progress.progress import ProgressBar # noqa: F401 +from pytorch_lightning.callbacks.progress.progress import ProgressBar, tqdm # noqa: F401 from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401 From cf4964719576b4476dbbf3e1d79751773f6aff99 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 19 Aug 2021 16:16:37 +0530 Subject: [PATCH 15/22] Add padding for train description --- .../callbacks/progress/rich_progress.py | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 5590a81be80e0..1871362e081d1 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -116,6 +116,22 @@ def disable(self) -> None: def enable(self) -> None: self._enabled = True + @property + def sanity_check_description(self) -> str: + return "[Validation Sanity Check]" + + @property + def validation_description(self) -> str: + return "[Validation]" + + @property + def test_description(self) -> str: + return "[Testing]" + + @property + def predict_description(self) -> str: + return "[Predicting]" + def setup(self, trainer, pl_module, stage): self.progress = Progress( SpinnerColumn(), @@ -134,7 +150,7 @@ def setup(self, trainer, pl_module, stage): def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_sanity_progress_bar = self.progress.add_task( - f"[{STYLES['sanity_check']}][Validation Sanity Check]", + f"[{STYLES['sanity_check']}]{self.sanity_check_description}", total=trainer.num_sanity_val_steps, ) @@ -152,8 +168,11 @@ def on_train_epoch_start(self, trainer, pl_module): self._total_val_batches = self._total_val_batches * val_checks_per_epoch total_batches = total_train_batches + self._total_val_batches + + train_description = self._get_train_description(trainer.current_epoch) + self.main_progress_bar = self.progress.add_task( - f"[{STYLES['train']}][Epoch {trainer.current_epoch}]", + f"[{STYLES['train']}]{train_description}", total=total_batches, ) @@ -161,7 +180,7 @@ def on_validation_epoch_start(self, trainer, pl_module): super().on_validation_epoch_start(trainer, pl_module) if self._total_val_batches > 0: self.val_progress_bar = self.progress.add_task( - f"[{STYLES['validate']}][Validation]", + f"[{STYLES['validate']}]{self.validation_description}", total=self._total_val_batches, ) @@ -173,14 +192,14 @@ def on_validation_epoch_end(self, trainer, pl_module): def on_test_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) self.test_progress_bar = self.progress.add_task( - f"[{STYLES['test']}][Testing]", + f"[{STYLES['test']}]{self.test_description}", total=self.total_test_batches, ) def on_predict_epoch_start(self, trainer, pl_module): super().on_predict_epoch_start(trainer, pl_module) self.predict_progress_bar = self.progress.add_task( - f"[{STYLES['predict']}][Predicting]", + f"[{STYLES['predict']}]{self.predict_description}", total=self.total_predict_batches, ) @@ -212,5 +231,16 @@ def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, da def _should_update(self, current, total): return self.is_enabled and (current % self.refresh_rate == 0 or current == total) + def _get_train_description(self, current_epoch: int) -> str: + train_description = f"[Epoch {current_epoch}]" + if len(self.validation_description) > len(train_description): + # Padding is required to avoid flickering due of uneven lengths of "Epoch X" + # and "Validation" Bar description + num_digits = len(str(current_epoch)) + required_padding = (len(self.validation_description) - len(train_description) + 1) - num_digits + for _ in range(required_padding): + train_description += " " + return train_description + def teardown(self, trainer, pl_module, stage): self.progress.__exit__(None, None, None) From 6acefa7abba4284dc55f25ea70af343ae0df9645 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 19 Aug 2021 19:27:51 +0530 Subject: [PATCH 16/22] Update progress metrics --- pytorch_lightning/callbacks/progress/progress.py | 9 --------- pytorch_lightning/callbacks/progress/rich_progress.py | 4 ++-- tests/callbacks/test_progress_bar.py | 2 +- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/progress.py b/pytorch_lightning/callbacks/progress/progress.py index e9834cee81d7a..be6d799d59fc2 100644 --- a/pytorch_lightning/callbacks/progress/progress.py +++ b/pytorch_lightning/callbacks/progress/progress.py @@ -56,7 +56,6 @@ class ProgressBar(ProgressBarBase): r""" This is the default progress bar used by Lightning. It prints to `stdout` using the :mod:`tqdm` package and shows up to four different bars: - - **sanity check progress:** the progress during the sanity check run - **main progress:** shows training + validation progress combined. It also accounts for multiple validation runs during training when @@ -64,25 +63,18 @@ class ProgressBar(ProgressBarBase): - **validation progress:** only visible during validation; shows total progress over all validation datasets. - **test progress:** only active when testing; shows total progress over all test datasets. - For infinite datasets, the progress bar never ends. - If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the :class:`~pytorch_lightning.trainer.trainer.Trainer`: - Example:: - class LitProgressBar(ProgressBar): - def init_validation_tqdm(self): bar = super().init_validation_tqdm() bar.set_description('running validation ...') return bar - bar = LitProgressBar() trainer = Trainer(callbacks=[bar]) - Args: refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. @@ -97,7 +89,6 @@ def init_validation_tqdm(self): together. This corresponds to :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. - """ def __init__(self, refresh_rate: int = 1, process_position: int = 0): diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 1871362e081d1..541af10fbe016 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -55,7 +55,7 @@ def __init__(self, trainer, stage): super().__init__() def render(self, task) -> Text: - if self._trainer.sanity_checking: + if self._stage != "fit" or self._trainer.sanity_checking: return "" if self._trainer.training and task.id not in self._tasks: self._tasks[task.id] = "None" @@ -81,7 +81,7 @@ def render(self, task) -> Text: class RichProgressBar(ProgressBarBase): - def __init__(self, refresh_rate: int = 1): + def __init__(self, refresh_rate: float = 1.0): if not _RICH_AVAILABLE: raise MisconfigurationException( "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`." diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 1c3176f39a886..378f58d54c0a7 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -24,7 +24,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase -from pytorch_lightning.callbacks.progress import tqdm +from pytorch_lightning.callbacks.progress.progress import tqdm from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf From cc23c48e01d6b3785b9031e7d0afb0fcf28b140a Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 19 Aug 2021 20:12:07 +0530 Subject: [PATCH 17/22] Remove Model summary rich --- .../callbacks/progress/__init__.py | 2 +- pytorch_lightning/trainer/trainer.py | 5 +- pytorch_lightning/utilities/model_summary.py | 78 ++----------------- 3 files changed, 11 insertions(+), 74 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/__init__.py b/pytorch_lightning/callbacks/progress/__init__.py index 441c79a5ab1c6..2807f009a4ed8 100644 --- a/pytorch_lightning/callbacks/progress/__init__.py +++ b/pytorch_lightning/callbacks/progress/__init__.py @@ -19,5 +19,5 @@ """ from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401 -from pytorch_lightning.callbacks.progress.progress import ProgressBar, tqdm # noqa: F401 +from pytorch_lightning.callbacks.progress.progress import ProgressBar # noqa: F401 from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 01e44bfb50439..887cdd46a9db2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -25,7 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.accelerators import Accelerator, IPUAccelerator -from pytorch_lightning.callbacks import Callback, RichProgressBar +from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops import IteratorBatchProcessor, TrainingBatchLoop, TrainingEpochLoop @@ -1111,9 +1111,8 @@ def _pre_training_routine(self): # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: - use_rich = isinstance(self.progress_bar_callback, RichProgressBar) max_depth = ModelSummary.MODES[self.weights_summary] - summarize(ref_model, max_depth=max_depth, use_rich=use_rich) + summarize(ref_model, max_depth=max_depth) # on pretrain routine end self.on_pretrain_routine_end() diff --git a/pytorch_lightning/utilities/model_summary.py b/pytorch_lightning/utilities/model_summary.py index 1592d74f228e4..896347bd1b30a 100644 --- a/pytorch_lightning/utilities/model_summary.py +++ b/pytorch_lightning/utilities/model_summary.py @@ -25,13 +25,9 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _RICH_AVAILABLE, _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.warnings import WarningCache -if _RICH_AVAILABLE: - from rich.console import Console - from rich.table import Table - log = logging.getLogger(__name__) warning_cache = WarningCache() @@ -303,7 +299,12 @@ def _forward_example_input(self) -> None: model(input_) model.train(mode) # restore mode of module - def _get_summary_data(self): + def __str__(self): + """ + Makes a summary listing with: + + Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size + """ arrays = [ [" ", list(map(str, range(len(self._layer_summary))))], ["Name", self.layer_names], @@ -314,65 +315,6 @@ def _get_summary_data(self): arrays.append(["In sizes", self.in_sizes]) arrays.append(["Out sizes", self.out_sizes]) - return arrays - - def print_rich_summary(self): - - if not _RICH_AVAILABLE: - raise MisconfigurationException( - "`print_rich_summary` requires `rich` to be installed." " Install it by running `pip install rich`." - ) - - arrays = self._get_summary_data() - total_parameters = self.total_parameters - trainable_parameters = self.trainable_parameters - model_size = self.model_size - - console = Console() - - table = Table(title="Model Summary") - - table.add_column(" ") - table.add_column("Name", justify="left", style="cyan", no_wrap=True) - table.add_column("Type", style="magenta") - table.add_column("Params", justify="right", style="green") - - if self._model.example_input_array is not None: - table.add_column("In sizes", justify="right", style="green") - table.add_column("Out sizes", justify="right", style="green") - - rows = list(zip(*(arr[1] for arr in arrays))) - for row in rows: - table.add_row(*row) - - console.print(table) - - # Formatting - s = "{:<{}}" - - parameters = [] - for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]: - parameters.append(s.format(get_human_readable_count(param), 10)) - - grid = Table.grid(expand=True) - grid.add_column() - grid.add_column() - - grid.add_row(f"[bold]Trainable params[/]: {parameters[0]}") - grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}") - grid.add_row(f"[bold]Total params[/]: {parameters[2]}") - grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}") - - console.print(grid) - - def __str__(self): - """ - Makes a summary listing with: - - Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size - """ - arrays = self._get_summary_data() - total_parameters = self.total_parameters trainable_parameters = self.trainable_parameters model_size = self.model_size @@ -497,7 +439,6 @@ def summarize( lightning_module: "pl.LightningModule", mode: Optional[str] = "top", max_depth: Optional[int] = None, - use_rich: bool = False, ) -> Optional[ModelSummary]: """ Summarize the LightningModule specified by `lightning_module`. @@ -529,8 +470,5 @@ def summarize( raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}") else: model_summary = ModelSummary(lightning_module, max_depth=max_depth) - if use_rich: - model_summary.print_rich_summary() - else: - log.info("\n" + str(model_summary)) + log.info("\n" + str(model_summary)) return model_summary From f1fa9ee9bda95def26ecf5a7651f516fbeda13d7 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 19 Aug 2021 20:40:59 +0530 Subject: [PATCH 18/22] Update imports --- pytorch_lightning/callbacks/progress/__init__.py | 2 +- tests/callbacks/test_progress_bar.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/__init__.py b/pytorch_lightning/callbacks/progress/__init__.py index 2807f009a4ed8..441c79a5ab1c6 100644 --- a/pytorch_lightning/callbacks/progress/__init__.py +++ b/pytorch_lightning/callbacks/progress/__init__.py @@ -19,5 +19,5 @@ """ from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401 -from pytorch_lightning.callbacks.progress.progress import ProgressBar # noqa: F401 +from pytorch_lightning.callbacks.progress.progress import ProgressBar, tqdm # noqa: F401 from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401 diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 378f58d54c0a7..1c3176f39a886 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -24,7 +24,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase -from pytorch_lightning.callbacks.progress.progress import tqdm +from pytorch_lightning.callbacks.progress import tqdm from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf From f106b22124a60f34174ab2d4c7f60cdd087c7c36 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 20 Aug 2021 13:33:31 +0530 Subject: [PATCH 19/22] Address reviews --- .../callbacks/progress/rich_progress.py | 52 +++++++++---------- pytorch_lightning/utilities/imports.py | 2 +- pytorch_lightning/utilities/model_summary.py | 4 +- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 541af10fbe016..51ac774b6c7ed 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from datetime import timedelta -from typing import Dict +from typing import Dict, Optional from pytorch_lightning.callbacks.progress.base import ProgressBarBase from pytorch_lightning.utilities import _RICH_AVAILABLE @@ -87,15 +87,15 @@ def __init__(self, refresh_rate: float = 1.0): "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`." ) super().__init__() - self._refresh_rate = refresh_rate - self._enabled = True - self._total_val_batches = 0 - self.progress = None - self.val_sanity_progress_bar = None - self.main_progress_bar = None - self.val_progress_bar = None - self.test_progress_bar = None - self.predict_progress_bar = None + self._refresh_rate: float = refresh_rate + self._enabled: bool = True + self._total_val_batches: int = 0 + self.progress: Progress = None + self.val_sanity_progress_bar_id: Optional[int] = None + self.main_progress_bar_id: Optional[int] = None + self.val_progress_bar_id: Optional[int] = None + self.test_progress_bar_id: Optional[int] = None + self.predict_progress_bar_id: Optional[int] = None self.console = Console(record=True) @property @@ -149,14 +149,14 @@ def setup(self, trainer, pl_module, stage): def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) - self.val_sanity_progress_bar = self.progress.add_task( + self.val_sanity_progress_bar_id = self.progress.add_task( f"[{STYLES['sanity_check']}]{self.sanity_check_description}", total=trainer.num_sanity_val_steps, ) def on_sanity_check_end(self, trainer, pl_module): super().on_sanity_check_end(trainer, pl_module) - self.progress.update(self.val_sanity_progress_bar, visible=False) + self.progress.update(self.val_sanity_progress_bar_id, visible=False) def on_train_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) @@ -171,7 +171,7 @@ def on_train_epoch_start(self, trainer, pl_module): train_description = self._get_train_description(trainer.current_epoch) - self.main_progress_bar = self.progress.add_task( + self.main_progress_bar_id = self.progress.add_task( f"[{STYLES['train']}]{train_description}", total=total_batches, ) @@ -179,26 +179,26 @@ def on_train_epoch_start(self, trainer, pl_module): def on_validation_epoch_start(self, trainer, pl_module): super().on_validation_epoch_start(trainer, pl_module) if self._total_val_batches > 0: - self.val_progress_bar = self.progress.add_task( + self.val_progress_bar_id = self.progress.add_task( f"[{STYLES['validate']}]{self.validation_description}", total=self._total_val_batches, ) def on_validation_epoch_end(self, trainer, pl_module): super().on_validation_epoch_end(trainer, pl_module) - if self.val_progress_bar is not None: - self.progress.update(self.val_progress_bar, visible=False) + if self.val_progress_bar_id is not None: + self.progress.update(self.val_progress_bar_id, visible=False) def on_test_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) - self.test_progress_bar = self.progress.add_task( + self.test_progress_bar_id = self.progress.add_task( f"[{STYLES['test']}]{self.test_description}", total=self.total_test_batches, ) def on_predict_epoch_start(self, trainer, pl_module): super().on_predict_epoch_start(trainer, pl_module) - self.predict_progress_bar = self.progress.add_task( + self.predict_progress_bar_id = self.progress.add_task( f"[{STYLES['predict']}]{self.predict_description}", total=self.total_predict_batches, ) @@ -206,29 +206,29 @@ def on_predict_epoch_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches): - self.progress.update(self.main_progress_bar, advance=1.0) + self.progress.update(self.main_progress_bar_id, advance=1.0) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if trainer.sanity_checking: - self.progress.update(self.val_sanity_progress_bar, advance=1.0) - elif self.val_progress_bar and self._should_update( + self.progress.update(self.val_sanity_progress_bar_id, advance=1.0) + elif self.val_progress_bar_id and self._should_update( self.val_batch_idx, self.total_train_batches + self.total_val_batches ): - self.progress.update(self.main_progress_bar, advance=1.0) - self.progress.update(self.val_progress_bar, advance=1.0) + self.progress.update(self.main_progress_bar_id, advance=1.0) + self.progress.update(self.val_progress_bar_id, advance=1.0) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.test_batch_idx, self.total_test_batches): - self.progress.update(self.test_progress_bar, advance=1.0) + self.progress.update(self.test_progress_bar_id, advance=1.0) def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.predict_batch_idx, self.total_predict_batches): - self.progress.update(self.predict_progress_bar, advance=1.0) + self.progress.update(self.predict_progress_bar_id, advance=1.0) - def _should_update(self, current, total): + def _should_update(self, current, total) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def _get_train_description(self, current_epoch: int) -> str: diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 103733417159d..f947d647872ad 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -84,13 +84,13 @@ def _compare_version(package: str, op, version) -> bool: _NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") _OMEGACONF_AVAILABLE = _module_available("omegaconf") _POPTORCH_AVAILABLE = _module_available("poptorch") +_RICH_AVAILABLE = _module_available("rich") _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"]) _TORCHTEXT_AVAILABLE = _module_available("torchtext") _TORCHVISION_AVAILABLE = _module_available("torchvision") _TORCHMETRICS_LOWER_THAN_0_3 = _compare_version("torchmetrics", operator.lt, "0.3.0") _TORCHMETRICS_GREATER_EQUAL_0_3 = _compare_version("torchmetrics", operator.ge, "0.3.0") _XLA_AVAILABLE: bool = _module_available("torch_xla") -_RICH_AVAILABLE = _module_available("rich") from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402 diff --git a/pytorch_lightning/utilities/model_summary.py b/pytorch_lightning/utilities/model_summary.py index 896347bd1b30a..d664d4774e870 100644 --- a/pytorch_lightning/utilities/model_summary.py +++ b/pytorch_lightning/utilities/model_summary.py @@ -436,9 +436,7 @@ def _is_lazy_weight_tensor(p: Tensor) -> bool: def summarize( - lightning_module: "pl.LightningModule", - mode: Optional[str] = "top", - max_depth: Optional[int] = None, + lightning_module: "pl.LightningModule", mode: Optional[str] = "top", max_depth: Optional[int] = None ) -> Optional[ModelSummary]: """ Summarize the LightningModule specified by `lightning_module`. From bc2b659a4bede8747b5d5369181e0a1428b5778e Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 23 Aug 2021 18:44:51 +0530 Subject: [PATCH 20/22] Add docstring --- .../callbacks/progress/rich_progress.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 51ac774b6c7ed..8182933a32079 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -16,7 +16,6 @@ from pytorch_lightning.callbacks.progress.base import ProgressBarBase from pytorch_lightning.utilities import _RICH_AVAILABLE -from pytorch_lightning.utilities.exceptions import MisconfigurationException if _RICH_AVAILABLE: from rich.console import Console, RenderableType @@ -81,9 +80,33 @@ def render(self, task) -> Text: class RichProgressBar(ProgressBarBase): + """ + Create a progress bar with `rich text formatting `_. + + Install it with pip: + + .. code-block:: bash + + pip install rich + + .. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import RichProgressBar + + trainer = Trainer(callbacks=RichProgressBar()) + + Args: + refresh_rate: the number of updates per second, must be strictly positive + + Raises: + ImportError: + If required `rich` package is not installed on the device. + """ + def __init__(self, refresh_rate: float = 1.0): if not _RICH_AVAILABLE: - raise MisconfigurationException( + raise ImportError( "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`." ) super().__init__() From af9c978e49a1c0f7459a3a818598b2ef72d8ce51 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 23 Aug 2021 18:50:30 +0530 Subject: [PATCH 21/22] Update code format --- pytorch_lightning/callbacks/progress/progress.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytorch_lightning/callbacks/progress/progress.py b/pytorch_lightning/callbacks/progress/progress.py index be6d799d59fc2..adcea3d581e17 100644 --- a/pytorch_lightning/callbacks/progress/progress.py +++ b/pytorch_lightning/callbacks/progress/progress.py @@ -56,6 +56,7 @@ class ProgressBar(ProgressBarBase): r""" This is the default progress bar used by Lightning. It prints to `stdout` using the :mod:`tqdm` package and shows up to four different bars: + - **sanity check progress:** the progress during the sanity check run - **main progress:** shows training + validation progress combined. It also accounts for multiple validation runs during training when @@ -63,18 +64,25 @@ class ProgressBar(ProgressBarBase): - **validation progress:** only visible during validation; shows total progress over all validation datasets. - **test progress:** only active when testing; shows total progress over all test datasets. + For infinite datasets, the progress bar never ends. + If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the :class:`~pytorch_lightning.trainer.trainer.Trainer`: + Example:: + class LitProgressBar(ProgressBar): + def init_validation_tqdm(self): bar = super().init_validation_tqdm() bar.set_description('running validation ...') return bar + bar = LitProgressBar() trainer = Trainer(callbacks=[bar]) + Args: refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. From 7df6033a19d25bfb778a3d9ef8a6cc2cc516ce5c Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 23 Aug 2021 19:06:40 +0530 Subject: [PATCH 22/22] Update test --- tests/callbacks/test_rich_progress_bar.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 4f094fc89792f..c6f44759ba371 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -17,7 +17,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar -from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -57,7 +56,7 @@ def test_rich_progress_bar(progress_update, tmpdir): assert progress_update.call_count == 6 -def test_rich_progress_bar_misconfiguration(): +def test_rich_progress_bar_import_error(): - with pytest.raises(MisconfigurationException, match="`RichProgressBar` requires `rich` to be installed."): + with pytest.raises(ImportError, match="`RichProgressBar` requires `rich` to be installed."): Trainer(callbacks=RichProgressBar())