Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `lr_sch_names` from `LearningRateMonitor` ([#10066](https://github.com/PyTorchLightning/pytorch-lightning/pull/10066))


- Deprecated `ProgressBar` callback in favor of `TQDMProgressBar` ([#10134](https://github.com/PyTorchLightning/pytorch-lightning/pull/10134))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,7 @@ See the :doc:`profiler documentation <../advanced/profiler>`. for more details.
progress_bar_refresh_rate
^^^^^^^^^^^^^^^^^^^^^^^^^
``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7.
Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``refresh_rate``
Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate``
directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar,
pass ``enable_progress_bar = False`` to the Trainer.

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ module = [
"pytorch_lightning.callbacks.gradient_accumulation_scheduler",
"pytorch_lightning.callbacks.lr_monitor",
"pytorch_lightning.callbacks.model_summary",
"pytorch_lightning.callbacks.progress",
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.callbacks.rich_model_summary",
"pytorch_lightning.core.optimizer",
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.model_summary import ModelSummary
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar, TQDMProgressBar
from pytorch_lightning.callbacks.pruning import ModelPruning
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
Expand Down Expand Up @@ -52,4 +52,5 @@
"RichProgressBar",
"StochasticWeightAveraging",
"Timer",
"TQDMProgressBar",
]
3 changes: 2 additions & 1 deletion pytorch_lightning/callbacks/progress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@

"""
from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401
from pytorch_lightning.callbacks.progress.progress import ProgressBar # noqa: F401
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401
from pytorch_lightning.callbacks.progress.tqdm_progress import ProgressBar # noqa: F401
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar # noqa: F401
24 changes: 24 additions & 0 deletions pytorch_lightning/callbacks/progress/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar
from pytorch_lightning.utilities import rank_zero_deprecation


class ProgressBar(TQDMProgressBar):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
rank_zero_deprecation(
"`ProgressBar` has been deprecated in v1.5 and will be removed in v1.7."
" It has been renamed to `TQDMProgressBar` instead."
)
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def format_num(n) -> str:
return n


class ProgressBar(ProgressBarBase):
class TQDMProgressBar(ProgressBarBase):
r"""
This is the default progress bar used by Lightning. It prints to ``stdout`` using the
:mod:`tqdm` package and shows up to four different bars:
Expand All @@ -75,7 +75,7 @@ class ProgressBar(ProgressBarBase):

Example:

>>> class LitProgressBar(ProgressBar):
>>> class LitProgressBar(TQDMProgressBar):
... def init_validation_tqdm(self):
... bar = super().init_validation_tqdm()
... bar.set_description('running validation ...')
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
GradientAccumulationScheduler,
ModelCheckpoint,
ModelSummary,
ProgressBar,
ProgressBarBase,
RichProgressBar,
TQDMProgressBar,
)
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.timer import Timer
Expand Down Expand Up @@ -82,14 +82,14 @@ def on_trainer_init(
if process_position != 0:
rank_zero_deprecation(
f"Setting `Trainer(process_position={process_position})` is deprecated in v1.5 and will be removed"
" in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with"
" in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with"
" `process_position` directly to the Trainer's `callbacks` argument instead."
)

if progress_bar_refresh_rate is not None:
rank_zero_deprecation(
f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
" will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with"
" will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with"
" `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress"
" bar pass `enable_progress_bar = False` to the Trainer."
)
Expand Down Expand Up @@ -230,7 +230,7 @@ def configure_progress_bar(self, refresh_rate=None, process_position=0):
if len(progress_bars) == 1:
progress_bar_callback = progress_bars[0]
elif refresh_rate > 0:
progress_bar_callback = ProgressBar(refresh_rate=refresh_rate, process_position=process_position)
progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position)
self.trainer.callbacks.append(progress_bar_callback)
else:
progress_bar_callback = None
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(

.. deprecated:: v1.5
``process_position`` has been deprecated in v1.5 and will be removed in v1.7.
Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``process_position``
Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``process_position``
directly to the Trainer's ``callbacks`` argument instead.

progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
Expand All @@ -302,7 +302,7 @@ def __init__(

.. deprecated:: v1.5
``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7.
Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``refresh_rate``
Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate``
directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar,
pass ``enable_progress_bar = False`` to the Trainer.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBarBase, TQDMProgressBar
from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -37,12 +37,12 @@
([], None),
([], 1),
([], 2),
([ProgressBar(refresh_rate=1)], 0),
([ProgressBar(refresh_rate=2)], 0),
([ProgressBar(refresh_rate=2)], 1),
([TQDMProgressBar(refresh_rate=1)], 0),
([TQDMProgressBar(refresh_rate=2)], 0),
([TQDMProgressBar(refresh_rate=2)], 1),
],
)
def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]):
def test_tqdm_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]):
"""Test different ways the progress bar can be turned on."""

trainer = Trainer(
Expand All @@ -63,7 +63,7 @@ def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]):
"callbacks,refresh_rate,enable_progress_bar",
[([], 0, True), ([], False, True), ([ModelCheckpoint(dirpath="../trainer")], 0, True), ([], 1, False)],
)
def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool):
def test_tqdm_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool):
"""Test different ways the progress bar can be turned off."""

trainer = Trainer(
Expand All @@ -73,19 +73,19 @@ def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int
enable_progress_bar=enable_progress_bar,
)

progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)]
progress_bars = [c for c in trainer.callbacks if isinstance(c, TQDMProgressBar)]
assert 0 == len(progress_bars)
assert not trainer.progress_bar_callback


def test_progress_bar_misconfiguration():
def test_tqdm_progress_bar_misconfiguration():
"""Test that Trainer doesn't accept multiple progress bars."""
callbacks = [ProgressBar(), ProgressBar(), ModelCheckpoint(dirpath="../trainer")]
callbacks = [TQDMProgressBar(), TQDMProgressBar(), ModelCheckpoint(dirpath="../trainer")]
with pytest.raises(MisconfigurationException, match=r"^You added multiple progress bar callbacks"):
Trainer(callbacks=callbacks)


def test_progress_bar_totals(tmpdir):
def test_tqdm_progress_bar_totals(tmpdir):
"""Test that the progress finishes with the correct total steps processed."""

model = BoringModel()
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_progress_bar_totals(tmpdir):
assert bar.test_batch_idx == k


def test_progress_bar_fast_dev_run(tmpdir):
def test_tqdm_progress_bar_fast_dev_run(tmpdir):
model = BoringModel()

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
Expand Down Expand Up @@ -174,12 +174,12 @@ def test_progress_bar_fast_dev_run(tmpdir):


@pytest.mark.parametrize("refresh_rate", [0, 1, 50])
def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int):
def test_tqdm_progress_bar_progress_refresh(tmpdir, refresh_rate: int):
"""Test that the three progress bars get correctly updated when using different refresh rates."""

model = BoringModel()

class CurrentProgressBar(ProgressBar):
class CurrentProgressBar(TQDMProgressBar):

train_batches_seen = 0
val_batches_seen = 0
Expand Down Expand Up @@ -239,7 +239,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal
def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int):
"""Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument."""

class CurrentProgressBar(ProgressBar):
class CurrentProgressBar(TQDMProgressBar):
val_pbar_total = 0
sanity_pbar_total = 0

Expand Down Expand Up @@ -271,7 +271,7 @@ def on_validation_epoch_end(self, *args):
assert progress_bar.val_pbar_total == limit_val_batches


def test_progress_bar_default_value(tmpdir):
def test_tqdm_progress_bar_default_value(tmpdir):
"""Test that a value of None defaults to refresh rate 1."""
trainer = Trainer(default_root_dir=tmpdir)
assert trainer.progress_bar_callback.refresh_rate == 1
Expand All @@ -281,7 +281,7 @@ def test_progress_bar_default_value(tmpdir):


@mock.patch.dict(os.environ, {"COLAB_GPU": "1"})
def test_progress_bar_value_on_colab(tmpdir):
def test_tqdm_progress_bar_value_on_colab(tmpdir):
"""Test that Trainer will override the default in Google COLAB."""
trainer = Trainer(default_root_dir=tmpdir)
assert trainer.progress_bar_callback.refresh_rate == 20
Expand All @@ -293,7 +293,7 @@ def test_progress_bar_value_on_colab(tmpdir):
assert trainer.progress_bar_callback.refresh_rate == 19


class MockedUpdateProgressBars(ProgressBar):
class MockedUpdateProgressBars(TQDMProgressBar):
"""Mocks the update method once bars get initializied."""

def _mock_bar_update(self, bar):
Expand Down Expand Up @@ -428,10 +428,10 @@ def predict_step(self, *args, **kwargs):


@mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write")
def test_progress_bar_print(tqdm_write, tmpdir):
def test_tqdm_progress_bar_print(tqdm_write, tmpdir):
"""Test that printing in the LightningModule redirects arguments to the progress bar."""
model = PrintModel()
bar = ProgressBar()
bar = TQDMProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
Expand All @@ -455,10 +455,10 @@ def test_progress_bar_print(tqdm_write, tmpdir):


@mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write")
def test_progress_bar_print_no_train(tqdm_write, tmpdir):
def test_tqdm_progress_bar_print_no_train(tqdm_write, tmpdir):
"""Test that printing in the LightningModule redirects arguments to the progress bar without training."""
model = PrintModel()
bar = ProgressBar()
bar = TQDMProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
Expand All @@ -482,10 +482,10 @@ def test_progress_bar_print_no_train(tqdm_write, tmpdir):

@mock.patch("builtins.print")
@mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write")
def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
def test_tqdm_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
"""Test that printing in LightningModule goes through built-in print function when progress bar is disabled."""
model = PrintModel()
bar = ProgressBar()
bar = TQDMProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
Expand All @@ -507,8 +507,8 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
tqdm_write.assert_not_called()


def test_progress_bar_can_be_pickled():
bar = ProgressBar()
def test_tqdm_progress_bar_can_be_pickled():
bar = TQDMProgressBar()
trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1)
model = BoringModel()

Expand All @@ -522,14 +522,14 @@ def test_progress_bar_can_be_pickled():


@RunIf(min_gpus=2, special=True)
def test_progress_bar_max_val_check_interval_0(tmpdir):
def test_tqdm_progress_bar_max_val_check_interval_0(tmpdir):
_test_progress_bar_max_val_check_interval(
tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.2
)


@RunIf(min_gpus=2, special=True)
def test_progress_bar_max_val_check_interval_1(tmpdir):
def test_tqdm_progress_bar_max_val_check_interval_1(tmpdir):
_test_progress_bar_max_val_check_interval(
tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.5
)
Expand Down Expand Up @@ -567,7 +567,7 @@ def _test_progress_bar_max_val_check_interval(


def test_get_progress_bar_metrics(tmpdir: str):
class TestProgressBar(ProgressBar):
class TestProgressBar(TQDMProgressBar):
def get_metrics(self, trainer: Trainer, model: LightningModule):
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
Expand All @@ -588,9 +588,9 @@ def get_metrics(self, trainer: Trainer, model: LightningModule):
assert "v_num" not in standard_metrics.keys()


def test_progress_bar_main_bar_resume():
def test_tqdm_progress_bar_main_bar_resume():
"""Test that the progress bar can resume its counters based on the Trainer state."""
bar = ProgressBar()
bar = TQDMProgressBar()
trainer = Mock()
model = Mock()

Expand Down
6 changes: 6 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning import Callback, LightningDataModule, Trainer
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.progress import ProgressBar
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger
from tests.callbacks.test_callbacks import OldStatefulCallback
Expand Down Expand Up @@ -391,6 +392,11 @@ def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir):
_ = XLAStatsMonitor()


def test_v1_7_0_progress_bar():
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
_ = ProgressBar()


def test_v1_7_0_deprecated_max_steps_none(tmpdir):
with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"):
_ = Trainer(max_steps=None)
Expand Down
Loading