From 3680c9dd699942c84271b8e8c8e592ee1d803e6d Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 13:16:49 +0900 Subject: [PATCH 1/5] fix --- pytorch_lightning/loggers/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index a27998366b671..921551a189ca6 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -229,6 +229,7 @@ def _dict_generator(input_dict, prefixes=None): prefixes = prefixes[:] if prefixes else [] if isinstance(input_dict, MutableMapping): for key, value in input_dict.items(): + key = str(key) if isinstance(value, (MutableMapping, Namespace)): value = vars(value) if isinstance(value, Namespace) else value for d in _dict_generator(value, prefixes + [key]): From 06c0efc511b012f0201079a6de378eb0cc587f21 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 13:22:29 +0900 Subject: [PATCH 2/5] params --- pytorch_lightning/loggers/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 921551a189ca6..ac7ab3e023bdb 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -207,7 +207,7 @@ def _sanitize_callable(val): return {key: _sanitize_callable(val) for key, val in params.items()} @staticmethod - def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]: + def _flatten_dict(params: Dict[Any, Any], delimiter: str = '/') -> Dict[str, Any]: """ Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. @@ -223,6 +223,8 @@ def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any {'a/b': 'c'} >>> LightningLoggerBase._flatten_dict({'a': {'b': 123}}) {'a/b': 123} + >>> LightningLoggerBase._flatten_dict({5: {'a': 123}}) + {'5/a': 123} """ def _dict_generator(input_dict, prefixes=None): From ebc7afd68a7c9ad859ea3a6abe18120ee77ef714 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 13:50:27 +0900 Subject: [PATCH 3/5] add test --- tests/loggers/test_tensorboard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 15a024003ebf0..178ad4ad32b57 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -113,7 +113,7 @@ def test_tensorboard_named_version(tmpdir): def test_tensorboard_no_name(tmpdir, name): """Verify that None or empty name works""" logger = TensorBoardLogger(save_dir=tmpdir, name=name) - logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written + logger.log_hyperparams({"a": 1, "b": 2, 123: 3}) # Force data to be written assert logger.root_dir == tmpdir assert os.listdir(tmpdir / "version_0") From 43ef95456efd1a6ad0625b9c6dfd4a962fa846b2 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 16:51:07 +0900 Subject: [PATCH 4/5] add another types --- tests/loggers/test_tensorboard.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 178ad4ad32b57..fa5c711357ba3 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -22,7 +22,7 @@ from omegaconf import OmegaConf from tensorboard.backend.event_processing.event_accumulator import EventAccumulator -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.loggers import TensorBoardLogger from tests.base import BoringModel, EvalModelTemplate @@ -102,7 +102,7 @@ def test_tensorboard_named_version(tmpdir): expected_version = "2020-02-05-162402" logger = TensorBoardLogger(save_dir=tmpdir, name=name, version=expected_version) - logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written + logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written assert logger.version == expected_version assert os.listdir(tmpdir / name) == [expected_version] @@ -113,7 +113,7 @@ def test_tensorboard_named_version(tmpdir): def test_tensorboard_no_name(tmpdir, name): """Verify that None or empty name works""" logger = TensorBoardLogger(save_dir=tmpdir, name=name) - logger.log_hyperparams({"a": 1, "b": 2, 123: 3}) # Force data to be written + logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written assert logger.root_dir == tmpdir assert os.listdir(tmpdir / "version_0") From 4fb5040efba9ba18054dfadbc08240533fe8577c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 5 Jan 2021 09:12:35 +0100 Subject: [PATCH 5/5] chlog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9801f56f6f1bf..804429f41174d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277)) +- Fixed casted key to string in `_flatten_dict` ([#5354](https://github.com/PyTorchLightning/pytorch-lightning/pull/5354)) + ## [1.1.2] - 2020-12-23