Skip to content

Commit dfae734

Browse files
awaelchlijustusschockpre-commit-ci[bot]tchaton
authored
sanitize arrays when logging as hyperparameters in TensorBoardLogger (#9031)
Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas chaton <[email protected]>
1 parent 1feec8c commit dfae734

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6363
- Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974))
6464

6565

66+
- Added sanitization of tensors when they get logged as hyperparameters in `TensorBoardLogger` ([#9031](https://github.com/PyTorchLightning/pytorch-lightning/pull/9031))
67+
68+
6669
- Added `InterBatchParallelDataFetcher` ([#9020](https://github.com/PyTorchLightning/pytorch-lightning/pull/9020))
6770

6871

pytorch_lightning/loggers/tensorboard.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from argparse import Namespace
2222
from typing import Any, Dict, Optional, Union
2323

24+
import numpy as np
2425
import torch
2526
from torch.utils.tensorboard import SummaryWriter
2627
from torch.utils.tensorboard.summary import hparams
@@ -286,6 +287,12 @@ def _get_next_version(self):
286287

287288
return max(existing_versions) + 1
288289

290+
@staticmethod
291+
def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
292+
params = LightningLoggerBase._sanitize_params(params)
293+
# logging of arrays with dimension > 1 is not supported, sanitize as string
294+
return {k: str(v) if isinstance(v, (torch.Tensor, np.ndarray)) and v.ndim > 1 else v for k, v in params.items()}
295+
289296
def __getstate__(self):
290297
state = self.__dict__.copy()
291298
state["_experiment"] = None

tests/loggers/test_tensorboard.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from argparse import Namespace
1818
from unittest import mock
1919

20+
import numpy as np
2021
import pytest
2122
import torch
2223
import yaml
@@ -178,6 +179,8 @@ def test_tensorboard_log_hyperparams(tmpdir):
178179
"list": [1, 2, 3],
179180
"namespace": Namespace(foo=Namespace(bar="buzz")),
180181
"layer": torch.nn.BatchNorm1d,
182+
"tensor": torch.empty(2, 2, 2),
183+
"array": np.empty([2, 2, 2]),
181184
}
182185
logger.log_hyperparams(hparams)
183186

@@ -193,6 +196,8 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
193196
"list": [1, 2, 3],
194197
"namespace": Namespace(foo=Namespace(bar="buzz")),
195198
"layer": torch.nn.BatchNorm1d,
199+
"tensor": torch.empty(2, 2, 2),
200+
"array": np.empty([2, 2, 2]),
196201
}
197202
metrics = {"abc": torch.tensor([0.54])}
198203
logger.log_hyperparams(hparams, metrics)

0 commit comments

Comments
 (0)