Skip to content

Commit 50198d7

Browse files
fix progress bar restart with fault-tolerant training enabled (#9310)
* reset progress updates * update docs * add test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f9132e8 commit 50198d7

File tree

4 files changed

+41
-7
lines changed

4 files changed

+41
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
297297
- Fixed `move_metrics_to_cpu` moving the loss on cpu while training on device ([#9308](https://github.com/PyTorchLightning/pytorch-lightning/pull/9308))
298298

299299

300+
- Fixed incorrect main progress bar indicator when resuming training mid-epoch ([#9310](https://github.com/PyTorchLightning/pytorch-lightning/pull/9310))
301+
302+
300303
## [1.4.5] - 2021-08-31
301304

302305
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))

pytorch_lightning/callbacks/progress/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,10 @@ def on_init_end(self, trainer):
154154
self._trainer = trainer
155155

156156
def on_train_start(self, trainer, pl_module):
157-
self._train_batch_idx = trainer.fit_loop.batch_idx
157+
self._train_batch_idx = 0
158158

159159
def on_train_epoch_start(self, trainer, pl_module):
160-
self._train_batch_idx = 0
160+
self._train_batch_idx = trainer.fit_loop.epoch_loop.batch_progress.current.completed
161161

162162
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
163163
self._train_batch_idx += 1

pytorch_lightning/callbacks/progress/progress.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def on_train_epoch_start(self, trainer, pl_module):
229229
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
230230
total_val_batches = total_val_batches * val_checks_per_epoch
231231
total_batches = total_train_batches + total_val_batches
232-
reset(self.main_progress_bar, total_batches)
232+
reset(self.main_progress_bar, total=total_batches, current=self.train_batch_idx)
233233
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
234234

235235
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
@@ -243,11 +243,11 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
243243
def on_validation_start(self, trainer, pl_module):
244244
super().on_validation_start(trainer, pl_module)
245245
if trainer.sanity_checking:
246-
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
246+
reset(self.val_progress_bar, total=sum(trainer.num_sanity_val_batches), current=self.val_batch_idx)
247247
else:
248248
self._update_bar(self.main_progress_bar) # fill up remaining
249249
self.val_progress_bar = self.init_validation_tqdm()
250-
reset(self.val_progress_bar, self.total_val_batches)
250+
reset(self.val_progress_bar, total=self.total_val_batches, current=self.val_batch_idx)
251251

252252
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
253253
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
@@ -333,7 +333,8 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
333333
return x
334334

335335

336-
def reset(bar: tqdm, total: Optional[int] = None) -> None:
337-
"""Resets the tqdm bar to 0 progress with a new total, unless it is disabled."""
336+
def reset(bar: tqdm, total: Optional[int] = None, current: int = 0) -> None:
337+
"""Resets the tqdm bar to the desired position and sets a new total, unless it is disabled."""
338338
if not bar.disable:
339339
bar.reset(total=convert_inf(total))
340+
bar.n = current

tests/callbacks/test_progress_bar.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,3 +558,33 @@ def _test_progress_bar_max_val_check_interval(
558558
total_val_batches = total_val_batches * val_checks_per_epoch
559559
if trainer.is_global_zero:
560560
assert trainer.progress_bar_callback.main_progress_bar.total == total_train_batches + total_val_batches
561+
562+
563+
def test_progress_bar_main_bar_resume():
564+
"""Test that the progress bar can resume its counters based on the Trainer state."""
565+
bar = ProgressBar()
566+
trainer = Mock()
567+
model = Mock()
568+
569+
trainer.sanity_checking = False
570+
trainer.check_val_every_n_epoch = 1
571+
trainer.current_epoch = 1
572+
trainer.num_training_batches = 5
573+
trainer.val_check_batch = 5
574+
trainer.num_val_batches = [3]
575+
trainer.fit_loop.epoch_loop.batch_progress.current.completed = 3
576+
577+
bar.on_init_end(trainer)
578+
bar.on_train_start(trainer, model)
579+
bar.on_train_epoch_start(trainer, model)
580+
581+
assert bar.main_progress_bar.n == 3
582+
assert bar.main_progress_bar.total == 8
583+
584+
# bar.on_train_epoch_end(trainer, model)
585+
bar.on_validation_start(trainer, model)
586+
bar.on_validation_epoch_start(trainer, model)
587+
588+
# restarting mid validation epoch is not currently supported
589+
assert bar.val_progress_bar.n == 0
590+
assert bar.val_progress_bar.total == 3

0 commit comments

Comments
 (0)