diff --git a/CHANGELOG.md b/CHANGELOG.md index 82ce08594b310..0e23f6a9efab5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed parsing of multiple training dataloaders ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) +- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608)) + + - Fixed recursive passing of `wrong_type` keyword argument in `pytorch_lightning.utilities.apply_to_collection` ([#7433](https://github.com/PyTorchLightning/pytorch-lightning/pull/7433)) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 45e9e55e69bf0..e6132e6f96c8c 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -283,6 +283,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0): self.main_progress_bar = None self.val_progress_bar = None self.test_progress_bar = None + self.predict_progress_bar = None def __getstate__(self): # can't pickle the tqdm objects @@ -290,6 +291,7 @@ def __getstate__(self): state['main_progress_bar'] = None state['val_progress_bar'] = None state['test_progress_bar'] = None + state['predict_progress_bar'] = None return state @property diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 2a33fbf0c1455..6ab7b9f7415ba 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import pickle import sys from typing import Optional, Union from unittest import mock @@ -482,3 +483,17 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): call("test_step"), ]) tqdm_write.assert_not_called() + + +def test_progress_bar_can_be_pickled(): + bar = ProgressBar() + trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1) + model = BoringModel() + + pickle.dumps(bar) + trainer.fit(model) + pickle.dumps(bar) + trainer.test(model) + pickle.dumps(bar) + trainer.predict(model) + pickle.dumps(bar)