From 6c8fea05a6a64945dc5efe97d41edaabc3f09f58 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Wed, 19 May 2021 12:48:06 +0200 Subject: [PATCH 1/4] Update test_progress_bar.py --- tests/callbacks/test_progress_bar.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 2a33fbf0c1455..cdbeefc3efb1e 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -175,6 +175,11 @@ def test_progress_bar_fast_dev_run(tmpdir): assert 1 == progress_bar.test_batch_idx assert 1 == progress_bar.test_progress_bar.total assert 1 == progress_bar.test_progress_bar.n + + trainer.predict(model, model.test_dataloader()) + assert 1 == progress_bar.predict_batch_idx + assert 1 == progress_bar.predict_progress_bar.total + assert 1 == progress_bar.predict_progress_bar.n @pytest.mark.parametrize('refresh_rate', [0, 1, 50]) From c594a50aa610f3c6965bde2af9af720241e5fad1 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Wed, 19 May 2021 12:56:22 +0200 Subject: [PATCH 2/4] Update test_progress_bar.py --- tests/callbacks/test_progress_bar.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index cdbeefc3efb1e..3bb63f4ef0c6d 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -175,11 +175,6 @@ def test_progress_bar_fast_dev_run(tmpdir): assert 1 == progress_bar.test_batch_idx assert 1 == progress_bar.test_progress_bar.total assert 1 == progress_bar.test_progress_bar.n - - trainer.predict(model, model.test_dataloader()) - assert 1 == progress_bar.predict_batch_idx - assert 1 == progress_bar.predict_progress_bar.total - assert 1 == progress_bar.predict_progress_bar.n @pytest.mark.parametrize('refresh_rate', [0, 1, 50]) @@ -487,3 +482,17 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): call("test_step"), ]) tqdm_write.assert_not_called() + +def test_progbar_pickle(): + bar = ProgressBar() + trainer = Trainer(fast_dev_run=True, callbacks=[bar, limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, max_steps=1) + model = BoringModel() + pickle.dumps(bar) + trainer.fit(model) + pickle.dumps(bar) + trainer.test(model) + pickle.dumps(bar) + trainer.predict(model, model.test_dataloader()) + pickle.dumps(bar) From d49cc7c33920239fa701e2f3ef21a375aaacbab6 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Wed, 19 May 2021 12:57:02 +0200 Subject: [PATCH 3/4] Update test_progress_bar.py --- tests/callbacks/test_progress_bar.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 3bb63f4ef0c6d..83bb392c20e61 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -485,9 +485,11 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): def test_progbar_pickle(): bar = ProgressBar() - trainer = Trainer(fast_dev_run=True, callbacks=[bar, limit_train_batches=1, - limit_val_batches=1, - limit_test_batches=1, max_steps=1) + trainer = Trainer(fast_dev_run=True, callbacks=[bar], + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + max_steps=1) model = BoringModel() pickle.dumps(bar) trainer.fit(model) From ac378959987e5881679199b7b22640bc3a58a9ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 May 2021 10:58:43 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/callbacks/test_progress_bar.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 83bb392c20e61..bfd16969b0dda 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -482,14 +482,18 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): call("test_step"), ]) tqdm_write.assert_not_called() - + + def test_progbar_pickle(): bar = ProgressBar() - trainer = Trainer(fast_dev_run=True, callbacks=[bar], - limit_train_batches=1, - limit_val_batches=1, - limit_test_batches=1, - max_steps=1) + trainer = Trainer( + fast_dev_run=True, + callbacks=[bar], + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + max_steps=1 + ) model = BoringModel() pickle.dumps(bar) trainer.fit(model)