@@ -125,7 +125,7 @@ class RichProgressBar(ProgressBarBase):
125125 trainer = Trainer(callbacks=RichProgressBar())
126126
127127 Args:
128- refresh_rate : the number of updates per second, must be strictly positive
128+ refresh_rate_per_second : the number of updates per second. If refresh_rate is 0, progress bar is disabled.
129129 theme: Contains styles used to stylize the progress bar.
130130
131131 Raises:
@@ -135,15 +135,15 @@ class RichProgressBar(ProgressBarBase):
135135
136136 def __init__ (
137137 self ,
138- refresh_rate : float = 1.0 ,
138+ refresh_rate_per_second : int = 10 ,
139139 theme : RichProgressBarTheme = RichProgressBarTheme (),
140140 ) -> None :
141141 if not _RICH_AVAILABLE :
142142 raise ImportError (
143143 "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
144144 )
145145 super ().__init__ ()
146- self ._refresh_rate : float = refresh_rate
146+ self ._refresh_rate_per_second : int = refresh_rate_per_second
147147 self ._enabled : bool = True
148148 self ._total_val_batches : int = 0
149149 self .progress : Progress = None
@@ -156,12 +156,17 @@ def __init__(
156156 self .theme = theme
157157
158158 @property
159- def refresh_rate (self ) -> int :
160- return self ._refresh_rate
159+ def refresh_rate_per_second (self ) -> float :
160+ """Refresh rate for Rich Progress.
161+
162+ Returns: Refresh rate for Progress Bar.
163+ Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress).
164+ """
165+ return self ._refresh_rate_per_second if self ._refresh_rate_per_second > 0 else 1
161166
162167 @property
163168 def is_enabled (self ) -> bool :
164- return self ._enabled and self .refresh_rate > 0
169+ return self ._enabled and self ._refresh_rate_per_second > 0
165170
166171 @property
167172 def is_disabled (self ) -> bool :
@@ -189,7 +194,7 @@ def test_description(self) -> str:
189194 def predict_description (self ) -> str :
190195 return "Predicting"
191196
192- def setup (self , trainer , pl_module , stage ):
197+ def setup (self , trainer , pl_module , stage : Optional [ str ] = None ):
193198 self .progress = Progress (
194199 TextColumn ("[progress.description]{task.description}" ),
195200 BarColumn (complete_style = self .theme .progress_bar_complete , finished_style = self .theme .progress_bar_finished ),
@@ -198,8 +203,10 @@ def setup(self, trainer, pl_module, stage):
198203 ProcessingSpeedColumn (style = self .theme .processing_speed ),
199204 MetricsTextColumn (trainer , pl_module , stage ),
200205 console = self .console ,
201- refresh_per_second = self .refresh_rate ,
202- ).__enter__ ()
206+ refresh_per_second = self .refresh_rate_per_second ,
207+ disable = self .is_disabled ,
208+ )
209+ self .progress .start ()
203210
204211 def on_sanity_check_start (self , trainer , pl_module ):
205212 super ().on_sanity_check_start (trainer , pl_module )
@@ -259,31 +266,23 @@ def on_predict_epoch_start(self, trainer, pl_module):
259266
260267 def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
261268 super ().on_train_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
262- if self ._should_update (self .train_batch_idx , self .total_train_batches + self .total_val_batches ):
263- self .progress .update (self .main_progress_bar_id , advance = 1.0 )
269+ self .progress .update (self .main_progress_bar_id , advance = 1.0 )
264270
265271 def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
266272 super ().on_validation_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
267273 if trainer .sanity_checking :
268274 self .progress .update (self .val_sanity_progress_bar_id , advance = 1.0 )
269- elif self .val_progress_bar_id and self ._should_update (
270- self .val_batch_idx , self .total_train_batches + self .total_val_batches
271- ):
275+ elif self .val_progress_bar_id :
272276 self .progress .update (self .main_progress_bar_id , advance = 1.0 )
273277 self .progress .update (self .val_progress_bar_id , advance = 1.0 )
274278
275279 def on_test_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
276280 super ().on_test_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
277- if self ._should_update (self .test_batch_idx , self .total_test_batches ):
278- self .progress .update (self .test_progress_bar_id , advance = 1.0 )
281+ self .progress .update (self .test_progress_bar_id , advance = 1.0 )
279282
280283 def on_predict_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
281284 super ().on_predict_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
282- if self ._should_update (self .predict_batch_idx , self .total_predict_batches ):
283- self .progress .update (self .predict_progress_bar_id , advance = 1.0 )
284-
285- def _should_update (self , current , total ) -> bool :
286- return self .is_enabled and (current % self .refresh_rate == 0 or current == total )
285+ self .progress .update (self .predict_progress_bar_id , advance = 1.0 )
287286
288287 def _get_train_description (self , current_epoch : int ) -> str :
289288 train_description = f"Epoch { current_epoch } "
@@ -296,8 +295,8 @@ def _get_train_description(self, current_epoch: int) -> str:
296295 train_description += " "
297296 return train_description
298297
299- def teardown (self , trainer , pl_module , stage ) :
300- self .progress .__exit__ ( None , None , None )
298+ def teardown (self , trainer , pl_module , stage : Optional [ str ] = None ) -> None :
299+ self .progress .stop ( )
301300
302301 def on_exception (self , trainer , pl_module , exception : BaseException ) -> None :
303302 if isinstance (exception , KeyboardInterrupt ):
0 commit comments