|
17 | 17 | import torch |
18 | 18 |
|
19 | 19 | from pytorch_lightning import Trainer |
| 20 | +from pytorch_lightning.loggers import CSVLogger |
20 | 21 | from pytorch_lightning.utilities.logger import ( |
21 | 22 | _add_prefix, |
22 | 23 | _convert_params, |
23 | 24 | _flatten_dict, |
| 25 | + _name, |
24 | 26 | _sanitize_callable_params, |
25 | 27 | _sanitize_params, |
| 28 | + _version, |
26 | 29 | ) |
27 | 30 |
|
28 | 31 |
|
@@ -172,3 +175,37 @@ def test_add_prefix(): |
172 | 175 | assert "prefix-metric2" not in metrics |
173 | 176 | assert metrics["prefix2_prefix-metric1"] == 1 |
174 | 177 | assert metrics["prefix2_prefix-metric2"] == 2 |
| 178 | + |
| 179 | + |
| 180 | +def test_name(tmpdir): |
| 181 | + """Verify names of loggers are concatenated properly.""" |
| 182 | + logger1 = CSVLogger(tmpdir, name="foo") |
| 183 | + logger2 = CSVLogger(tmpdir, name="bar") |
| 184 | + logger3 = CSVLogger(tmpdir, name="foo") |
| 185 | + logger4 = CSVLogger(tmpdir, name="baz") |
| 186 | + loggers = [logger1, logger2, logger3, logger4] |
| 187 | + name = _name([]) |
| 188 | + assert name == "" |
| 189 | + name = _name([logger3]) |
| 190 | + assert name == "foo" |
| 191 | + name = _name(loggers) |
| 192 | + assert name == "foo_bar_baz" |
| 193 | + name = _name(loggers, "-") |
| 194 | + assert name == "foo-bar-baz" |
| 195 | + |
| 196 | + |
| 197 | +def test_version(tmpdir): |
| 198 | + """Verify names of loggers are concatenated properly.""" |
| 199 | + logger1 = CSVLogger(tmpdir, version=0) |
| 200 | + logger2 = CSVLogger(tmpdir, version=2) |
| 201 | + logger3 = CSVLogger(tmpdir, version=1) |
| 202 | + logger4 = CSVLogger(tmpdir, version=0) |
| 203 | + loggers = [logger1, logger2, logger3, logger4] |
| 204 | + version = _version([]) |
| 205 | + assert version == "" |
| 206 | + version = _version([logger3]) |
| 207 | + assert version == 1 |
| 208 | + version = _version(loggers) |
| 209 | + assert version == "0_2_1" |
| 210 | + version = _version(loggers, "-") |
| 211 | + assert version == "0-2-1" |
0 commit comments