Skip to content

Commit cc76d8f

Browse files
carmoccaawaelchli
authored andcommitted
Fix references for ResultCollection.extra and improve str and repr (#8622)
1 parent f64842c commit cc76d8f

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,8 @@ def minimize(self) -> Optional[torch.Tensor]:
377377

378378
@minimize.setter
379379
def minimize(self, loss: Optional[torch.Tensor]) -> None:
380-
if loss is not None:
381-
if not isinstance(loss, torch.Tensor):
382-
raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}")
380+
if loss is not None and not isinstance(loss, torch.Tensor):
381+
raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}")
383382
self._minimize = loss
384383

385384
@property
@@ -388,7 +387,8 @@ def extra(self) -> Dict[str, Any]:
388387
Extras are any keys other than the loss returned by
389388
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`
390389
"""
391-
return self.get("_extra", {})
390+
self.setdefault("_extra", {})
391+
return self["_extra"]
392392

393393
@extra.setter
394394
def extra(self, extra: Dict[str, Any]) -> None:
@@ -605,7 +605,16 @@ def cpu(self) -> "ResultCollection":
605605
return self.to(device="cpu")
606606

607607
def __str__(self) -> str:
608-
return f"{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})"
608+
# sample output: `ResultCollection(minimize=1.23, {})`
609+
minimize = f"minimize={self.minimize}, " if self.minimize is not None else ""
610+
# remove empty values
611+
self_str = str({k: v for k, v in self.items() if v})
612+
return f"{self.__class__.__name__}({minimize}{self_str})"
613+
614+
def __repr__(self):
615+
# sample output: `{True, cpu, minimize=tensor(1.23 grad_fn=<SumBackward0>), {'_extra': {}}}`
616+
minimize = f"minimize={repr(self.minimize)}, " if self.minimize is not None else ""
617+
return f"{{{self.training}, {repr(self.device)}, " + minimize + f"{super().__repr__()}}}"
609618

610619
def __getstate__(self, drop_value: bool = True) -> dict:
611620
d = self.__dict__.copy()

tests/core/test_metric_result_integration.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,28 @@ def test_result_metric_integration():
132132

133133
assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}
134134

135+
result.minimize = torch.tensor(1.0)
136+
result.extra = {}
135137
assert str(result) == (
136-
"ResultCollection(True, cpu, {"
138+
"ResultCollection("
139+
"minimize=1.0, "
140+
"{"
137141
"'h.a': ResultMetric('a', value=DummyMetric()), "
138142
"'h.b': ResultMetric('b', value=DummyMetric()), "
139143
"'h.c': ResultMetric('c', value=DummyMetric())"
140144
"})"
141145
)
146+
assert repr(result) == (
147+
"{"
148+
"True, "
149+
"device(type='cpu'), "
150+
"minimize=tensor(1.), "
151+
"{'h.a': ResultMetric('a', value=DummyMetric()), "
152+
"'h.b': ResultMetric('b', value=DummyMetric()), "
153+
"'h.c': ResultMetric('c', value=DummyMetric()), "
154+
"'_extra': {}}"
155+
"}"
156+
)
142157

143158

144159
def test_result_collection_simple_loop():
@@ -332,3 +347,9 @@ def on_save_checkpoint(self, checkpoint) -> None:
332347
gpus=1 if device == "cuda" else 0,
333348
)
334349
trainer.fit(model)
350+
351+
352+
def test_result_collection_extra_reference():
353+
"""Unit-test to check that the `extra` dict reference is properly set."""
354+
rc = ResultCollection(True)
355+
assert rc.extra is rc["_extra"]

0 commit comments

Comments
 (0)