diff --git a/.github/workflows/ci-pytorch-test-full.yml b/.github/workflows/ci-pytorch-test-full.yml index e4c5ecd9cc0c1..c1506e41b24fd 100644 --- a/.github/workflows/ci-pytorch-test-full.yml +++ b/.github/workflows/ci-pytorch-test-full.yml @@ -16,7 +16,6 @@ jobs: pl-cpu: runs-on: ${{ matrix.os }} - if: github.event.pull_request.draft == false strategy: fail-fast: false matrix: diff --git a/.github/workflows/ci-pytorch-test-slow.yml b/.github/workflows/ci-pytorch-test-slow.yml index c1b2ab2292009..8d990ad4060e9 100644 --- a/.github/workflows/ci-pytorch-test-slow.yml +++ b/.github/workflows/ci-pytorch-test-slow.yml @@ -21,7 +21,6 @@ concurrency: jobs: pl-slow: runs-on: ${{ matrix.os }} - if: github.event.pull_request.draft == false strategy: fail-fast: false matrix: diff --git a/src/pytorch_lightning/callbacks/progress/rich_progress.py b/src/pytorch_lightning/callbacks/progress/rich_progress.py index cb0ce0c52e72d..70e522df72958 100644 --- a/src/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/src/pytorch_lightning/callbacks/progress/rich_progress.py @@ -289,7 +289,6 @@ def _init_progress(self, trainer): self.progress = CustomProgress( *self.configure_columns(trainer), self._metric_component, - auto_refresh=False, disable=self.is_disabled, console=self._console, ) @@ -297,10 +296,6 @@ def _init_progress(self, trainer): # progress has started self._progress_stopped = False - def refresh(self) -> None: - if self.progress: - self.progress.refresh() - def on_train_start(self, trainer, pl_module): self._init_progress(trainer) @@ -319,7 +314,6 @@ def on_sanity_check_start(self, trainer, pl_module): def on_sanity_check_end(self, trainer, pl_module): if self.progress is not None: self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) - self.refresh() def on_train_epoch_start(self, trainer, pl_module): total_batches = self.total_batches_current_epoch @@ -335,8 +329,6 @@ def on_train_epoch_start(self, trainer, pl_module): self.main_progress_bar_id, total=total_batches, description=train_description, visible=True ) - self.refresh() - def on_validation_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: @@ -359,8 +351,6 @@ def on_validation_batch_start( self.total_val_batches_current_dataloader, self.validation_description, visible=False ) - self.refresh() - def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: return self.progress.add_task( @@ -376,7 +366,6 @@ def _update(self, progress_bar_id: int, current: int, visible: bool = True) -> N leftover = current % self.refresh_rate advance = leftover if (current == total and leftover != 0) else self.refresh_rate self.progress.update(progress_bar_id, advance=advance, visible=visible) - self.refresh() def _should_update(self, current: int, total: Union[int, float]) -> bool: return current % self.refresh_rate == 0 or current == total @@ -384,7 +373,6 @@ def _should_update(self, current: int, total: Union[int, float]) -> bool: def on_validation_epoch_end(self, trainer, pl_module): if self.val_progress_bar_id is not None and trainer.state.fn == "fit": self.progress.update(self.val_progress_bar_id, advance=0, visible=False) - self.refresh() def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if trainer.state.fn == "fit": @@ -406,7 +394,6 @@ def on_test_batch_start( if self.test_progress_bar_id is not None: self.progress.update(self.test_progress_bar_id, advance=0, visible=False) self.test_progress_bar_id = self._add_task(self.total_test_batches_current_dataloader, self.test_description) - self.refresh() def on_predict_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int @@ -419,12 +406,10 @@ def on_predict_batch_start( self.predict_progress_bar_id = self._add_task( self.total_predict_batches_current_dataloader, self.predict_description ) - self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed) self._update_metrics(trainer, pl_module) - self.refresh() def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._update_metrics(trainer, pl_module) @@ -437,15 +422,16 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, if self.main_progress_bar_id is not None: self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed) self._update(self.val_progress_bar_id, self.val_batch_idx) - self.refresh() + + # TODO: Find out why an error occurs without refresh here. + if self.progress: + self.progress.refresh() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._update(self.test_progress_bar_id, self.test_batch_idx) - self.refresh() def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._update(self.predict_progress_bar_id, self.predict_batch_idx) - self.refresh() def _get_train_description(self, current_epoch: int) -> str: train_description = f"Epoch {current_epoch}"