Skip to content

Commit 2da5c2b

Browse files
committed
Add tests for progress_bar printing without training
1 parent a494e37 commit 2da5c2b

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/callbacks/test_progress_bar.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,32 @@ def test_progress_bar_print(tqdm_write, tmpdir):
465465
]
466466

467467

468+
@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write")
469+
def test_progress_bar_print_no_train(tqdm_write, tmpdir):
470+
""" Test that printing in the LightningModule redirects arguments to the progress bar without training. """
471+
model = PrintModel()
472+
bar = ProgressBar()
473+
trainer = Trainer(
474+
default_root_dir=tmpdir,
475+
num_sanity_val_steps=0,
476+
limit_val_batches=1,
477+
limit_test_batches=1,
478+
limit_predict_batches=1,
479+
max_steps=1,
480+
callbacks=[bar],
481+
)
482+
483+
trainer.validate(model)
484+
trainer.test(model)
485+
trainer.predict(model)
486+
assert tqdm_write.call_count == 3
487+
assert tqdm_write.call_args_list == [
488+
call("validation_step", end=os.linesep, file=sys.stderr, nolock=False),
489+
call("test_step", end=os.linesep, file=None, nolock=False),
490+
call("predict_step", end=os.linesep, file=None, nolock=False),
491+
]
492+
493+
468494
@mock.patch('builtins.print')
469495
@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write")
470496
def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):

0 commit comments

Comments
 (0)