Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
fefa23d
Add Rich Progress Bar
kaushikb11 Aug 15, 2021
6be12ed
Add CustomTimeColumn
kaushikb11 Aug 15, 2021
fa42556
Add BatchesProcessedColumn & ProcessingSpeedColumn
kaushikb11 Aug 15, 2021
86e4c51
Add support for Testing Bar
kaushikb11 Aug 16, 2021
c5a5824
Add updates for validation bar
kaushikb11 Aug 16, 2021
93a6e8b
Class refactor
kaushikb11 Aug 16, 2021
4794814
Update
kaushikb11 Aug 16, 2021
1abb6ed
Add support for display per epoch
kaushikb11 Aug 17, 2021
ca33976
Update Sanity & predict bar
kaushikb11 Aug 17, 2021
251cc46
Add rich for Model Summary
kaushikb11 Aug 17, 2021
539ba08
Update model summary
kaushikb11 Aug 17, 2021
cbbf3bf
Update Styles
kaushikb11 Aug 17, 2021
a94aff3
Add tests
kaushikb11 Aug 17, 2021
9d9625e
Fix test
kaushikb11 Aug 17, 2021
cf49647
Add padding for train description
kaushikb11 Aug 19, 2021
6acefa7
Update progress metrics
kaushikb11 Aug 19, 2021
cc23c48
Remove Model summary rich
kaushikb11 Aug 19, 2021
f1fa9ee
Update imports
kaushikb11 Aug 19, 2021
f106b22
Address reviews
kaushikb11 Aug 20, 2021
bc2b659
Add docstring
kaushikb11 Aug 23, 2021
6998b49
Merge branch 'master' into add/rich_logging
kaushikb11 Aug 23, 2021
af9c978
Update code format
kaushikb11 Aug 23, 2021
88c6b65
Merge branch 'add/rich_logging' of https://github.com/kaushikb11/pyto…
kaushikb11 Aug 23, 2021
7df6033
Update test
kaushikb11 Aug 23, 2021
266bd66
Merge branch 'master' into add/rich_logging
tchaton Aug 23, 2021
721a7a6
Merge branch 'master' into add/rich_logging
kaushikb11 Aug 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a friendly error message when DDP attempts to spawn new distributed processes with rank > 0 ([#9005](https://github.com/PyTorchLightning/pytorch-lightning/pull/9005))


- Added Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))


### Changed

- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.pruning import ModelPruning
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
Expand All @@ -45,4 +45,5 @@
"QuantizationAwareTraining",
"StochasticWeightAveraging",
"Timer",
"RichProgressBar",
]
23 changes: 23 additions & 0 deletions pytorch_lightning/callbacks/progress/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Progress Bars
=============

Use or override one of the progress bar callbacks.

"""
from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401
from pytorch_lightning.callbacks.progress.progress import ProgressBar, tqdm # noqa: F401
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401
181 changes: 181 additions & 0 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks import Callback


class ProgressBarBase(Callback):
r"""
The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback`
that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`.
You should implement your highly custom progress bars with this as the base class.

Example::

class LitProgressBar(ProgressBarBase):

def __init__(self):
super().__init__() # don't forget this :)
self.enable = True

def disable(self):
self.enable = False

def on_train_batch_end(self, trainer, pl_module, outputs):
super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :)
percent = (self.train_batch_idx / self.total_train_batches) * 100
sys.stdout.flush()
sys.stdout.write(f'{percent:.01f} percent complete \r')

bar = LitProgressBar()
trainer = Trainer(callbacks=[bar])

"""

def __init__(self):

self._trainer = None
self._train_batch_idx = 0
self._val_batch_idx = 0
self._test_batch_idx = 0
self._predict_batch_idx = 0

@property
def trainer(self):
return self._trainer

@property
def train_batch_idx(self) -> int:
"""
The current batch index being processed during training.
Use this to update your progress bar.
"""
return self._train_batch_idx

@property
def val_batch_idx(self) -> int:
"""
The current batch index being processed during validation.
Use this to update your progress bar.
"""
return self._val_batch_idx

@property
def test_batch_idx(self) -> int:
"""
The current batch index being processed during testing.
Use this to update your progress bar.
"""
return self._test_batch_idx

@property
def predict_batch_idx(self) -> int:
"""
The current batch index being processed during predicting.
Use this to update your progress bar.
"""
return self._predict_batch_idx

@property
def total_train_batches(self) -> int:
"""
The total number of training batches during training, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
training dataloader is of infinite size.
"""
return self.trainer.num_training_batches

@property
def total_val_batches(self) -> int:
"""
The total number of validation batches during validation, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
validation dataloader is of infinite size.
"""
total_val_batches = 0
if self.trainer.enable_validation:
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0

return total_val_batches

@property
def total_test_batches(self) -> int:
"""
The total number of testing batches during testing, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
test dataloader is of infinite size.
"""
return sum(self.trainer.num_test_batches)

@property
def total_predict_batches(self) -> int:
"""
The total number of predicting batches during testing, which may change from epoch to epoch.
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
predict dataloader is of infinite size.
"""
return sum(self.trainer.num_predict_batches)

def disable(self):
"""
You should provide a way to disable the progress bar.
The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the
output on processes that have a rank different from 0, e.g., in multi-node training.
"""
raise NotImplementedError

def enable(self):
"""
You should provide a way to enable the progress bar.
The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training
routines like the :ref:`learning rate finder <advanced/lr_finder:Learning Rate Finder>`
to temporarily enable and disable the main progress bar.
"""
raise NotImplementedError

def print(self, *args, **kwargs):
"""
You should provide a way to print without breaking the progress bar.
"""
print(*args, **kwargs)

def on_init_end(self, trainer):
self._trainer = trainer

def on_train_start(self, trainer, pl_module):
self._train_batch_idx = trainer.fit_loop.batch_idx

def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._train_batch_idx += 1

def on_validation_start(self, trainer, pl_module):
self._val_batch_idx = 0

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._val_batch_idx += 1

def on_test_start(self, trainer, pl_module):
self._test_batch_idx = 0

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._test_batch_idx += 1

def on_predict_epoch_start(self, trainer, pl_module):
self._predict_batch_idx = 0

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._predict_batch_idx += 1
Loading