From 2573470d4b2acff2a6e66471b18f8d272f3f0421 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 May 2021 00:01:15 +0200 Subject: [PATCH] Clear predict_progress_bar in ProgressBar.__getstate__ (#7608) Co-authored-by: Yifu Wang Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 8 ++++++++ pytorch_lightning/callbacks/progress.py | 2 ++ tests/callbacks/test_progress_bar.py | 15 +++++++++++++++ 3 files changed, 25 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e126d948090c..e63da99daba1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). + + + +- Fixed `ProgressBar` pickling after calling `trainer.predict` ([#7608](https://github.com/PyTorchLightning/pytorch-lightning/pull/7608)) + + + + ## [1.3.2] - 2021-05-18 ### Changed diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index be9d2f44356f5..451078f1ac2e4 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -283,6 +283,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0): self.main_progress_bar = None self.val_progress_bar = None self.test_progress_bar = None + self.predict_progress_bar = None def __getstate__(self): # can't pickle the tqdm objects @@ -290,6 +291,7 @@ def __getstate__(self): state['main_progress_bar'] = None state['val_progress_bar'] = None state['test_progress_bar'] = None + state['predict_progress_bar'] = None return state @property diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 76f1e4cb0570f..6edb81975b347 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import pickle import sys from typing import Optional, Union from unittest import mock @@ -482,3 +483,17 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): call("test_step"), ]) tqdm_write.assert_not_called() + + +def test_progress_bar_can_be_pickled(): + bar = ProgressBar() + trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1) + model = BoringModel() + + pickle.dumps(bar) + trainer.fit(model) + pickle.dumps(bar) + trainer.test(model) + pickle.dumps(bar) + trainer.predict(model) + pickle.dumps(bar)