diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 8d509d41d52bf..42b5f7a36641c 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -15,7 +15,7 @@ import os from abc import ABC from argparse import ArgumentParser, Namespace -from typing import List, Optional, Union, Type, TypeVar +from typing import List, Optional, Union, Type, TypeVar, cast from pytorch_lightning.callbacks import Callback, ProgressBarBase, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule @@ -154,6 +154,7 @@ def progress_bar_callback(self): def progress_bar_dict(self) -> dict: """ Read-only for progress bar metrics. """ ref_model = self.model if not self.data_parallel else self.model.module + ref_model = cast(LightningModule, ref_model) return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics) @property