Skip to content

Commit 054bf13

Browse files
jxtngxawaelchlirohitgr7carmocca
authored andcommitted
Fix mypy errors attributed to pytorch_lightning.loggers.tensorboard.py (#13688)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 94cb590 commit 054bf13

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ module = [
6060
"pytorch_lightning.loggers.comet",
6161
"pytorch_lightning.loggers.mlflow",
6262
"pytorch_lightning.loggers.neptune",
63-
"pytorch_lightning.loggers.tensorboard",
6463
"pytorch_lightning.loggers.wandb",
6564
"pytorch_lightning.profilers.advanced",
6665
"pytorch_lightning.profilers.base",

src/pytorch_lightning/loggers/tensorboard.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(
9696
sub_dir: Optional[str] = None,
9797
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
9898
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
99-
**kwargs,
99+
**kwargs: Any,
100100
):
101101
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
102102
self._save_dir = save_dir
@@ -108,8 +108,8 @@ def __init__(
108108
self._prefix = prefix
109109
self._fs = get_filesystem(save_dir)
110110

111-
self._experiment = None
112-
self.hparams = {}
111+
self._experiment: Optional["SummaryWriter"] = None
112+
self.hparams: Union[Dict[str, Any], Namespace] = {}
113113
self._kwargs = kwargs
114114

115115
@property
@@ -138,7 +138,7 @@ def log_dir(self) -> str:
138138
return log_dir
139139

140140
@property
141-
def save_dir(self) -> Optional[str]:
141+
def save_dir(self) -> str:
142142
"""Gets the save directory where the TensorBoard experiments are saved.
143143
144144
Returns:
@@ -155,7 +155,7 @@ def sub_dir(self) -> Optional[str]:
155155
"""
156156
return self._sub_dir
157157

158-
@property
158+
@property # type: ignore[misc]
159159
@rank_zero_experiment
160160
def experiment(self) -> SummaryWriter:
161161
r"""
@@ -236,7 +236,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
236236
raise ValueError(m) from ex
237237

238238
@rank_zero_only
239-
def log_graph(self, model: "pl.LightningModule", input_array=None):
239+
def log_graph(self, model: "pl.LightningModule", input_array: Optional[Tensor] = None) -> None:
240240
if self._log_graph:
241241
if input_array is None:
242242
input_array = model.example_input_array
@@ -281,7 +281,7 @@ def name(self) -> str:
281281
return self._name
282282

283283
@property
284-
def version(self) -> int:
284+
def version(self) -> Union[int, str]:
285285
"""Get the experiment version.
286286
287287
Returns:
@@ -291,7 +291,7 @@ def version(self) -> int:
291291
self._version = self._get_next_version()
292292
return self._version
293293

294-
def _get_next_version(self):
294+
def _get_next_version(self) -> int:
295295
root_dir = self.root_dir
296296

297297
try:
@@ -318,7 +318,7 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
318318
# logging of arrays with dimension > 1 is not supported, sanitize as string
319319
return {k: str(v) if isinstance(v, (Tensor, np.ndarray)) and v.ndim > 1 else v for k, v in params.items()}
320320

321-
def __getstate__(self):
321+
def __getstate__(self) -> Dict[str, Any]:
322322
state = self.__dict__.copy()
323323
state["_experiment"] = None
324324
return state

0 commit comments

Comments
 (0)