Skip to content

Commit f45f368

Browse files
dscarmorohitgr7tchaton
authored andcommitted
fix logged keys in mlflow logger (#4412)
* [#4411] fix gpu_log_memory with mlflow logger * sanitize parenthesis instead of removing for all loggers * apply regex for mlflow key sanitization * replace ',' with '.' typo * add single warning and test Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: chaton <[email protected]> (cherry picked from commit 470e294)
1 parent e59a05d commit f45f368

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

pytorch_lightning/loggers/mlflow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
MLflow
1717
------
1818
"""
19+
import re
20+
import warnings
1921
from argparse import Namespace
2022
from time import time
2123
from typing import Any, Dict, Optional, Union
@@ -151,6 +153,13 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
151153
if isinstance(v, str):
152154
log.warning(f'Discarding metric with string value {k}={v}.')
153155
continue
156+
157+
new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
158+
if k != new_k:
159+
warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n",
160+
f"Replacing {k} with {new_k}."))
161+
k = new_k
162+
154163
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
155164

156165
@rank_zero_only

tests/loggers/test_mlflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ def test_mlflow_logger_dirs_creation(tmpdir):
137137
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}
138138

139139
model = EvalModelTemplate()
140-
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
140+
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3,
141+
log_gpu_memory=True)
141142
trainer.fit(model)
142143
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}
143144
assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')

0 commit comments

Comments
 (0)