Skip to content

Commit d06c0fa

Browse files
williamFalconSeanNaren
authored andcommitted
Adds shortcut for path to log (#4573)
* added log_dir shortcut to trainer properties for writing logs * added log_dir shortcut * added log_dir shortcut * added log_dir shortcut * added log_dir shortcut * added log_dir shortcut * added log_dir shortcut * added log_dir shortcut * added log_dir shortcut (cherry picked from commit 09a5169)
1 parent c007e32 commit d06c0fa

File tree

7 files changed

+170
-1
lines changed

7 files changed

+170
-1
lines changed

.pyrightconfig.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"pytorch_lightning/trainer/training_tricks.py",
3131
"pytorch_lightning/trainer/batch_size_scaling.py",
3232
"pytorch_lightning/trainer/distrib_data_parallel.py",
33+
"pytorch_lightning/trainer/properties.py",
3334
"pytorch_lightning/trainer/lr_scheduler_connector.py",
3435
"pytorch_lightning/trainer/training_loop_temp.py",
3536
"pytorch_lightning/trainer/connectors/checkpoint_connector.py",

docs/source/loggers.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ but you can pass to the :class:`~pytorch_lightning.trainer.trainer.Trainer` any
1919

2020
Read more about :ref:`logging` options.
2121

22+
To log arbitrary artifacts like images or audio samples use the `trainer.log_dir` property to resolve
23+
the path.
24+
25+
.. code-block:: python
26+
27+
def training_step(self, batch, batch_idx):
28+
img = ...
29+
log_image(img, self.trainer.log_dir)
30+
2231
Comet.ml
2332
========
2433

docs/source/trainer.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,17 @@ The metrics sent to the logger (visualizer).
16661666
logged_metrics = trainer.logged_metrics
16671667
assert logged_metrics['a_val'] == 2
16681668
1669+
log_dir
1670+
*******
1671+
The directory for the current experiment. Use this to save images to, etc...
1672+
1673+
.. code-block:: python
1674+
1675+
def training_step(self, batch, batch_idx):
1676+
img = ...
1677+
save_img(img, self.trainer.log_dir)
1678+
1679+
16691680
16701681
is_global_zero
16711682
**************

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def __init_ckpt_dir(self, filepath, dirpath, filename, save_top_k):
292292
if dirpath and self._fs.protocol == 'file':
293293
dirpath = os.path.realpath(dirpath)
294294

295-
self.dirpath = dirpath or None
295+
self.dirpath: Union[str, None] = dirpath or None
296296
self.filename = filename or None
297297

298298
def __init_monitor_mode(self, monitor, mode):

pytorch_lightning/trainer/properties.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from pytorch_lightning.utilities import argparse_utils
2727
from pytorch_lightning.utilities.cloud_io import get_filesystem
2828
from pytorch_lightning.utilities.model_utils import is_overridden
29+
from pytorch_lightning.accelerators.accelerator import Accelerator
30+
from pytorch_lightning.loggers.base import LightningLoggerBase
31+
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
2932

3033

3134
class TrainerProperties(ABC):
@@ -44,10 +47,30 @@ class TrainerProperties(ABC):
4447
limit_val_batches: int
4548
_default_root_dir: str
4649
_weights_save_path: str
50+
default_root_path: str
51+
accelerator_backend: Accelerator
52+
logger: LightningLoggerBase
4753
model_connector: ModelConnector
4854
checkpoint_connector: CheckpointConnector
4955
callbacks: List[Callback]
5056

57+
@property
58+
def log_dir(self):
59+
if self.checkpoint_callback is not None:
60+
dir = self.checkpoint_callback.dirpath
61+
dir = os.path.split(dir)[0]
62+
elif self.logger is not None:
63+
if isinstance(self.logger, TensorBoardLogger):
64+
dir = self.logger.log_dir
65+
else:
66+
dir = self.logger.save_dir
67+
else:
68+
dir = self._default_root_dir
69+
70+
if self.accelerator_backend is not None:
71+
dir = self.accelerator_backend.broadcast(dir)
72+
return dir
73+
5174
@property
5275
def use_amp(self) -> bool:
5376
return self.precision == 16

tests/trainer/properties/__init__.py

Whitespace-only changes.
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import torch
16+
import pytest
17+
from tests.base.boring_model import BoringModel, RandomDataset
18+
from pytorch_lightning import Trainer
19+
from pytorch_lightning.utilities import APEX_AVAILABLE
20+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
21+
22+
23+
def test_logdir(tmpdir):
24+
"""
25+
Tests that the path is correct when checkpoint and loggers are used
26+
"""
27+
class TestModel(BoringModel):
28+
def training_step(self, batch, batch_idx):
29+
output = self.layer(batch)
30+
loss = self.loss(batch, output)
31+
32+
expected = os.path.join(self.trainer.default_root_dir, 'lightning_logs', 'version_0')
33+
assert self.trainer.log_dir == expected
34+
return {"loss": loss}
35+
36+
model = TestModel()
37+
38+
limit_train_batches = 2
39+
trainer = Trainer(
40+
default_root_dir=tmpdir,
41+
limit_train_batches=limit_train_batches,
42+
limit_val_batches=2,
43+
max_epochs=1,
44+
)
45+
46+
trainer.fit(model)
47+
48+
49+
def test_logdir_no_checkpoint_cb(tmpdir):
50+
"""
51+
Tests that the path is correct with no checkpoint
52+
"""
53+
class TestModel(BoringModel):
54+
def training_step(self, batch, batch_idx):
55+
output = self.layer(batch)
56+
loss = self.loss(batch, output)
57+
expected = os.path.join(self.trainer.default_root_dir, 'lightning_logs', 'version_0')
58+
assert self.trainer.log_dir == expected
59+
return {"loss": loss}
60+
61+
model = TestModel()
62+
63+
limit_train_batches = 2
64+
trainer = Trainer(
65+
default_root_dir=tmpdir,
66+
limit_train_batches=limit_train_batches,
67+
limit_val_batches=2,
68+
max_epochs=1,
69+
checkpoint_callback=False
70+
)
71+
72+
trainer.fit(model)
73+
74+
75+
def test_logdir_no_logger(tmpdir):
76+
"""
77+
Tests that the path is correct even when there is no logger
78+
"""
79+
class TestModel(BoringModel):
80+
def training_step(self, batch, batch_idx):
81+
output = self.layer(batch)
82+
loss = self.loss(batch, output)
83+
expected = os.path.join(self.trainer.default_root_dir)
84+
assert self.trainer.log_dir == expected
85+
return {"loss": loss}
86+
87+
model = TestModel()
88+
89+
limit_train_batches = 2
90+
trainer = Trainer(
91+
default_root_dir=tmpdir,
92+
limit_train_batches=limit_train_batches,
93+
limit_val_batches=2,
94+
max_epochs=1,
95+
logger=False,
96+
)
97+
98+
trainer.fit(model)
99+
100+
101+
def test_logdir_no_logger_no_checkpoint(tmpdir):
102+
"""
103+
Tests that the path is correct even when there is no logger
104+
"""
105+
class TestModel(BoringModel):
106+
def training_step(self, batch, batch_idx):
107+
output = self.layer(batch)
108+
loss = self.loss(batch, output)
109+
expected = os.path.join(self.trainer.default_root_dir)
110+
assert self.trainer.log_dir == expected
111+
return {"loss": loss}
112+
113+
model = TestModel()
114+
115+
limit_train_batches = 2
116+
trainer = Trainer(
117+
default_root_dir=tmpdir,
118+
limit_train_batches=limit_train_batches,
119+
limit_val_batches=2,
120+
max_epochs=1,
121+
logger=False,
122+
checkpoint_callback=False
123+
)
124+
125+
trainer.fit(model)

0 commit comments

Comments
 (0)