Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ci-pytorch-test-full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ jobs:

pl-cpu:
runs-on: ${{ matrix.os }}
if: github.event.pull_request.draft == false
strategy:
fail-fast: false
matrix:
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/ci-pytorch-test-slow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ concurrency:
jobs:
pl-slow:
runs-on: ${{ matrix.os }}
if: github.event.pull_request.draft == false
strategy:
fail-fast: false
matrix:
Expand Down
22 changes: 4 additions & 18 deletions src/pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,18 +289,13 @@ def _init_progress(self, trainer):
self.progress = CustomProgress(
*self.configure_columns(trainer),
self._metric_component,
auto_refresh=False,
disable=self.is_disabled,
console=self._console,
)
self.progress.start()
# progress has started
self._progress_stopped = False

def refresh(self) -> None:
if self.progress:
self.progress.refresh()

def on_train_start(self, trainer, pl_module):
self._init_progress(trainer)

Expand All @@ -319,7 +314,6 @@ def on_sanity_check_start(self, trainer, pl_module):
def on_sanity_check_end(self, trainer, pl_module):
if self.progress is not None:
self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False)
self.refresh()

def on_train_epoch_start(self, trainer, pl_module):
total_batches = self.total_batches_current_epoch
Expand All @@ -335,8 +329,6 @@ def on_train_epoch_start(self, trainer, pl_module):
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
)

self.refresh()

def on_validation_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
Expand All @@ -359,8 +351,6 @@ def on_validation_batch_start(
self.total_val_batches_current_dataloader, self.validation_description, visible=False
)

self.refresh()

def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
if self.progress is not None:
return self.progress.add_task(
Expand All @@ -376,15 +366,13 @@ def _update(self, progress_bar_id: int, current: int, visible: bool = True) -> N
leftover = current % self.refresh_rate
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
self.progress.update(progress_bar_id, advance=advance, visible=visible)
self.refresh()

def _should_update(self, current: int, total: Union[int, float]) -> bool:
return current % self.refresh_rate == 0 or current == total

def on_validation_epoch_end(self, trainer, pl_module):
if self.val_progress_bar_id is not None and trainer.state.fn == "fit":
self.progress.update(self.val_progress_bar_id, advance=0, visible=False)
self.refresh()

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.state.fn == "fit":
Expand All @@ -406,7 +394,6 @@ def on_test_batch_start(
if self.test_progress_bar_id is not None:
self.progress.update(self.test_progress_bar_id, advance=0, visible=False)
self.test_progress_bar_id = self._add_task(self.total_test_batches_current_dataloader, self.test_description)
self.refresh()

def on_predict_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
Expand All @@ -419,12 +406,10 @@ def on_predict_batch_start(
self.predict_progress_bar_id = self._add_task(
self.total_predict_batches_current_dataloader, self.predict_description
)
self.refresh()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
self._update_metrics(trainer, pl_module)
self.refresh()

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._update_metrics(trainer, pl_module)
Expand All @@ -437,15 +422,16 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
if self.main_progress_bar_id is not None:
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
self._update(self.val_progress_bar_id, self.val_batch_idx)
self.refresh()

# TODO: Find out why an error occurs without refresh here.
if self.progress:
self.progress.refresh()

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._update(self.test_progress_bar_id, self.test_batch_idx)
self.refresh()

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._update(self.predict_progress_bar_id, self.predict_batch_idx)
self.refresh()

def _get_train_description(self, current_epoch: int) -> str:
train_description = f"Epoch {current_epoch}"
Expand Down