Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pyrightconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"pytorch_lightning/trainer/training_tricks.py",
"pytorch_lightning/trainer/batch_size_scaling.py",
"pytorch_lightning/trainer/distrib_data_parallel.py",
"pytorch_lightning/trainer/properties.py",
"pytorch_lightning/trainer/lr_scheduler_connector.py",
"pytorch_lightning/trainer/training_loop_temp.py",
"pytorch_lightning/trainer/connectors/checkpoint_connector.py",
Expand Down
9 changes: 9 additions & 0 deletions docs/source/loggers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ but you can pass to the :class:`~pytorch_lightning.trainer.trainer.Trainer` any

Read more about :ref:`logging` options.

To log arbitrary artifacts like images or audio samples use the `trainer.log_dir` property to resolve
the path.

.. code-block:: python

def training_step(self, batch, batch_idx):
img = ...
log_image(img, self.trainer.log_dir)

Comet.ml
========

Expand Down
11 changes: 11 additions & 0 deletions docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1666,6 +1666,17 @@ The metrics sent to the logger (visualizer).
logged_metrics = trainer.logged_metrics
assert logged_metrics['a_val'] == 2

log_dir
*******
The directory for the current experiment. Use this to save images to, etc...

.. code-block:: python

def training_step(self, batch, batch_idx):
img = ...
save_img(img, self.trainer.log_dir)



is_global_zero
**************
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __init_ckpt_dir(self, filepath, dirpath, filename, save_top_k):
if dirpath and self._fs.protocol == 'file':
dirpath = os.path.realpath(dirpath)

self.dirpath = dirpath or None
self.dirpath: Union[str, None] = dirpath or None
self.filename = filename or None

def __init_monitor_mode(self, monitor, mode):
Expand Down
23 changes: 23 additions & 0 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from pytorch_lightning.utilities import argparse_utils
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger


class TrainerProperties(ABC):
Expand All @@ -44,10 +47,30 @@ class TrainerProperties(ABC):
limit_val_batches: int
_default_root_dir: str
_weights_save_path: str
default_root_path: str
accelerator_backend: Accelerator
logger: LightningLoggerBase
model_connector: ModelConnector
checkpoint_connector: CheckpointConnector
callbacks: List[Callback]

@property
def log_dir(self):
if self.checkpoint_callback is not None:
dir = self.checkpoint_callback.dirpath
dir = os.path.split(dir)[0]
elif self.logger is not None:
if isinstance(self.logger, TensorBoardLogger):
dir = self.logger.log_dir
else:
dir = self.logger.save_dir
else:
dir = self._default_root_dir

if self.accelerator_backend is not None:
dir = self.accelerator_backend.broadcast(dir)
return dir

@property
def use_amp(self) -> bool:
return self.precision == 16
Expand Down
Empty file.
125 changes: 125 additions & 0 deletions tests/trainer/properties/log_dir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import pytest
from tests.base.boring_model import BoringModel, RandomDataset
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import APEX_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def test_logdir(tmpdir):
"""
Tests that the path is correct when checkpoint and loggers are used
"""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)

expected = os.path.join(self.trainer.default_root_dir, 'lightning_logs', 'version_0')
assert self.trainer.log_dir == expected
return {"loss": loss}

model = TestModel()

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
)

trainer.fit(model)


def test_logdir_no_checkpoint_cb(tmpdir):
"""
Tests that the path is correct with no checkpoint
"""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
expected = os.path.join(self.trainer.default_root_dir, 'lightning_logs', 'version_0')
assert self.trainer.log_dir == expected
return {"loss": loss}

model = TestModel()

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
checkpoint_callback=False
)

trainer.fit(model)


def test_logdir_no_logger(tmpdir):
"""
Tests that the path is correct even when there is no logger
"""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
expected = os.path.join(self.trainer.default_root_dir)
assert self.trainer.log_dir == expected
return {"loss": loss}

model = TestModel()

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
logger=False,
)

trainer.fit(model)


def test_logdir_no_logger_no_checkpoint(tmpdir):
"""
Tests that the path is correct even when there is no logger
"""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
expected = os.path.join(self.trainer.default_root_dir)
assert self.trainer.log_dir == expected
return {"loss": loss}

model = TestModel()

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
logger=False,
checkpoint_callback=False
)

trainer.fit(model)