Skip to content

Commit 2c3cc74

Browse files
authored
Warn when self.log(..., logger=True) is called without a logger (#15814)
1 parent 4e64391 commit 2c3cc74

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727
- Added a check to validate that wrapped FSDP models are used while initializing optimizers ([#15301](https://github.com/Lightning-AI/lightning/pull/15301))
2828

2929

30+
- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))
31+
3032
### Changed
3133

3234
- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))

src/pytorch_lightning/core/module.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def log(
322322
name: str,
323323
value: _METRIC_COLLECTION,
324324
prog_bar: bool = False,
325-
logger: bool = True,
325+
logger: Optional[bool] = None,
326326
on_step: Optional[bool] = None,
327327
on_epoch: Optional[bool] = None,
328328
reduce_fx: Union[str, Callable] = "mean",
@@ -438,6 +438,16 @@ def log(
438438
"With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
439439
)
440440

441+
if logger and self.trainer.logger is None:
442+
rank_zero_warn(
443+
f"You called `self.log({name!r}, ..., logger=True)` but have no logger configured. You can enable one"
444+
" by doing `Trainer(logger=ALogger(...))`"
445+
)
446+
if logger is None:
447+
# we could set false here if there's no configured logger, however, we still need to compute the "logged"
448+
# metrics anyway because that's what the evaluation loops use as return value
449+
logger = True
450+
441451
results.log(
442452
self._current_fx_name,
443453
name,
@@ -463,7 +473,7 @@ def log_dict(
463473
self,
464474
dictionary: Mapping[str, _METRIC_COLLECTION],
465475
prog_bar: bool = False,
466-
logger: bool = True,
476+
logger: Optional[bool] = None,
467477
on_step: Optional[bool] = None,
468478
on_epoch: Optional[bool] = None,
469479
reduce_fx: Union[str, Callable] = "mean",

tests/tests_pytorch/trainer/logging_/test_logger_connector.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pytorch_lightning import LightningModule
2323
from pytorch_lightning.callbacks.callback import Callback
2424
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
25+
from pytorch_lightning.loggers import CSVLogger
2526
from pytorch_lightning.trainer import Trainer
2627
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
2728
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
@@ -605,15 +606,25 @@ def test_result_collection_on_tensor_with_mean_reduction():
605606
}
606607

607608

608-
def test_logged_metrics_has_logged_epoch_value(tmpdir):
609+
@pytest.mark.parametrize("logger", (False, True))
610+
def test_logged_metrics_has_logged_epoch_value(tmpdir, logger):
609611
class TestModel(BoringModel):
610612
def training_step(self, batch, batch_idx):
611613
self.log("epoch", -batch_idx, logger=True)
612614
return super().training_step(batch, batch_idx)
613615

614616
model = TestModel()
615-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
616-
trainer.fit(model)
617+
trainer_kwargs = dict(
618+
default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=0, max_epochs=1, logger=False
619+
)
620+
if logger:
621+
trainer_kwargs["logger"] = CSVLogger(tmpdir)
622+
trainer = Trainer(**trainer_kwargs)
623+
if not logger:
624+
with pytest.warns(match=r"log\('epoch', ..., logger=True\)` but have no logger"):
625+
trainer.fit(model)
626+
else:
627+
trainer.fit(model)
617628

618629
# should not get overridden if logged manually
619630
assert trainer.logged_metrics == {"epoch": -1}

0 commit comments

Comments
 (0)