diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a57607722662..1af563f740fec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -457,6 +457,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the lr-scheduler state not being dumped to checkpoint when using the deepspeed strategy ([#11307](https://github.com/PyTorchLightning/pytorch-lightning/pull/11307)) +- Fixed bug where progress bar was not being disabled when not in rank zero during predict ([#11377](https://github.com/PyTorchLightning/pytorch-lightning/pull/11377)) + + - Fixed `SimpleProfiler` summary ([#11414](https://github.com/PyTorchLightning/pytorch-lightning/pull/11414)) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 4babd823e82d5..5c35bf122bf0f 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -131,11 +131,7 @@ def total_predict_batches(self) -> Union[int, float]: return sum(self.trainer.num_predict_batches) def disable(self) -> None: - """You should provide a way to disable the progress bar. - - The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the - output on processes that have a rank different from 0, e.g., in multi-node training. - """ + """You should provide a way to disable the progress bar.""" raise NotImplementedError def enable(self) -> None: @@ -153,6 +149,8 @@ def print(self, *args: Any, **kwargs: Any) -> None: def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: self._trainer = trainer + if not trainer.is_global_zero: + self.disable() def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]: r""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6cf6dd51cf04a..89c4061dbeaff 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1280,9 +1280,6 @@ def _pre_training_routine(self): def _run_train(self) -> None: self._pre_training_routine() - if not self.is_global_zero and self.progress_bar_callback is not None: - self.progress_bar_callback.disable() - self._run_sanity_check() # enable train mode @@ -1294,9 +1291,6 @@ def _run_train(self) -> None: self.fit_loop.run() def _run_evaluate(self) -> _EVALUATE_OUTPUT: - if not self.is_global_zero and self.progress_bar_callback is not None: - self.progress_bar_callback.disable() - assert self.evaluating # reload dataloaders diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index e484e1cb5b32f..1b7faf8c3ae88 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -17,7 +17,7 @@ from collections import defaultdict from typing import Union from unittest import mock -from unittest.mock import ANY, call +from unittest.mock import ANY, call, PropertyMock import pytest import torch @@ -618,3 +618,30 @@ def test_step(self, batch, batch_idx): trainer.test(model, verbose=False) assert pbar.calls["test"] == [] + + +@mock.patch("pytorch_lightning.trainer.trainer.Trainer.is_global_zero", new_callable=PropertyMock, return_value=False) +def test_tqdm_progress_bar_disabled_when_not_rank_zero(is_global_zero): + """Test that the progress bar is disabled when not in global rank zero.""" + progress_bar = TQDMProgressBar() + model = BoringModel() + trainer = Trainer( + callbacks=[progress_bar], + fast_dev_run=True, + ) + + progress_bar.enable() + trainer.fit(model) + assert progress_bar.is_disabled + + progress_bar.enable() + trainer.predict(model) + assert progress_bar.is_disabled + + progress_bar.enable() + trainer.validate(model) + assert progress_bar.is_disabled + + progress_bar.enable() + trainer.test(model) + assert progress_bar.is_disabled