Skip to content

Commit 21fd56e

Browse files
archsyscallBorda
authored andcommitted
FIX-5311: Cast to string _flatten_dict (#5354)
* fix * params * add test * add another types * chlog Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit 6536ea4)
1 parent 9610ea8 commit 21fd56e

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8282

8383
- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277))
8484

85+
- Fixed casted key to string in `_flatten_dict` ([#5354](https://github.com/PyTorchLightning/pytorch-lightning/pull/5354))
86+
8587

8688
## [1.1.2] - 2020-12-23
8789

pytorch_lightning/loggers/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def _sanitize_callable(val):
208208
return {key: _sanitize_callable(val) for key, val in params.items()}
209209

210210
@staticmethod
211-
def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]:
211+
def _flatten_dict(params: Dict[Any, Any], delimiter: str = '/') -> Dict[str, Any]:
212212
"""
213213
Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
214214
@@ -224,12 +224,15 @@ def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any
224224
{'a/b': 'c'}
225225
>>> LightningLoggerBase._flatten_dict({'a': {'b': 123}})
226226
{'a/b': 123}
227+
>>> LightningLoggerBase._flatten_dict({5: {'a': 123}})
228+
{'5/a': 123}
227229
"""
228230

229231
def _dict_generator(input_dict, prefixes=None):
230232
prefixes = prefixes[:] if prefixes else []
231233
if isinstance(input_dict, MutableMapping):
232234
for key, value in input_dict.items():
235+
key = str(key)
233236
if isinstance(value, (MutableMapping, Namespace)):
234237
value = vars(value) if isinstance(value, Namespace) else value
235238
for d in _dict_generator(value, prefixes + [key]):

tests/loggers/test_tensorboard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_tensorboard_named_version(tmpdir):
102102
expected_version = "2020-02-05-162402"
103103

104104
logger = TensorBoardLogger(save_dir=tmpdir, name=name, version=expected_version)
105-
logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written
105+
logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written
106106

107107
assert logger.version == expected_version
108108
assert os.listdir(tmpdir / name) == [expected_version]
@@ -113,7 +113,7 @@ def test_tensorboard_named_version(tmpdir):
113113
def test_tensorboard_no_name(tmpdir, name):
114114
"""Verify that None or empty name works"""
115115
logger = TensorBoardLogger(save_dir=tmpdir, name=name)
116-
logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written
116+
logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written
117117
assert logger.root_dir == tmpdir
118118
assert os.listdir(tmpdir / "version_0")
119119

0 commit comments

Comments
 (0)