From 224e9cb58e9c2d6cb4b88f2a62f60b279a77977c Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Wed, 19 May 2021 02:25:14 -0700 Subject: [PATCH 1/9] Clear predict_progress_bar in ProgressBar.__getstate__ --- pytorch_lightning/callbacks/progress.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 45e9e55e69bf0..e6132e6f96c8c 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -283,6 +283,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0): self.main_progress_bar = None self.val_progress_bar = None self.test_progress_bar = None + self.predict_progress_bar = None def __getstate__(self): # can't pickle the tqdm objects @@ -290,6 +291,7 @@ def __getstate__(self): state['main_progress_bar'] = None state['val_progress_bar'] = None state['test_progress_bar'] = None + state['predict_progress_bar'] = None return state @property From 6c8fea05a6a64945dc5efe97d41edaabc3f09f58 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Wed, 19 May 2021 12:48:06 +0200 Subject: [PATCH 2/9] Update test_progress_bar.py --- tests/callbacks/test_progress_bar.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 2a33fbf0c1455..cdbeefc3efb1e 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -175,6 +175,11 @@ def test_progress_bar_fast_dev_run(tmpdir): assert 1 == progress_bar.test_batch_idx assert 1 == progress_bar.test_progress_bar.total assert 1 == progress_bar.test_progress_bar.n + + trainer.predict(model, model.test_dataloader()) + assert 1 == progress_bar.predict_batch_idx + assert 1 == progress_bar.predict_progress_bar.total + assert 1 == progress_bar.predict_progress_bar.n @pytest.mark.parametrize('refresh_rate', [0, 1, 50]) From c594a50aa610f3c6965bde2af9af720241e5fad1 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Wed, 19 May 2021 12:56:22 +0200 Subject: [PATCH 3/9] Update test_progress_bar.py --- tests/callbacks/test_progress_bar.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index cdbeefc3efb1e..3bb63f4ef0c6d 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -175,11 +175,6 @@ def test_progress_bar_fast_dev_run(tmpdir): assert 1 == progress_bar.test_batch_idx assert 1 == progress_bar.test_progress_bar.total assert 1 == progress_bar.test_progress_bar.n - - trainer.predict(model, model.test_dataloader()) - assert 1 == progress_bar.predict_batch_idx - assert 1 == progress_bar.predict_progress_bar.total - assert 1 == progress_bar.predict_progress_bar.n @pytest.mark.parametrize('refresh_rate', [0, 1, 50]) @@ -487,3 +482,17 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): call("test_step"), ]) tqdm_write.assert_not_called() + +def test_progbar_pickle(): + bar = ProgressBar() + trainer = Trainer(fast_dev_run=True, callbacks=[bar, limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, max_steps=1) + model = BoringModel() + pickle.dumps(bar) + trainer.fit(model) + pickle.dumps(bar) + trainer.test(model) + pickle.dumps(bar) + trainer.predict(model, model.test_dataloader()) + pickle.dumps(bar) From d49cc7c33920239fa701e2f3ef21a375aaacbab6 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Wed, 19 May 2021 12:57:02 +0200 Subject: [PATCH 4/9] Update test_progress_bar.py --- tests/callbacks/test_progress_bar.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 3bb63f4ef0c6d..83bb392c20e61 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -485,9 +485,11 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): def test_progbar_pickle(): bar = ProgressBar() - trainer = Trainer(fast_dev_run=True, callbacks=[bar, limit_train_batches=1, - limit_val_batches=1, - limit_test_batches=1, max_steps=1) + trainer = Trainer(fast_dev_run=True, callbacks=[bar], + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + max_steps=1) model = BoringModel() pickle.dumps(bar) trainer.fit(model) From ac378959987e5881679199b7b22640bc3a58a9ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 May 2021 10:58:43 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/callbacks/test_progress_bar.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 83bb392c20e61..bfd16969b0dda 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -482,14 +482,18 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): call("test_step"), ]) tqdm_write.assert_not_called() - + + def test_progbar_pickle(): bar = ProgressBar() - trainer = Trainer(fast_dev_run=True, callbacks=[bar], - limit_train_batches=1, - limit_val_batches=1, - limit_test_batches=1, - max_steps=1) + trainer = Trainer( + fast_dev_run=True, + callbacks=[bar], + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + max_steps=1 + ) model = BoringModel() pickle.dumps(bar) trainer.fit(model) From e679b5bbac78a3aa04fcd8168e06cd7c16e5213e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 19 May 2021 16:44:20 +0200 Subject: [PATCH 6/9] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 82ce08594b310..0e23f6a9efab5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) +- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608)) + + - Fixed recursive passing of `wrong_type` keyword argument in `pytorch_lightning.utilities.apply_to_collection` ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) From 355e4c9a5e11e4818d61ada4ef7cfdb734981b71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 19 May 2021 16:46:08 +0200 Subject: [PATCH 7/9] Update tests/callbacks/test_progress_bar.py --- tests/callbacks/test_progress_bar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index bfd16969b0dda..3620ffe6ec03e 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -500,5 +500,5 @@ def test_progbar_pickle(): pickle.dumps(bar) trainer.test(model) pickle.dumps(bar) - trainer.predict(model, model.test_dataloader()) + trainer.predict(model) pickle.dumps(bar) From f51c33077d72a5e78d4223e0c4f7947d1293ec0d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 20 May 2021 03:10:12 +0200 Subject: [PATCH 8/9] Missing import --- tests/callbacks/test_progress_bar.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 3620ffe6ec03e..cab5722612831 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import pickle import sys from typing import Optional, Union from unittest import mock @@ -486,15 +487,9 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): def test_progbar_pickle(): bar = ProgressBar() - trainer = Trainer( - fast_dev_run=True, - callbacks=[bar], - limit_train_batches=1, - limit_val_batches=1, - limit_test_batches=1, - max_steps=1 - ) + trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1) model = BoringModel() + pickle.dumps(bar) trainer.fit(model) pickle.dumps(bar) From 0e597f58761c8f724a8b2760f895508dec71b8e3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 20 May 2021 03:11:33 +0200 Subject: [PATCH 9/9] Change name --- tests/callbacks/test_progress_bar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index cab5722612831..6ab7b9f7415ba 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -485,7 +485,7 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): tqdm_write.assert_not_called() -def test_progbar_pickle(): +def test_progress_bar_can_be_pickled(): bar = ProgressBar() trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1) model = BoringModel()