1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ from dataclasses import dataclass
1415from datetime import timedelta
15- from typing import Dict , Optional
16+ from typing import Optional , Union
1617
1718from pytorch_lightning .callbacks .progress .base import ProgressBarBase
1819from pytorch_lightning .utilities import _RICH_AVAILABLE
1920
21+ Style = None
2022if _RICH_AVAILABLE :
2123 from rich .console import Console , RenderableType
22- from rich .progress import BarColumn , Progress , ProgressColumn , SpinnerColumn , TextColumn
24+ from rich .progress import BarColumn , Progress , ProgressColumn , TextColumn
25+ from rich .style import Style
2326 from rich .text import Text
2427
2528 class CustomTimeColumn (ProgressColumn ):
2629
2730 # Only refresh twice a second to prevent jitter
2831 max_refresh = 0.5
2932
33+ def __init__ (self , style : Union [str , Style ]) -> None :
34+ self .style = style
35+ super ().__init__ ()
36+
3037 def render (self , task ) -> Text :
3138 elapsed = task .finished_time if task .finished else task .elapsed
3239 remaining = task .time_remaining
3340 elapsed_delta = "-:--:--" if elapsed is None else str (timedelta (seconds = int (elapsed )))
3441 remaining_delta = "-:--:--" if remaining is None else str (timedelta (seconds = int (remaining )))
35- return Text . from_markup (f"[progress.elapsed] { elapsed_delta } < [progress.remaining] { remaining_delta } " )
42+ return Text (f"{ elapsed_delta } • { remaining_delta } " , style = self . style )
3643
3744 class BatchesProcessedColumn (ProgressColumn ):
45+ def __init__ (self , style : Union [str , Style ]):
46+ self .style = style
47+ super ().__init__ ()
48+
3849 def render (self , task ) -> RenderableType :
39- return Text . from_markup (f"[magenta] { int (task .completed )} /{ task .total } " )
50+ return Text (f"{ int (task .completed )} /{ task .total } " , style = self . style )
4051
4152 class ProcessingSpeedColumn (ProgressColumn ):
53+ def __init__ (self , style : Union [str , Style ]):
54+ self .style = style
55+ super ().__init__ ()
56+
4257 def render (self , task ) -> RenderableType :
4358 task_speed = f"{ task .speed :>.2f} " if task .speed is not None else "0.00"
44- return Text . from_markup (f"[progress.data.speed] { task_speed } it/s" )
59+ return Text (f"{ task_speed } it/s" , style = self . style )
4560
4661 class MetricsTextColumn (ProgressColumn ):
4762 """A column containing text."""
@@ -71,19 +86,26 @@ def render(self, task) -> Text:
7186 metrics = self ._trainer .progress_bar_callback .get_metrics (self ._trainer , self ._pl_module )
7287 else :
7388 metrics = self ._trainer .progress_bar_metrics
89+
7490 for k , v in metrics .items ():
7591 _text += f"{ k } : { round (v , 3 ) if isinstance (v , float ) else v } "
7692 text = Text .from_markup (_text , style = None , justify = "left" )
7793 return text
7894
7995
80- STYLES : Dict [str , str ] = {
81- "train" : "red" ,
82- "sanity_check" : "yellow" ,
83- "validate" : "yellow" ,
84- "test" : "yellow" ,
85- "predict" : "yellow" ,
86- }
96+ @dataclass
97+ class RichProgressBarTheme :
98+ """Styles to associate to different base components.
99+
100+ https://rich.readthedocs.io/en/stable/style.html
101+ """
102+
103+ text_color : str = "white"
104+ progress_bar_complete : Union [str , Style ] = "#6206E0"
105+ progress_bar_finished : Union [str , Style ] = "#6206E0"
106+ batch_process : str = "white"
107+ time : str = "grey54"
108+ processing_speed : str = "grey70"
87109
88110
89111class RichProgressBar (ProgressBarBase ):
@@ -104,13 +126,18 @@ class RichProgressBar(ProgressBarBase):
104126
105127 Args:
106128 refresh_rate: the number of updates per second, must be strictly positive
129+ theme: Contains styles used to stylize the progress bar.
107130
108131 Raises:
109132 ImportError:
110133 If required `rich` package is not installed on the device.
111134 """
112135
113- def __init__ (self , refresh_rate : float = 1.0 ):
136+ def __init__ (
137+ self ,
138+ refresh_rate : float = 1.0 ,
139+ theme : RichProgressBarTheme = RichProgressBarTheme (),
140+ ) -> None :
114141 if not _RICH_AVAILABLE :
115142 raise ImportError (
116143 "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
@@ -126,6 +153,7 @@ def __init__(self, refresh_rate: float = 1.0):
126153 self .test_progress_bar_id : Optional [int ] = None
127154 self .predict_progress_bar_id : Optional [int ] = None
128155 self .console = Console (record = True )
156+ self .theme = theme
129157
130158 @property
131159 def refresh_rate (self ) -> int :
@@ -147,39 +175,36 @@ def enable(self) -> None:
147175
148176 @property
149177 def sanity_check_description (self ) -> str :
150- return "[ Validation Sanity Check] "
178+ return "Validation Sanity Check"
151179
152180 @property
153181 def validation_description (self ) -> str :
154- return "[ Validation] "
182+ return "Validation"
155183
156184 @property
157185 def test_description (self ) -> str :
158- return "[ Testing] "
186+ return "Testing"
159187
160188 @property
161189 def predict_description (self ) -> str :
162- return "[ Predicting] "
190+ return "Predicting"
163191
164192 def setup (self , trainer , pl_module , stage ):
165193 self .progress = Progress (
166- SpinnerColumn (),
167194 TextColumn ("[progress.description]{task.description}" ),
168- BarColumn (),
169- BatchesProcessedColumn (),
170- "[" ,
171- CustomTimeColumn (),
172- ProcessingSpeedColumn (),
195+ BarColumn (complete_style = self .theme .progress_bar_complete , finished_style = self .theme .progress_bar_finished ),
196+ BatchesProcessedColumn (style = self .theme .batch_process ),
197+ CustomTimeColumn (style = self .theme .time ),
198+ ProcessingSpeedColumn (style = self .theme .processing_speed ),
173199 MetricsTextColumn (trainer , pl_module , stage ),
174- "]" ,
175200 console = self .console ,
176201 refresh_per_second = self .refresh_rate ,
177202 ).__enter__ ()
178203
179204 def on_sanity_check_start (self , trainer , pl_module ):
180205 super ().on_sanity_check_start (trainer , pl_module )
181206 self .val_sanity_progress_bar_id = self .progress .add_task (
182- f"[{ STYLES [ 'sanity_check' ] } ]{ self .sanity_check_description } " ,
207+ f"[{ self . theme . text_color } ]{ self .sanity_check_description } " ,
183208 total = trainer .num_sanity_val_steps ,
184209 )
185210
@@ -201,15 +226,15 @@ def on_train_epoch_start(self, trainer, pl_module):
201226 train_description = self ._get_train_description (trainer .current_epoch )
202227
203228 self .main_progress_bar_id = self .progress .add_task (
204- f"[{ STYLES [ 'train' ] } ]{ train_description } " ,
229+ f"[{ self . theme . text_color } ]{ train_description } " ,
205230 total = total_batches ,
206231 )
207232
208233 def on_validation_epoch_start (self , trainer , pl_module ):
209234 super ().on_validation_epoch_start (trainer , pl_module )
210235 if self ._total_val_batches > 0 :
211236 self .val_progress_bar_id = self .progress .add_task (
212- f"[{ STYLES [ 'validate' ] } ]{ self .validation_description } " ,
237+ f"[{ self . theme . text_color } ]{ self .validation_description } " ,
213238 total = self ._total_val_batches ,
214239 )
215240
@@ -221,14 +246,14 @@ def on_validation_epoch_end(self, trainer, pl_module):
221246 def on_test_epoch_start (self , trainer , pl_module ):
222247 super ().on_train_epoch_start (trainer , pl_module )
223248 self .test_progress_bar_id = self .progress .add_task (
224- f"[{ STYLES [ 'test' ] } ]{ self .test_description } " ,
249+ f"[{ self . theme . text_color } ]{ self .test_description } " ,
225250 total = self .total_test_batches ,
226251 )
227252
228253 def on_predict_epoch_start (self , trainer , pl_module ):
229254 super ().on_predict_epoch_start (trainer , pl_module )
230255 self .predict_progress_bar_id = self .progress .add_task (
231- f"[{ STYLES [ 'predict' ] } ]{ self .predict_description } " ,
256+ f"[{ self . theme . text_color } ]{ self .predict_description } " ,
232257 total = self .total_predict_batches ,
233258 )
234259
@@ -261,7 +286,7 @@ def _should_update(self, current, total) -> bool:
261286 return self .is_enabled and (current % self .refresh_rate == 0 or current == total )
262287
263288 def _get_train_description (self , current_epoch : int ) -> str :
264- train_description = f"[ Epoch { current_epoch } ] "
289+ train_description = f"Epoch { current_epoch } "
265290 if len (self .validation_description ) > len (train_description ):
266291 # Padding is required to avoid flickering due of uneven lengths of "Epoch X"
267292 # and "Validation" Bar description
@@ -273,3 +298,7 @@ def _get_train_description(self, current_epoch: int) -> str:
273298
274299 def teardown (self , trainer , pl_module , stage ):
275300 self .progress .__exit__ (None , None , None )
301+
302+ def on_exception (self , trainer , pl_module , exception : BaseException ) -> None :
303+ if isinstance (exception , KeyboardInterrupt ):
304+ self .progress .stop ()
0 commit comments