Skip to content
Merged
9 changes: 9 additions & 0 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
MLflow
------
"""
import re
import warnings
from argparse import Namespace
from time import time
from typing import Any, Dict, Optional, Union
Expand Down Expand Up @@ -151,6 +153,13 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
if isinstance(v, str):
log.warning(f'Discarding metric with string value {k}={v}.')
continue

new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
if k != new_k:
warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n",
f"Replacing {k} with {new_k}."))
k = new_k

self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)

@rank_zero_only
Expand Down
3 changes: 2 additions & 1 deletion tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def test_mlflow_logger_dirs_creation(tmpdir):
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}

model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3,
log_gpu_memory=True)
trainer.fit(model)
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}
assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')
Expand Down