From 2f3c494c0ae767fa139df815f70194bb01bcbc4a Mon Sep 17 00:00:00 2001 From: ryanking13 Date: Mon, 24 May 2021 20:26:44 +0900 Subject: [PATCH 1/4] Check progress bar existence before printing --- pytorch_lightning/callbacks/progress.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index e6132e6f96c8c..0fe05ff812e20 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -473,12 +473,14 @@ def print( ): active_progress_bar = None - if not self.main_progress_bar.disable: + if self.main_progress_bar is not None and not self.main_progress_bar.disable: active_progress_bar = self.main_progress_bar - elif not self.val_progress_bar.disable: + elif self.val_progress_bar is not None and not self.val_progress_bar.disable: active_progress_bar = self.val_progress_bar - elif not self.test_progress_bar.disable: + elif self.test_progress_bar is not None and not self.test_progress_bar.disable: active_progress_bar = self.test_progress_bar + elif self.predict_progress_bar is not None and not self.predict_progress_bar.disable: + active_progress_bar = self.predict_progress_bar if active_progress_bar is not None: s = sep.join(map(str, args)) From a494e37bc0972a07cbfa3c69a102960807b450f8 Mon Sep 17 00:00:00 2001 From: ryanking13 Date: Mon, 24 May 2021 20:54:33 +0900 Subject: [PATCH 2/4] Add tests for predict_progres_bar --- tests/callbacks/test_progress_bar.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 6ab7b9f7415ba..f00240ffc6eb1 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -433,6 +433,10 @@ def test_step(self, *args, **kwargs): self.print("test_step") return super().test_step(*args, **kwargs) + def predict_step(self, *args, **kwargs): + self.print("predict_step") + return super().predict_step(*args, **kwargs) + @mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") def test_progress_bar_print(tqdm_write, tmpdir): @@ -445,16 +449,19 @@ def test_progress_bar_print(tqdm_write, tmpdir): limit_train_batches=1, limit_val_batches=1, limit_test_batches=1, + limit_predict_batches=1, max_steps=1, callbacks=[bar], ) trainer.fit(model) trainer.test(model) - assert tqdm_write.call_count == 3 + trainer.predict(model) + assert tqdm_write.call_count == 4 assert tqdm_write.call_args_list == [ call("training_step", end="", file=None, nolock=False), call("validation_step", end=os.linesep, file=sys.stderr, nolock=False), call("test_step", end=os.linesep, file=None, nolock=False), + call("predict_step", end=os.linesep, file=None, nolock=False), ] @@ -470,17 +477,20 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): limit_train_batches=1, limit_val_batches=1, limit_test_batches=1, + limit_predict_batches=1, max_steps=1, callbacks=[bar], ) bar.disable() trainer.fit(model) - trainer.test(model) + trainer.test(model, verbose=False) + trainer.predict(model) mock_print.assert_has_calls([ call("training_step", end=""), call("validation_step", file=ANY), call("test_step"), + call("predict_step"), ]) tqdm_write.assert_not_called() From 2da5c2b45954c0823d17ec6f1ca5a1fe1a0a847d Mon Sep 17 00:00:00 2001 From: ryanking13 Date: Mon, 24 May 2021 20:55:56 +0900 Subject: [PATCH 3/4] Add tests for progress_bar printing without training --- tests/callbacks/test_progress_bar.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index f00240ffc6eb1..f4f8f34c1b4c1 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -465,6 +465,32 @@ def test_progress_bar_print(tqdm_write, tmpdir): ] +@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") +def test_progress_bar_print_no_train(tqdm_write, tmpdir): + """ Test that printing in the LightningModule redirects arguments to the progress bar without training. """ + model = PrintModel() + bar = ProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_val_batches=1, + limit_test_batches=1, + limit_predict_batches=1, + max_steps=1, + callbacks=[bar], + ) + + trainer.validate(model) + trainer.test(model) + trainer.predict(model) + assert tqdm_write.call_count == 3 + assert tqdm_write.call_args_list == [ + call("validation_step", end=os.linesep, file=sys.stderr, nolock=False), + call("test_step", end=os.linesep, file=None, nolock=False), + call("predict_step", end=os.linesep, file=None, nolock=False), + ] + + @mock.patch('builtins.print') @mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): From 75c7d935b9f331d31f48e5f49bd4c5cb43c82e1d Mon Sep 17 00:00:00 2001 From: ryanking13 Date: Mon, 24 May 2021 22:18:18 +0900 Subject: [PATCH 4/4] Update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1239f349e8f5f..4e5892d03734c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -138,6 +138,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed setting correct `DistribType` for `ddp_cpu` (spawn) backend ([#7492](https://github.com/PyTorchLightning/pytorch-lightning/pull/7492)) +- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674)) + + ## [1.3.1] - 2021-05-11 ### Fixed