Skip to content

Commit 47bf2ce

Browse files
committed
Update progress_bar_metrics return _PBAR_DICT
1 parent b6b071f commit 47bf2ce

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/pytorch_lightning/callbacks/progress/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
222222
if not trainer.is_global_zero:
223223
self.disable()
224224

225-
def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
225+
def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str, float, Dict[str, float]]]:
226226
r"""
227227
Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.
228228
Implement this to override the items displayed in the progress bar.

src/pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
7878
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
7979
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
80-
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
80+
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _PBAR_DICT, _ResultCollection
8181
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
8282
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
8383
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
@@ -2735,7 +2735,7 @@ def logged_metrics(self) -> _OUT_DICT:
27352735
return self._logger_connector.logged_metrics
27362736

27372737
@property
2738-
def progress_bar_metrics(self) -> Dict:
2738+
def progress_bar_metrics(self) -> _PBAR_DICT:
27392739
return self._logger_connector.progress_bar_metrics
27402740

27412741
@property

0 commit comments

Comments
 (0)