Skip to content

Commit 4b357c0

Browse files
authored
Merge branch 'master' into 1.1.3-release
2 parents 9ce2cd3 + 6536ea4 commit 4b357c0

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929
- Fixed existence check for hparams not using underlying filesystem ([#5250](https://github.com/PyTorchLightning/pytorch-lightning/pull/5250))
3030
- Fixed `LightningOptimizer` AMP bug ([#5191](https://github.com/PyTorchLightning/pytorch-lightning/pull/5191))
3131

32+
- Fixed casted key to string in `_flatten_dict` ([#5354](https://github.com/PyTorchLightning/pytorch-lightning/pull/5354))
33+
3234

3335
## [1.1.2] - 2020-12-23
3436

pytorch_lightning/loggers/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _sanitize_callable(val):
207207
return {key: _sanitize_callable(val) for key, val in params.items()}
208208

209209
@staticmethod
210-
def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]:
210+
def _flatten_dict(params: Dict[Any, Any], delimiter: str = '/') -> Dict[str, Any]:
211211
"""
212212
Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
213213
@@ -223,12 +223,15 @@ def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any
223223
{'a/b': 'c'}
224224
>>> LightningLoggerBase._flatten_dict({'a': {'b': 123}})
225225
{'a/b': 123}
226+
>>> LightningLoggerBase._flatten_dict({5: {'a': 123}})
227+
{'5/a': 123}
226228
"""
227229

228230
def _dict_generator(input_dict, prefixes=None):
229231
prefixes = prefixes[:] if prefixes else []
230232
if isinstance(input_dict, MutableMapping):
231233
for key, value in input_dict.items():
234+
key = str(key)
232235
if isinstance(value, (MutableMapping, Namespace)):
233236
value = vars(value) if isinstance(value, Namespace) else value
234237
for d in _dict_generator(value, prefixes + [key]):

tests/loggers/test_tensorboard.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from omegaconf import OmegaConf
2323
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
2424

25-
from pytorch_lightning import Trainer, seed_everything
25+
from pytorch_lightning import seed_everything, Trainer
2626
from pytorch_lightning.loggers import TensorBoardLogger
2727
from tests.base import BoringModel, EvalModelTemplate
2828

@@ -102,7 +102,7 @@ def test_tensorboard_named_version(tmpdir):
102102
expected_version = "2020-02-05-162402"
103103

104104
logger = TensorBoardLogger(save_dir=tmpdir, name=name, version=expected_version)
105-
logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written
105+
logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written
106106

107107
assert logger.version == expected_version
108108
assert os.listdir(tmpdir / name) == [expected_version]
@@ -113,7 +113,7 @@ def test_tensorboard_named_version(tmpdir):
113113
def test_tensorboard_no_name(tmpdir, name):
114114
"""Verify that None or empty name works"""
115115
logger = TensorBoardLogger(save_dir=tmpdir, name=name)
116-
logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written
116+
logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written
117117
assert logger.root_dir == tmpdir
118118
assert os.listdir(tmpdir / "version_0")
119119

0 commit comments

Comments
 (0)