Skip to content

Commit b1e38bf

Browse files
authored
Better errors for logging corner cases (#13164)
1 parent a475010 commit b1e38bf

File tree

5 files changed

+75
-12
lines changed

5 files changed

+75
-12
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4949
- Include a version suffix for new "last" checkpoints of later runs in the same directory ([#12902](https://github.com/PyTorchLightning/pytorch-lightning/pull/12902))
5050

5151

52+
- Show a better error message when a Metric that does not return a Tensor is logged ([#13164](https://github.com/PyTorchLightning/pytorch-lightning/pull/13164))
53+
54+
5255
- Added missing `predict_dataset` argument in `LightningDataModule.from_datasets` to create predict dataloaders ([#12942](https://github.com/PyTorchLightning/pytorch-lightning/pull/12942))
5356

5457

@@ -123,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
123126
- `DataLoader` instantiated inside a `*_dataloader` hook will not set the passed arguments as attributes anymore ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))
124127

125128

129+
- When a multi-element tensor is logged, an error is now raised instead of silently taking the mean of all elements ([#13164](https://github.com/PyTorchLightning/pytorch-lightning/pull/13164))
130+
131+
126132
- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))
127133

128134

src/pytorch_lightning/core/module.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def log(
396396
)
397397

398398
value = apply_to_collection(value, numbers.Number, self.__to_tensor)
399+
apply_to_collection(value, torch.Tensor, self.__check_numel_1, name)
399400

400401
if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name):
401402
# if we started a new epoch (running its first batch) the hook name has changed
@@ -518,11 +519,10 @@ def log_dict(
518519
)
519520

520521
@staticmethod
521-
def __check_not_nested(value: dict, name: str) -> dict:
522+
def __check_not_nested(value: dict, name: str) -> None:
522523
# self-imposed restriction. for simplicity
523524
if any(isinstance(v, dict) for v in value.values()):
524525
raise ValueError(f"`self.log({name}, {value})` was called, but nested dictionaries cannot be logged")
525-
return value
526526

527527
@staticmethod
528528
def __check_allowed(v: Any, name: str, value: Any) -> None:
@@ -531,6 +531,14 @@ def __check_allowed(v: Any, name: str, value: Any) -> None:
531531
def __to_tensor(self, value: numbers.Number) -> Tensor:
532532
return torch.tensor(value, device=self.device)
533533

534+
@staticmethod
535+
def __check_numel_1(value: torch.Tensor, name: str) -> None:
536+
if not torch.numel(value) == 1:
537+
raise ValueError(
538+
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
539+
f" You can try doing `self.log({name}, {value}.mean())`"
540+
)
541+
534542
def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None:
535543
"""Override this method to change the default behaviour of ``log_grad_norm``.
536544

src/pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,12 @@ def update(self, value: _IN_METRIC, batch_size: int) -> None: # type: ignore[ov
244244
# perform accumulation with reduction
245245
if self.meta.is_mean_reduction:
246246
# do not use `+=` as it doesn't do type promotion
247-
self.value = self.value + value.mean() * batch_size
247+
self.value = self.value + value * batch_size
248248
self.cumulated_batch_size = self.cumulated_batch_size + batch_size
249249
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
250-
self.value = self.meta.reduce_fx(self.value, value.mean())
250+
self.value = self.meta.reduce_fx(self.value, value)
251251
elif self.meta.is_sum_reduction:
252-
self.value = self.value + value.mean()
252+
self.value = self.value + value
253253
else:
254254
value = cast(Metric, value)
255255
self.value = value
@@ -528,8 +528,14 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
528528
result_metric.compute()
529529
result_metric.meta.sync.should = should
530530
cache = result_metric._computed
531-
if cache is not None and not result_metric.meta.enable_graph:
532-
return cache.detach()
531+
if cache is not None:
532+
if not isinstance(cache, torch.Tensor):
533+
raise ValueError(
534+
f"The `.compute()` return of the metric logged as {result_metric.meta.name!r} must be a tensor."
535+
f" Found {cache}"
536+
)
537+
if not result_metric.meta.enable_graph:
538+
return cache.detach()
533539
return cache
534540

535541
def valid_items(self) -> Generator:

tests/tests_pytorch/core/test_metric_result_integration.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,3 +629,33 @@ def test_result_metric_max_min(reduce_fx, expected):
629629
rm = _ResultMetric(metadata, is_tensor=True)
630630
rm.update(torch.tensor(expected), 1)
631631
assert rm.compute() == expected
632+
633+
634+
def test_compute_not_a_tensor_raises():
635+
class RandomMetric(Metric):
636+
def update(self):
637+
pass
638+
639+
def compute(self):
640+
return torch.tensor(1.0), torch.tensor(2.0)
641+
642+
class MyModel(BoringModel):
643+
def __init__(self):
644+
super().__init__()
645+
self.metric = RandomMetric()
646+
647+
def on_train_start(self):
648+
self.log("foo", self.metric)
649+
650+
model = MyModel()
651+
trainer = Trainer(
652+
limit_train_batches=1,
653+
limit_val_batches=0,
654+
max_epochs=1,
655+
enable_progress_bar=False,
656+
enable_checkpointing=False,
657+
logger=False,
658+
enable_model_summary=False,
659+
)
660+
with pytest.raises(ValueError, match=r"compute\(\)` return of.*foo' must be a tensor"):
661+
trainer.fit(model)

tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,16 @@ class TestModel(BoringModel):
631631
def training_step(self, batch, batch_idx):
632632
self.log("foo/dataloader_idx_0", -1)
633633

634-
trainer = Trainer(default_root_dir=tmpdir)
634+
trainer = Trainer(
635+
default_root_dir=tmpdir,
636+
limit_train_batches=1,
637+
limit_val_batches=0,
638+
max_epochs=1,
639+
enable_progress_bar=False,
640+
enable_checkpointing=False,
641+
logger=False,
642+
enable_model_summary=False,
643+
)
635644
model = TestModel()
636645
with pytest.raises(MisconfigurationException, match="`self.log` with the key `foo/dataloader_idx_0`"):
637646
trainer.fit(model)
@@ -640,7 +649,6 @@ class TestModel(BoringModel):
640649
def training_step(self, batch, batch_idx):
641650
self.log("foo", Accuracy())
642651

643-
trainer = Trainer(default_root_dir=tmpdir)
644652
model = TestModel()
645653
with pytest.raises(MisconfigurationException, match="fix this by setting an attribute for the metric in your"):
646654
trainer.fit(model)
@@ -653,7 +661,6 @@ def __init__(self):
653661
def training_step(self, batch, batch_idx):
654662
self.log("foo", Accuracy())
655663

656-
trainer = Trainer(default_root_dir=tmpdir)
657664
model = TestModel()
658665
with pytest.raises(
659666
MisconfigurationException,
@@ -667,7 +674,6 @@ def training_step(self, *args):
667674
self.log("foo", -1, prog_bar=True)
668675
return super().training_step(*args)
669676

670-
trainer = Trainer(default_root_dir=tmpdir)
671677
model = TestModel()
672678
with pytest.raises(MisconfigurationException, match=r"self.log\(foo, ...\)` twice in `training_step`"):
673679
trainer.fit(model)
@@ -677,11 +683,18 @@ def training_step(self, *args):
677683
self.log("foo", -1, reduce_fx=torch.argmax)
678684
return super().training_step(*args)
679685

680-
trainer = Trainer(default_root_dir=tmpdir)
681686
model = TestModel()
682687
with pytest.raises(MisconfigurationException, match=r"reduce_fx={min,max,mean,sum}\)` are supported"):
683688
trainer.fit(model)
684689

690+
class TestModel(BoringModel):
691+
def on_train_start(self):
692+
self.log("foo", torch.tensor([1.0, 2.0]))
693+
694+
model = TestModel()
695+
with pytest.raises(ValueError, match="tensor must have a single element"):
696+
trainer.fit(model)
697+
685698

686699
def test_sanity_metrics_are_reset(tmpdir):
687700
class TestModel(BoringModel):

0 commit comments

Comments
 (0)