Skip to content

Commit 538e743

Browse files
kaushikb11tchaton
andauthored
feat: Add Rich Progress Bar (#8929)
Co-authored-by: thomas chaton <[email protected]>
1 parent 1e4d892 commit 538e743

File tree

11 files changed

+552
-178
lines changed

11 files changed

+552
-178
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7575
- Added a friendly error message when DDP attempts to spawn new distributed processes with rank > 0 ([#9005](https://github.com/PyTorchLightning/pytorch-lightning/pull/9005))
7676

7777

78+
- Added Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))
79+
80+
7881
### Changed
7982

8083
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))

pytorch_lightning/callbacks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
2121
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
2222
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
23-
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase
23+
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar
2424
from pytorch_lightning.callbacks.pruning import ModelPruning
2525
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
2626
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
@@ -45,4 +45,5 @@
4545
"QuantizationAwareTraining",
4646
"StochasticWeightAveraging",
4747
"Timer",
48+
"RichProgressBar",
4849
]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Progress Bars
16+
=============
17+
18+
Use or override one of the progress bar callbacks.
19+
20+
"""
21+
from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401
22+
from pytorch_lightning.callbacks.progress.progress import ProgressBar, tqdm # noqa: F401
23+
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pytorch_lightning.callbacks import Callback
15+
16+
17+
class ProgressBarBase(Callback):
18+
r"""
19+
The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback`
20+
that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`.
21+
You should implement your highly custom progress bars with this as the base class.
22+
23+
Example::
24+
25+
class LitProgressBar(ProgressBarBase):
26+
27+
def __init__(self):
28+
super().__init__() # don't forget this :)
29+
self.enable = True
30+
31+
def disable(self):
32+
self.enable = False
33+
34+
def on_train_batch_end(self, trainer, pl_module, outputs):
35+
super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :)
36+
percent = (self.train_batch_idx / self.total_train_batches) * 100
37+
sys.stdout.flush()
38+
sys.stdout.write(f'{percent:.01f} percent complete \r')
39+
40+
bar = LitProgressBar()
41+
trainer = Trainer(callbacks=[bar])
42+
43+
"""
44+
45+
def __init__(self):
46+
47+
self._trainer = None
48+
self._train_batch_idx = 0
49+
self._val_batch_idx = 0
50+
self._test_batch_idx = 0
51+
self._predict_batch_idx = 0
52+
53+
@property
54+
def trainer(self):
55+
return self._trainer
56+
57+
@property
58+
def train_batch_idx(self) -> int:
59+
"""
60+
The current batch index being processed during training.
61+
Use this to update your progress bar.
62+
"""
63+
return self._train_batch_idx
64+
65+
@property
66+
def val_batch_idx(self) -> int:
67+
"""
68+
The current batch index being processed during validation.
69+
Use this to update your progress bar.
70+
"""
71+
return self._val_batch_idx
72+
73+
@property
74+
def test_batch_idx(self) -> int:
75+
"""
76+
The current batch index being processed during testing.
77+
Use this to update your progress bar.
78+
"""
79+
return self._test_batch_idx
80+
81+
@property
82+
def predict_batch_idx(self) -> int:
83+
"""
84+
The current batch index being processed during predicting.
85+
Use this to update your progress bar.
86+
"""
87+
return self._predict_batch_idx
88+
89+
@property
90+
def total_train_batches(self) -> int:
91+
"""
92+
The total number of training batches during training, which may change from epoch to epoch.
93+
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
94+
training dataloader is of infinite size.
95+
"""
96+
return self.trainer.num_training_batches
97+
98+
@property
99+
def total_val_batches(self) -> int:
100+
"""
101+
The total number of validation batches during validation, which may change from epoch to epoch.
102+
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
103+
validation dataloader is of infinite size.
104+
"""
105+
total_val_batches = 0
106+
if self.trainer.enable_validation:
107+
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
108+
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0
109+
110+
return total_val_batches
111+
112+
@property
113+
def total_test_batches(self) -> int:
114+
"""
115+
The total number of testing batches during testing, which may change from epoch to epoch.
116+
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
117+
test dataloader is of infinite size.
118+
"""
119+
return sum(self.trainer.num_test_batches)
120+
121+
@property
122+
def total_predict_batches(self) -> int:
123+
"""
124+
The total number of predicting batches during testing, which may change from epoch to epoch.
125+
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
126+
predict dataloader is of infinite size.
127+
"""
128+
return sum(self.trainer.num_predict_batches)
129+
130+
def disable(self):
131+
"""
132+
You should provide a way to disable the progress bar.
133+
The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the
134+
output on processes that have a rank different from 0, e.g., in multi-node training.
135+
"""
136+
raise NotImplementedError
137+
138+
def enable(self):
139+
"""
140+
You should provide a way to enable the progress bar.
141+
The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training
142+
routines like the :ref:`learning rate finder <advanced/lr_finder:Learning Rate Finder>`
143+
to temporarily enable and disable the main progress bar.
144+
"""
145+
raise NotImplementedError
146+
147+
def print(self, *args, **kwargs):
148+
"""
149+
You should provide a way to print without breaking the progress bar.
150+
"""
151+
print(*args, **kwargs)
152+
153+
def on_init_end(self, trainer):
154+
self._trainer = trainer
155+
156+
def on_train_start(self, trainer, pl_module):
157+
self._train_batch_idx = trainer.fit_loop.batch_idx
158+
159+
def on_train_epoch_start(self, trainer, pl_module):
160+
self._train_batch_idx = 0
161+
162+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
163+
self._train_batch_idx += 1
164+
165+
def on_validation_start(self, trainer, pl_module):
166+
self._val_batch_idx = 0
167+
168+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
169+
self._val_batch_idx += 1
170+
171+
def on_test_start(self, trainer, pl_module):
172+
self._test_batch_idx = 0
173+
174+
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
175+
self._test_batch_idx += 1
176+
177+
def on_predict_epoch_start(self, trainer, pl_module):
178+
self._predict_batch_idx = 0
179+
180+
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
181+
self._predict_batch_idx += 1

0 commit comments

Comments
 (0)