From f53d2fb5fc198146c2d1ba7a760122bbc43dd60e Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 25 Oct 2021 19:59:37 +0530 Subject: [PATCH 01/10] Rename ProgressBar to TQDMProgressBar --- pytorch_lightning/callbacks/__init__.py | 3 ++- .../callbacks/progress/__init__.py | 3 ++- .../callbacks/progress/progress.py | 23 +++++++++++++++++++ .../callbacks/progress/tqdm_progress.py | 4 ++-- 4 files changed, 29 insertions(+), 4 deletions(-) create mode 100644 pytorch_lightning/callbacks/progress/progress.py diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index b94fa969f6ac9..f47bc115ece51 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -22,7 +22,7 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_summary import ModelSummary from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter -from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar +from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar, TQDMProgressBar from pytorch_lightning.callbacks.pruning import ModelPruning from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary @@ -52,4 +52,5 @@ "RichProgressBar", "StochasticWeightAveraging", "Timer", + "TQDMProgressBar", ] diff --git a/pytorch_lightning/callbacks/progress/__init__.py b/pytorch_lightning/callbacks/progress/__init__.py index 3fa7b1afe6b44..6ccc181b95c21 100644 --- a/pytorch_lightning/callbacks/progress/__init__.py +++ b/pytorch_lightning/callbacks/progress/__init__.py @@ -19,5 +19,6 @@ """ from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401 +from pytorch_lightning.callbacks.progress.progress import ProgressBar # noqa: F401 from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401 -from pytorch_lightning.callbacks.progress.tqdm_progress import ProgressBar # noqa: F401 +from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar # noqa: F401 diff --git a/pytorch_lightning/callbacks/progress/progress.py b/pytorch_lightning/callbacks/progress/progress.py new file mode 100644 index 0000000000000..a14daefe8af38 --- /dev/null +++ b/pytorch_lightning/callbacks/progress/progress.py @@ -0,0 +1,23 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar +from pytorch_lightning.utilities import rank_zero_deprecation + + +class ProgressBar(TQDMProgressBar): + + rank_zero_deprecation( + "`ProgressBar` has been deprecated in v1.5 and will be removed in v1.7." + " It has been renamed to `TQDMProgressBar` instead." + ) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 7f3b902925c6c..672d9d893ad61 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -54,7 +54,7 @@ def format_num(n) -> str: return n -class ProgressBar(ProgressBarBase): +class TQDMProgressBar(ProgressBarBase): r""" This is the default progress bar used by Lightning. It prints to ``stdout`` using the :mod:`tqdm` package and shows up to four different bars: @@ -75,7 +75,7 @@ class ProgressBar(ProgressBarBase): Example: - >>> class LitProgressBar(ProgressBar): + >>> class LitProgressBar(TQDMProgressBar): ... def init_validation_tqdm(self): ... bar = super().init_validation_tqdm() ... bar.set_description('running validation ...') From 5bfca26cbf9c7b4892e7dd2ceab4fb5c2e55f6b8 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 25 Oct 2021 20:31:12 +0530 Subject: [PATCH 02/10] Add test & update changelog --- CHANGELOG.md | 4 ++++ tests/deprecated_api/test_remove_1-7.py | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 655484292ee59..48107fe32ccad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -408,6 +408,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924)) + +- Deprecated `ProgressBar` callback in favor of `TQDMProgressBar` ([#10134](https://github.com/PyTorchLightning/pytorch-lightning/pull/10134)) + + ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index edf85c11766e3..298e76f5850de 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -387,3 +387,10 @@ def test_v1_7_0_deprecate_gpu_stats_monitor(tmpdir): def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir): with pytest.deprecated_call(match="The `XLAStatsMonitor` callback was deprecated in v1.5"): _ = XLAStatsMonitor() + + +def test_v1_7_0_progress_bar(): + + _soft_unimport_module("pytorch_lightning.callbacks.progress") + with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."): + from pytorch_lightning.callbacks.progress import ProgressBar # noqa: F401 From d8a5a6ea2cda1fda4e17ca6ab30253eb774e88c3 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 28 Oct 2021 15:47:43 +0530 Subject: [PATCH 03/10] Update test --- pytorch_lightning/callbacks/progress/progress.py | 11 ++++++----- tests/deprecated_api/test_remove_1-7.py | 5 ++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/progress.py b/pytorch_lightning/callbacks/progress/progress.py index a14daefe8af38..1a1ddccb7ac8b 100644 --- a/pytorch_lightning/callbacks/progress/progress.py +++ b/pytorch_lightning/callbacks/progress/progress.py @@ -16,8 +16,9 @@ class ProgressBar(TQDMProgressBar): - - rank_zero_deprecation( - "`ProgressBar` has been deprecated in v1.5 and will be removed in v1.7." - " It has been renamed to `TQDMProgressBar` instead." - ) + def __init__(self, refresh_rate: int = 1, process_position: int = 0): + super().__init__(refresh_rate=refresh_rate, process_position=process_position) + rank_zero_deprecation( + "`ProgressBar` has been deprecated in v1.5 and will be removed in v1.7." + " It has been renamed to `TQDMProgressBar` instead." + ) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 4040b946a626f..dfe326f332b00 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -19,6 +19,7 @@ from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor +from pytorch_lightning.callbacks.progress import ProgressBar from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger from tests.callbacks.test_callbacks import OldStatefulCallback @@ -391,10 +392,8 @@ def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir): def test_v1_7_0_progress_bar(): - - _soft_unimport_module("pytorch_lightning.callbacks.progress") with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."): - from pytorch_lightning.callbacks.progress import ProgressBar # noqa: F401 + _ = ProgressBar() def test_v1_7_0_deprecated_max_steps_none(tmpdir): From 634a3b2c07bbcdd99b3307e696b278e0e1c4eb5a Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 28 Oct 2021 22:31:16 +0530 Subject: [PATCH 04/10] Update --- pytorch_lightning/trainer/connectors/callback_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 2f63e65340760..8809254599a84 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -20,9 +20,9 @@ GradientAccumulationScheduler, ModelCheckpoint, ModelSummary, - ProgressBar, ProgressBarBase, RichProgressBar, + TQDMProgressBar, ) from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer @@ -230,7 +230,7 @@ def configure_progress_bar(self, refresh_rate=None, process_position=0): if len(progress_bars) == 1: progress_bar_callback = progress_bars[0] elif refresh_rate > 0: - progress_bar_callback = ProgressBar(refresh_rate=refresh_rate, process_position=process_position) + progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position) self.trainer.callbacks.append(progress_bar_callback) else: progress_bar_callback = None From 6043efcbb60521f3998acadbea523378eb68f3e5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Oct 2021 17:04:52 +0000 Subject: [PATCH 05/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/deprecated_api/test_remove_1-7.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index e927baec78a8d..16c511b6effd9 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -19,8 +19,8 @@ from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor -from pytorch_lightning.callbacks.progress import ProgressBar from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor +from pytorch_lightning.callbacks.progress import ProgressBar from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger from tests.callbacks.test_callbacks import OldStatefulCallback From 9f6faa7ad7da347fbaa9142cc7ad49b14e6e2e00 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 28 Oct 2021 23:51:17 +0530 Subject: [PATCH 06/10] Update pytorch_lightning/callbacks/progress/progress.py Co-authored-by: ananthsub --- pytorch_lightning/callbacks/progress/progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/progress.py b/pytorch_lightning/callbacks/progress/progress.py index 1a1ddccb7ac8b..e1898a1bec1b1 100644 --- a/pytorch_lightning/callbacks/progress/progress.py +++ b/pytorch_lightning/callbacks/progress/progress.py @@ -16,8 +16,8 @@ class ProgressBar(TQDMProgressBar): - def __init__(self, refresh_rate: int = 1, process_position: int = 0): - super().__init__(refresh_rate=refresh_rate, process_position=process_position) + def __init__(self, *args, **kwargs) + super().__init__(*args, **kwargs) rank_zero_deprecation( "`ProgressBar` has been deprecated in v1.5 and will be removed in v1.7." " It has been renamed to `TQDMProgressBar` instead." From 7efbb60ad419c3f7db5bcbdd68c7543acfa1181b Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 29 Oct 2021 00:06:03 +0530 Subject: [PATCH 07/10] Update pytorch_lightning/callbacks/progress/progress.py --- pytorch_lightning/callbacks/progress/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress/progress.py b/pytorch_lightning/callbacks/progress/progress.py index e1898a1bec1b1..73dc60dc632c8 100644 --- a/pytorch_lightning/callbacks/progress/progress.py +++ b/pytorch_lightning/callbacks/progress/progress.py @@ -16,7 +16,7 @@ class ProgressBar(TQDMProgressBar): - def __init__(self, *args, **kwargs) + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) rank_zero_deprecation( "`ProgressBar` has been deprecated in v1.5 and will be removed in v1.7." From b2ca8b38578c629c6253fa9925bc83dd1682a1f8 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 29 Oct 2021 17:39:58 +0530 Subject: [PATCH 08/10] Update --- docs/source/common/trainer.rst | 2 +- pyproject.toml | 1 + ...gress_bar.py => test_tqdm_progress_bar.py} | 60 +++++++++---------- .../connectors/test_callback_connector.py | 8 +-- .../logging_/test_train_loop_logging.py | 4 +- 5 files changed, 38 insertions(+), 37 deletions(-) rename tests/callbacks/{test_progress_bar.py => test_tqdm_progress_bar.py} (93%) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 3d1f4aae5ddb9..006e14b64db47 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1240,7 +1240,7 @@ See the :doc:`profiler documentation <../advanced/profiler>`. for more details. progress_bar_refresh_rate ^^^^^^^^^^^^^^^^^^^^^^^^^ ``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7. -Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``refresh_rate`` +Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate`` directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar, pass ``enable_progress_bar = False`` to the Trainer. diff --git a/pyproject.toml b/pyproject.toml index 98c0962da8445..a2b83ae93e6aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ module = [ "pytorch_lightning.callbacks.gradient_accumulation_scheduler", "pytorch_lightning.callbacks.lr_monitor", "pytorch_lightning.callbacks.model_summary", + "pytorch_lightning.callbacks.progress", "pytorch_lightning.callbacks.pruning", "pytorch_lightning.callbacks.rich_model_summary", "pytorch_lightning.core.optimizer", diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py similarity index 93% rename from tests/callbacks/test_progress_bar.py rename to tests/callbacks/test_tqdm_progress_bar.py index 9cbf89b64faf7..b92fb18d54ccd 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -23,7 +23,7 @@ from torch.utils.data.dataloader import DataLoader from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBarBase, TQDMProgressBar from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -37,12 +37,12 @@ ([], None), ([], 1), ([], 2), - ([ProgressBar(refresh_rate=1)], 0), - ([ProgressBar(refresh_rate=2)], 0), - ([ProgressBar(refresh_rate=2)], 1), + ([TQDMProgressBar(refresh_rate=1)], 0), + ([TQDMProgressBar(refresh_rate=2)], 0), + ([TQDMProgressBar(refresh_rate=2)], 1), ], ) -def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]): +def test_tqdm_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]): """Test different ways the progress bar can be turned on.""" trainer = Trainer( @@ -63,7 +63,7 @@ def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]): "callbacks,refresh_rate,enable_progress_bar", [([], 0, True), ([], False, True), ([ModelCheckpoint(dirpath="../trainer")], 0, True), ([], 1, False)], ) -def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool): +def test_tqdm_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool): """Test different ways the progress bar can be turned off.""" trainer = Trainer( @@ -73,19 +73,19 @@ def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int enable_progress_bar=enable_progress_bar, ) - progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)] + progress_bars = [c for c in trainer.callbacks if isinstance(c, TQDMProgressBar)] assert 0 == len(progress_bars) assert not trainer.progress_bar_callback -def test_progress_bar_misconfiguration(): +def test_tqdm_progress_bar_misconfiguration(): """Test that Trainer doesn't accept multiple progress bars.""" - callbacks = [ProgressBar(), ProgressBar(), ModelCheckpoint(dirpath="../trainer")] + callbacks = [TQDMProgressBar(), TQDMProgressBar(), ModelCheckpoint(dirpath="../trainer")] with pytest.raises(MisconfigurationException, match=r"^You added multiple progress bar callbacks"): Trainer(callbacks=callbacks) -def test_progress_bar_totals(tmpdir): +def test_tqdm_progress_bar_totals(tmpdir): """Test that the progress finishes with the correct total steps processed.""" model = BoringModel() @@ -138,7 +138,7 @@ def test_progress_bar_totals(tmpdir): assert bar.test_batch_idx == k -def test_progress_bar_fast_dev_run(tmpdir): +def test_tqdm_progress_bar_fast_dev_run(tmpdir): model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) @@ -174,12 +174,12 @@ def test_progress_bar_fast_dev_run(tmpdir): @pytest.mark.parametrize("refresh_rate", [0, 1, 50]) -def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int): +def test_tqdm_progress_bar_progress_refresh(tmpdir, refresh_rate: int): """Test that the three progress bars get correctly updated when using different refresh rates.""" model = BoringModel() - class CurrentProgressBar(ProgressBar): + class CurrentProgressBar(TQDMProgressBar): train_batches_seen = 0 val_batches_seen = 0 @@ -239,7 +239,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int): """Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument.""" - class CurrentProgressBar(ProgressBar): + class CurrentProgressBar(TQDMProgressBar): val_pbar_total = 0 sanity_pbar_total = 0 @@ -271,7 +271,7 @@ def on_validation_epoch_end(self, *args): assert progress_bar.val_pbar_total == limit_val_batches -def test_progress_bar_default_value(tmpdir): +def test_tqdm_progress_bar_default_value(tmpdir): """Test that a value of None defaults to refresh rate 1.""" trainer = Trainer(default_root_dir=tmpdir) assert trainer.progress_bar_callback.refresh_rate == 1 @@ -281,7 +281,7 @@ def test_progress_bar_default_value(tmpdir): @mock.patch.dict(os.environ, {"COLAB_GPU": "1"}) -def test_progress_bar_value_on_colab(tmpdir): +def test_tqdm_progress_bar_value_on_colab(tmpdir): """Test that Trainer will override the default in Google COLAB.""" trainer = Trainer(default_root_dir=tmpdir) assert trainer.progress_bar_callback.refresh_rate == 20 @@ -293,7 +293,7 @@ def test_progress_bar_value_on_colab(tmpdir): assert trainer.progress_bar_callback.refresh_rate == 19 -class MockedUpdateProgressBars(ProgressBar): +class MockedUpdateProgressBars(TQDMProgressBar): """Mocks the update method once bars get initializied.""" def _mock_bar_update(self, bar): @@ -428,10 +428,10 @@ def predict_step(self, *args, **kwargs): @mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write") -def test_progress_bar_print(tqdm_write, tmpdir): +def test_tqdm_progress_bar_print(tqdm_write, tmpdir): """Test that printing in the LightningModule redirects arguments to the progress bar.""" model = PrintModel() - bar = ProgressBar() + bar = TQDMProgressBar() trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, @@ -455,10 +455,10 @@ def test_progress_bar_print(tqdm_write, tmpdir): @mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write") -def test_progress_bar_print_no_train(tqdm_write, tmpdir): +def test_tqdm_progress_bar_print_no_train(tqdm_write, tmpdir): """Test that printing in the LightningModule redirects arguments to the progress bar without training.""" model = PrintModel() - bar = ProgressBar() + bar = TQDMProgressBar() trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, @@ -482,10 +482,10 @@ def test_progress_bar_print_no_train(tqdm_write, tmpdir): @mock.patch("builtins.print") @mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write") -def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): +def test_tqdm_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): """Test that printing in LightningModule goes through built-in print function when progress bar is disabled.""" model = PrintModel() - bar = ProgressBar() + bar = TQDMProgressBar() trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, @@ -507,8 +507,8 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): tqdm_write.assert_not_called() -def test_progress_bar_can_be_pickled(): - bar = ProgressBar() +def test_tqdm_progress_bar_can_be_pickled(): + bar = TQDMProgressBar() trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1) model = BoringModel() @@ -522,14 +522,14 @@ def test_progress_bar_can_be_pickled(): @RunIf(min_gpus=2, special=True) -def test_progress_bar_max_val_check_interval_0(tmpdir): +def test_tqdm_progress_bar_max_val_check_interval_0(tmpdir): _test_progress_bar_max_val_check_interval( tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.2 ) @RunIf(min_gpus=2, special=True) -def test_progress_bar_max_val_check_interval_1(tmpdir): +def test_tqdm_progress_bar_max_val_check_interval_1(tmpdir): _test_progress_bar_max_val_check_interval( tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.5 ) @@ -567,7 +567,7 @@ def _test_progress_bar_max_val_check_interval( def test_get_progress_bar_metrics(tmpdir: str): - class TestProgressBar(ProgressBar): + class TestProgressBar(TQDMProgressBar): def get_metrics(self, trainer: Trainer, model: LightningModule): items = super().get_metrics(trainer, model) items.pop("v_num", None) @@ -588,9 +588,9 @@ def get_metrics(self, trainer: Trainer, model: LightningModule): assert "v_num" not in standard_metrics.keys() -def test_progress_bar_main_bar_resume(): +def test_tqdm_progress_bar_main_bar_resume(): """Test that the progress bar can resume its counters based on the Trainer state.""" - bar = ProgressBar() + bar = TQDMProgressBar() trainer = Mock() model = Mock() diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 949723b022750..7ec238acf5682 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -22,7 +22,7 @@ LearningRateMonitor, ModelCheckpoint, ModelSummary, - ProgressBar, + TQDMProgressBar, ) from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector from tests.helpers import BoringModel @@ -35,7 +35,7 @@ def test_checkpoint_callbacks_are_last(tmpdir): model_summary = ModelSummary() early_stopping = EarlyStopping() lr_monitor = LearningRateMonitor() - progress_bar = ProgressBar() + progress_bar = TQDMProgressBar() # no model reference trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2]) @@ -155,7 +155,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): return trainer early_stopping = EarlyStopping() - progress_bar = ProgressBar() + progress_bar = TQDMProgressBar() lr_monitor = LearningRateMonitor() grad_accumulation = GradientAccumulationScheduler({1: 1}) @@ -199,7 +199,7 @@ def test_attach_model_callbacks_override_info(caplog): """Test that the logs contain the info about overriding callbacks returned by configure_callbacks.""" model = LightningModule() model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()] - trainer = Trainer(enable_checkpointing=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()]) + trainer = Trainer(enable_checkpointing=False, callbacks=[EarlyStopping(), LearningRateMonitor(), TQDMProgressBar()]) trainer.model = model cb_connector = CallbackConnector(trainer) with caplog.at_level(logging.INFO): diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 888bbe8e75108..cd3e024535d90 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -24,7 +24,7 @@ from torchmetrics import Accuracy from pytorch_lightning import callbacks, Trainer -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset @@ -475,7 +475,7 @@ def on_train_epoch_end(self, *_): ) self.on_train_epoch_end_called = True - class TestProgressBar(ProgressBar): + class TestProgressBar(TQDMProgressBar): def get_metrics(self, trainer: Trainer, model: LightningModule): items = super().get_metrics(trainer, model) items.pop("v_num", None) From 2d4d60934d1bb7ffcda5995ae7706c508e2ae64c Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 29 Oct 2021 17:44:19 +0530 Subject: [PATCH 09/10] update typehint --- pytorch_lightning/callbacks/progress/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress/progress.py b/pytorch_lightning/callbacks/progress/progress.py index 73dc60dc632c8..e13e612805f03 100644 --- a/pytorch_lightning/callbacks/progress/progress.py +++ b/pytorch_lightning/callbacks/progress/progress.py @@ -16,7 +16,7 @@ class ProgressBar(TQDMProgressBar): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) rank_zero_deprecation( "`ProgressBar` has been deprecated in v1.5 and will be removed in v1.7." From 29199d747d274d9b86085062eacdcb36cb56c4a1 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 1 Nov 2021 15:24:18 +0530 Subject: [PATCH 10/10] Update --- pytorch_lightning/trainer/connectors/callback_connector.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 8809254599a84..4d41734ed90e6 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -82,14 +82,14 @@ def on_trainer_init( if process_position != 0: rank_zero_deprecation( f"Setting `Trainer(process_position={process_position})` is deprecated in v1.5 and will be removed" - " in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with" + " in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with" " `process_position` directly to the Trainer's `callbacks` argument instead." ) if progress_bar_refresh_rate is not None: rank_zero_deprecation( f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and" - " will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with" + " will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with" " `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress" " bar pass `enable_progress_bar = False` to the Trainer." ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0bbb1a7c37838..933d6a4dad4cb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -294,7 +294,7 @@ def __init__( .. deprecated:: v1.5 ``process_position`` has been deprecated in v1.5 and will be removed in v1.7. - Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``process_position`` + Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``process_position`` directly to the Trainer's ``callbacks`` argument instead. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. @@ -303,7 +303,7 @@ def __init__( .. deprecated:: v1.5 ``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7. - Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``refresh_rate`` + Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate`` directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar, pass ``enable_progress_bar = False`` to the Trainer.