Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved `block_ddp_sync_behaviour` out of `TrainingBatchLoop` to loop utilities ([#9192](https://github.com/PyTorchLightning/pytorch-lightning/pull/9192))


- Renamed default `ProgressBar` to `TQDMProgressBar` ([#9229](https://github.com/PyTorchLightning/pytorch-lightning/pull/9229))


### Deprecated

- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/callbacks/progress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@

"""
from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401
from pytorch_lightning.callbacks.progress.progress import ProgressBar, tqdm # noqa: F401
from pytorch_lightning.callbacks.progress.progress import ProgressBar # noqa: F401
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401
from pytorch_lightning.callbacks.progress.tqdm_progress import tqdm, TQDMProgressBar # noqa: F401
328 changes: 3 additions & 325 deletions pytorch_lightning/callbacks/progress/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,329 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import io
import math
import os
import sys
from typing import Optional, Union
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
if importlib.util.find_spec("ipywidgets") is not None:
from tqdm.auto import tqdm as _tqdm
else:
from tqdm import tqdm as _tqdm

from pytorch_lightning.callbacks.progress.base import ProgressBarBase

_PAD_SIZE = 5


class tqdm(_tqdm):
"""
Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering
"""

@staticmethod
def format_num(n) -> str:
"""Add additional padding to the formatted numbers"""
should_be_padded = isinstance(n, (float, str))
if not isinstance(n, str):
n = _tqdm.format_num(n)
if should_be_padded and "e" not in n:
if "." not in n and len(n) < _PAD_SIZE:
try:
_ = float(n)
except ValueError:
return n
n += "."
n += "0" * (_PAD_SIZE - len(n))
return n


class ProgressBar(ProgressBarBase):
r"""
This is the default progress bar used by Lightning. It prints to `stdout` using the
:mod:`tqdm` package and shows up to four different bars:
- **sanity check progress:** the progress during the sanity check run
- **main progress:** shows training + validation progress combined. It also accounts for
multiple validation runs during training when
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used.
- **validation progress:** only visible during validation;
shows total progress over all validation datasets.
- **test progress:** only active when testing; shows total progress over all test datasets.
For infinite datasets, the progress bar never ends.
If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override
specific methods of the callback class and pass your custom implementation to the
:class:`~pytorch_lightning.trainer.trainer.Trainer`:
Example::
class LitProgressBar(ProgressBar):
def init_validation_tqdm(self):
bar = super().init_validation_tqdm()
bar.set_description('running validation ...')
return bar
bar = LitProgressBar()
trainer = Trainer(callbacks=[bar])
Args:
refresh_rate:
Determines at which rate (in number of batches) the progress bars get updated.
Set it to ``0`` to disable the display. By default, the
:class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress
bar and sets the refresh rate to the value provided to the
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
process_position:
Set this to a value greater than ``0`` to offset the progress bars by this many lines.
This is useful when you have progress bars defined elsewhere and want to show all of them
together. This corresponds to
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
"""

def __init__(self, refresh_rate: int = 1, process_position: int = 0):
super().__init__()
self._refresh_rate = refresh_rate
self._process_position = process_position
self._enabled = True
self.main_progress_bar = None
self.val_progress_bar = None
self.test_progress_bar = None
self.predict_progress_bar = None

def __getstate__(self):
# can't pickle the tqdm objects
state = self.__dict__.copy()
state["main_progress_bar"] = None
state["val_progress_bar"] = None
state["test_progress_bar"] = None
state["predict_progress_bar"] = None
return state

@property
def refresh_rate(self) -> int:
return self._refresh_rate

@property
def process_position(self) -> int:
return self._process_position

@property
def is_enabled(self) -> bool:
return self._enabled and self.refresh_rate > 0

@property
def is_disabled(self) -> bool:
return not self.is_enabled

def disable(self) -> None:
self._enabled = False

def enable(self) -> None:
self._enabled = True

def init_sanity_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for the validation sanity run."""
bar = tqdm(
desc="Validation sanity check",
position=(2 * self.process_position),
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
file=sys.stdout,
)
return bar

def init_train_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for training."""
bar = tqdm(
desc="Training",
initial=self.train_batch_idx,
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
dynamic_ncols=True,
file=sys.stdout,
smoothing=0,
)
return bar

def init_predict_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for predicting."""
bar = tqdm(
desc="Predicting",
initial=self.train_batch_idx,
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
dynamic_ncols=True,
file=sys.stdout,
smoothing=0,
)
return bar

def init_validation_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for validation."""
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.main_progress_bar is not None
bar = tqdm(
desc="Validating",
position=(2 * self.process_position + has_main_bar),
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
file=sys.stdout,
)
return bar

def init_test_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for testing."""
bar = tqdm(
desc="Testing",
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
dynamic_ncols=True,
file=sys.stdout,
)
return bar

def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
self.main_progress_bar = tqdm(disable=True) # dummy progress bar

def on_sanity_check_end(self, trainer, pl_module):
super().on_sanity_check_end(trainer, pl_module)
self.main_progress_bar.close()
self.val_progress_bar.close()

def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self.main_progress_bar = self.init_train_tqdm()

def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float("inf") and total_val_batches != float("inf"):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
total_batches = total_train_batches + total_val_batches
reset(self.main_progress_bar, total_batches)
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
total_batches = self.total_train_batches + self.total_val_batches
total_batches = convert_inf(total_batches)
if self._should_update(self.train_batch_idx, total_batches):
self._update_bar(self.main_progress_bar)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
if trainer.sanity_checking:
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
else:
self._update_bar(self.main_progress_bar) # fill up remaining
self.val_progress_bar = self.init_validation_tqdm()
reset(self.val_progress_bar, self.total_val_batches)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)):
self._update_bar(self.val_progress_bar)
self._update_bar(self.main_progress_bar)

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
if self.main_progress_bar is not None:
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
self.val_progress_bar.close()

def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module)
self.main_progress_bar.close()

def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.test_batch_idx, self.total_test_batches):
self._update_bar(self.test_progress_bar)

def on_test_end(self, trainer, pl_module):
super().on_test_end(trainer, pl_module)
self.test_progress_bar.close()

def on_predict_epoch_start(self, trainer, pl_module):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar = self.init_predict_tqdm()
self.predict_progress_bar.total = convert_inf(self.total_predict_batches)

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
self._update_bar(self.predict_progress_bar)

def on_predict_end(self, trainer, pl_module):
self.predict_progress_bar.close()

def print(
self, *args, sep: str = " ", end: str = os.linesep, file: Optional[io.TextIOBase] = None, nolock: bool = False
):
active_progress_bar = None

if self.main_progress_bar is not None and not self.main_progress_bar.disable:
active_progress_bar = self.main_progress_bar
elif self.val_progress_bar is not None and not self.val_progress_bar.disable:
active_progress_bar = self.val_progress_bar
elif self.test_progress_bar is not None and not self.test_progress_bar.disable:
active_progress_bar = self.test_progress_bar
elif self.predict_progress_bar is not None and not self.predict_progress_bar.disable:
active_progress_bar = self.predict_progress_bar

if active_progress_bar is not None:
s = sep.join(map(str, args))
active_progress_bar.write(s, end=end, file=file, nolock=nolock)

def _should_update(self, current, total) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def _update_bar(self, bar: Optional[tqdm]) -> None:
"""Updates the bar by the refresh rate without overshooting."""
if bar is None:
return
if bar.total is not None:
delta = min(self.refresh_rate, bar.total - bar.n)
else:
# infinite / unknown size
delta = self.refresh_rate
if delta > 0:
bar.update(delta)


def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
"""The tqdm doesn't support inf/nan values. We have to convert it to None."""
if x is None or math.isinf(x) or math.isnan(x):
return None
return x


def reset(bar: tqdm, total: Optional[int] = None) -> None:
"""Resets the tqdm bar to 0 progress with a new total, unless it is disabled."""
if not bar.disable:
bar.reset(total=convert_inf(total))
# points to the default implementation
ProgressBar = TQDMProgressBar
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather inherit and add a deprecation warning, see what we did with Metrics

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda reference link?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for example #6584

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this

class ProgressBar(TQDMProgressBar):
    pass

Loading