Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed a pickling error when using `RichProgressBar` together with checkpointing ([#15319](https://github.com/Lightning-AI/lightning/pull/15319))
- Fixed the `RichProgressBar` crashing when used with distributed strategies ([#15376](https://github.com/Lightning-AI/lightning/pull/15376))


- Fixed an issue with `RichProgressBar` not resetting the internal state for the sanity check progress ([#15377](https://github.com/Lightning-AI/lightning/pull/15377))
Expand Down
71 changes: 40 additions & 31 deletions src/pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,9 @@ def __init__(
self._console: Optional[Console] = None
self._console_kwargs = console_kwargs or {}
self._enabled: bool = True
self.progress: Optional[Progress] = None
self.val_sanity_progress_bar_id: Optional["TaskID"] = None
self.progress: Optional[CustomProgress] = None
self.main_progress_bar_id: Optional["TaskID"]
self.val_sanity_progress_bar_id: Optional["TaskID"] = None
self.val_progress_bar_id: Optional["TaskID"]
self.test_progress_bar_id: Optional["TaskID"]
self.predict_progress_bar_id: Optional["TaskID"]
Expand All @@ -279,6 +279,30 @@ def is_enabled(self) -> bool:
def is_disabled(self) -> bool:
return not self.is_enabled

@property
def main_progress_bar(self) -> Task:
assert self.progress is not None
assert self.main_progress_bar_id is not None
return self.progress.tasks[self.main_progress_bar_id]

@property
def val_sanity_check_bar(self) -> Task:
assert self.progress is not None
assert self.val_sanity_progress_bar_id is not None
return self.progress.tasks[self.val_sanity_progress_bar_id]

@property
def val_progress_bar(self) -> Task:
assert self.progress is not None
assert self.val_progress_bar_id is not None
return self.progress.tasks[self.val_progress_bar_id]

@property
def test_progress_bar(self) -> Task:
assert self.progress is not None
assert self.test_progress_bar_id is not None
return self.progress.tasks[self.test_progress_bar_id]

def _update_for_light_colab_theme(self) -> None:
if _detect_light_colab_theme():
attributes = ["description", "batch_progress", "metrics"]
Expand Down Expand Up @@ -336,6 +360,8 @@ def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod
self.refresh()

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.is_disabled:
return
total_batches = self.total_batches_current_epoch
train_description = self._get_train_description(trainer.current_epoch)

Expand All @@ -355,9 +381,11 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
def on_validation_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if not self.has_dataloader_changed(dataloader_idx) or self.progress is None:
if self.is_disabled or not self.has_dataloader_changed(dataloader_idx):
return

assert self.progress is not None

if trainer.sanity_checking:
if self.val_sanity_progress_bar_id is not None:
self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False)
Expand Down Expand Up @@ -397,7 +425,7 @@ def _should_update(self, current: int, total: Union[int, float]) -> bool:
return current % self.refresh_rate == 0 or current == total

def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.val_progress_bar_id is not None and trainer.state.fn == "fit":
if self.is_enabled and self.val_progress_bar_id is not None and trainer.state.fn == "fit":
assert self.progress is not None
self.progress.update(self.val_progress_bar_id, advance=0, visible=False)
self.refresh()
Expand All @@ -416,7 +444,7 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule")
def on_test_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if not self.has_dataloader_changed(dataloader_idx):
if self.is_disabled or not self.has_dataloader_changed(dataloader_idx):
return

if self.test_progress_bar_id is not None:
Expand All @@ -428,7 +456,7 @@ def on_test_batch_start(
def on_predict_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
if not self.has_dataloader_changed(dataloader_idx):
if self.is_disabled or not self.has_dataloader_changed(dataloader_idx):
return

if self.predict_progress_bar_id is not None:
Expand Down Expand Up @@ -458,8 +486,9 @@ def on_validation_batch_end(
batch_idx: int,
dataloader_idx: int,
) -> None:
if self.is_disabled:
return
if trainer.sanity_checking:
assert self.val_sanity_progress_bar_id is not None
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx)
elif self.val_progress_bar_id is not None:
# check to see if we should update the main training progress bar
Expand All @@ -477,6 +506,8 @@ def on_test_batch_end(
batch_idx: int,
dataloader_idx: int,
) -> None:
if self.is_disabled:
return
assert self.test_progress_bar_id is not None
self._update(self.test_progress_bar_id, self.test_batch_idx)
self.refresh()
Expand All @@ -490,6 +521,8 @@ def on_predict_batch_end(
batch_idx: int,
dataloader_idx: int,
) -> None:
if self.is_disabled:
return
assert self.predict_progress_bar_id is not None
self._update(self.predict_progress_bar_id, self.predict_batch_idx)
self.refresh()
Expand Down Expand Up @@ -528,30 +561,6 @@ def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
self._stop_progress()

@property
def val_progress_bar(self) -> Task:
assert self.progress is not None
assert self.val_progress_bar_id is not None
return self.progress.tasks[self.val_progress_bar_id]

@property
def val_sanity_check_bar(self) -> Task:
assert self.progress is not None
assert self.val_sanity_progress_bar_id is not None
return self.progress.tasks[self.val_sanity_progress_bar_id]

@property
def main_progress_bar(self) -> Task:
assert self.progress is not None
assert self.main_progress_bar_id is not None
return self.progress.tasks[self.main_progress_bar_id]

@property
def test_progress_bar(self) -> Task:
assert self.progress is not None
assert self.test_progress_bar_id is not None
return self.progress.tasks[self.test_progress_bar_id]

def configure_columns(self, trainer: "pl.Trainer") -> list:
return [
TextColumn("[progress.description]{task.description}"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def test_rich_progress_bar_padding():
assert len(progress_bar.validation_description) == len(train_description)


@RunIf(rich=True)
def test_rich_progress_bar_can_be_pickled():
bar = RichProgressBar()
trainer = Trainer(
Expand Down Expand Up @@ -475,3 +476,37 @@ def _set_fake_bar_ids():

# the progress object remains in case we need it for the next stage
assert bar.progress is not None


@RunIf(rich=True)
def test_rich_progress_bar_disabled(tmpdir):
"""Test that in a disabled bar there are no updates and no internal progress objects."""
bar = RichProgressBar()
bar.disable()
assert bar.is_disabled

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
limit_predict_batches=2,
max_epochs=1,
enable_model_summary=False,
enable_checkpointing=False,
callbacks=[bar],
)

with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.CustomProgress") as mocked:
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)

mocked.assert_not_called()
assert bar.main_progress_bar_id is None
assert bar.val_sanity_progress_bar_id is None
assert bar.val_progress_bar_id is None
assert bar.test_progress_bar_id is None
assert bar.predict_progress_bar_id is None