Skip to content

Commit c0f411b

Browse files
committed
Add unit tests
1 parent 6cf1eb4 commit c0f411b

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

pytorch_lightning/utilities/logger.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,17 +148,17 @@ def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[
148148
return metrics
149149

150150

151-
def _name(loggers: List[Any]) -> str:
151+
def _name(loggers: List[Any], separator: Optional[str] = "_") -> str:
152152
if len(loggers) == 1:
153153
return loggers[0].name
154154
else:
155155
# Concatenate names together, removing duplicates and preserving order
156-
return "_".join(dict.fromkeys(str(logger.name) for logger in loggers))
156+
return separator.join(dict.fromkeys(str(logger.name) for logger in loggers))
157157

158158

159-
def _version(loggers: List[Any]) -> Union[int, str]:
159+
def _version(loggers: List[Any], separator: Optional[str] = "_") -> Union[int, str]:
160160
if len(loggers) == 1:
161161
return loggers[0].version
162162
else:
163163
# Concatenate versions together, removing duplicates and preserving order
164-
return "_".join(dict.fromkeys(str(logger.version) for logger in loggers))
164+
return separator.join(dict.fromkeys(str(logger.version) for logger in loggers))

tests/utilities/test_logger.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
import torch
1818

1919
from pytorch_lightning import Trainer
20+
from pytorch_lightning.loggers import CSVLogger
2021
from pytorch_lightning.utilities.logger import (
2122
_add_prefix,
2223
_convert_params,
2324
_flatten_dict,
25+
_name,
2426
_sanitize_callable_params,
2527
_sanitize_params,
28+
_version,
2629
)
2730

2831

@@ -172,3 +175,37 @@ def test_add_prefix():
172175
assert "prefix-metric2" not in metrics
173176
assert metrics["prefix2_prefix-metric1"] == 1
174177
assert metrics["prefix2_prefix-metric2"] == 2
178+
179+
180+
def test_name(tmpdir):
181+
"""Verify names of loggers are concatenated properly."""
182+
logger1 = CSVLogger(tmpdir, name="foo")
183+
logger2 = CSVLogger(tmpdir, name="bar")
184+
logger3 = CSVLogger(tmpdir, name="foo")
185+
logger4 = CSVLogger(tmpdir, name="baz")
186+
loggers = [logger1, logger2, logger3, logger4]
187+
name = _name([])
188+
assert name == ""
189+
name = _name([logger3])
190+
assert name == "foo"
191+
name = _name(loggers)
192+
assert name == "foo_bar_baz"
193+
name = _name(loggers, "-")
194+
assert name == "foo-bar-baz"
195+
196+
197+
def test_version(tmpdir):
198+
"""Verify names of loggers are concatenated properly."""
199+
logger1 = CSVLogger(tmpdir, version=0)
200+
logger2 = CSVLogger(tmpdir, version=2)
201+
logger3 = CSVLogger(tmpdir, version=1)
202+
logger4 = CSVLogger(tmpdir, version=0)
203+
loggers = [logger1, logger2, logger3, logger4]
204+
version = _version([])
205+
assert version == ""
206+
version = _version([logger3])
207+
assert version == 1
208+
version = _version(loggers)
209+
assert version == "0_2_1"
210+
version = _version(loggers, "-")
211+
assert version == "0-2-1"

0 commit comments

Comments
 (0)