Skip to content

Commit d4676ee

Browse files
awaelchlilexierule
authored andcommitted
Fix progress bar print error when called before training (#7674)
* Check progress bar existence before printing * Add tests for predict_progres_bar * Add tests for progress_bar printing without training * Update changelog update changelog
1 parent f11c754 commit d4676ee

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616
- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608))
1717
- Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592))
1818
- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))
19+
- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674))
1920

2021
## [1.3.2] - 2021-05-18
2122

@@ -30,7 +31,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3031
- Fixed setting correct `DistribType` for `ddp_cpu` (spawn) backend ([#7492](https://github.com/PyTorchLightning/pytorch-lightning/pull/7492))
3132
- Fixed incorrect number of calls to LR scheduler when `check_val_every_n_epoch > 1` ([#7032](https://github.com/PyTorchLightning/pytorch-lightning/pull/7032))
3233

33-
3434
## [1.3.1] - 2021-05-11
3535

3636
### Fixed

pytorch_lightning/callbacks/progress.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,12 +473,14 @@ def print(
473473
):
474474
active_progress_bar = None
475475

476-
if not self.main_progress_bar.disable:
476+
if self.main_progress_bar is not None and not self.main_progress_bar.disable:
477477
active_progress_bar = self.main_progress_bar
478-
elif not self.val_progress_bar.disable:
478+
elif self.val_progress_bar is not None and not self.val_progress_bar.disable:
479479
active_progress_bar = self.val_progress_bar
480-
elif not self.test_progress_bar.disable:
480+
elif self.test_progress_bar is not None and not self.test_progress_bar.disable:
481481
active_progress_bar = self.test_progress_bar
482+
elif self.predict_progress_bar is not None and not self.predict_progress_bar.disable:
483+
active_progress_bar = self.predict_progress_bar
482484

483485
if active_progress_bar is not None:
484486
s = sep.join(map(str, args))

tests/callbacks/test_progress_bar.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,10 @@ def test_step(self, *args, **kwargs):
433433
self.print("test_step")
434434
return super().test_step(*args, **kwargs)
435435

436+
def predict_step(self, *args, **kwargs):
437+
self.print("predict_step")
438+
return super().predict_step(*args, **kwargs)
439+
436440

437441
@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write")
438442
def test_progress_bar_print(tqdm_write, tmpdir):
@@ -445,16 +449,45 @@ def test_progress_bar_print(tqdm_write, tmpdir):
445449
limit_train_batches=1,
446450
limit_val_batches=1,
447451
limit_test_batches=1,
452+
limit_predict_batches=1,
448453
max_steps=1,
449454
callbacks=[bar],
450455
)
451456
trainer.fit(model)
452457
trainer.test(model)
453-
assert tqdm_write.call_count == 3
458+
trainer.predict(model)
459+
assert tqdm_write.call_count == 4
454460
assert tqdm_write.call_args_list == [
455461
call("training_step", end="", file=None, nolock=False),
456462
call("validation_step", end=os.linesep, file=sys.stderr, nolock=False),
457463
call("test_step", end=os.linesep, file=None, nolock=False),
464+
call("predict_step", end=os.linesep, file=None, nolock=False),
465+
]
466+
467+
468+
@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write")
469+
def test_progress_bar_print_no_train(tqdm_write, tmpdir):
470+
""" Test that printing in the LightningModule redirects arguments to the progress bar without training. """
471+
model = PrintModel()
472+
bar = ProgressBar()
473+
trainer = Trainer(
474+
default_root_dir=tmpdir,
475+
num_sanity_val_steps=0,
476+
limit_val_batches=1,
477+
limit_test_batches=1,
478+
limit_predict_batches=1,
479+
max_steps=1,
480+
callbacks=[bar],
481+
)
482+
483+
trainer.validate(model)
484+
trainer.test(model)
485+
trainer.predict(model)
486+
assert tqdm_write.call_count == 3
487+
assert tqdm_write.call_args_list == [
488+
call("validation_step", end=os.linesep, file=sys.stderr, nolock=False),
489+
call("test_step", end=os.linesep, file=None, nolock=False),
490+
call("predict_step", end=os.linesep, file=None, nolock=False),
458491
]
459492

460493

@@ -470,17 +503,20 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
470503
limit_train_batches=1,
471504
limit_val_batches=1,
472505
limit_test_batches=1,
506+
limit_predict_batches=1,
473507
max_steps=1,
474508
callbacks=[bar],
475509
)
476510
bar.disable()
477511
trainer.fit(model)
478-
trainer.test(model)
512+
trainer.test(model, verbose=False)
513+
trainer.predict(model)
479514

480515
mock_print.assert_has_calls([
481516
call("training_step", end=""),
482517
call("validation_step", file=ANY),
483518
call("test_step"),
519+
call("predict_step"),
484520
])
485521
tqdm_write.assert_not_called()
486522

0 commit comments

Comments
 (0)