Skip to content

Commit 73fca23

Browse files
authored
Add typing for ResultCollection [3/3] (#9271)
1 parent 50198d7 commit 73fca23

File tree

5 files changed

+53
-66
lines changed

5 files changed

+53
-66
lines changed

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ module = [
6464
"pytorch_lightning.callbacks.pruning",
6565
"pytorch_lightning.loops.closure",
6666
"pytorch_lightning.trainer.evaluation_loop",
67-
"pytorch_lightning.trainer.connectors.logger_connector.fx_validator",
68-
"pytorch_lightning.trainer.connectors.logger_connector.logger_connector",
67+
"pytorch_lightning.trainer.connectors.logger_connector.*",
6968
"pytorch_lightning.trainer.progress",
7069
"pytorch_lightning.tuner.auto_gpu_select",
7170
"pytorch_lightning.utilities.apply_func",

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 32 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from collections.abc import Generator
1515
from dataclasses import asdict, dataclass, replace
1616
from functools import partial, wraps
17-
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

1919
import torch
2020
from torchmetrics import Metric
@@ -24,16 +24,13 @@
2424
from pytorch_lightning.utilities import rank_zero_warn
2525
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device
2626
from pytorch_lightning.utilities.data import extract_batch_size
27-
from pytorch_lightning.utilities.enums import LightningEnum
2827
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2928
from pytorch_lightning.utilities.memory import recursive_detach
3029
from pytorch_lightning.utilities.metrics import metrics_to_scalars
3130
from pytorch_lightning.utilities.warnings import WarningCache
3231

33-
# re-define the ones from pytorch_lightning.utilities.types without the `Number` type
3432
# TODO(@tchaton): Typing-pickle issue on python<3.7 (https://github.com/cloudpipe/cloudpickle/pull/318)
35-
_METRIC = Any # Union[Metric, torch.Tensor]
36-
_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]]
33+
_IN_METRIC = Any # Union[Metric, torch.Tensor] # Do not include scalars as they were converted to tensors
3734
_OUT_METRIC = Union[torch.Tensor, Dict[str, torch.Tensor]]
3835
_PBAR_METRIC = Union[float, Dict[str, float]]
3936
_OUT_DICT = Dict[str, _OUT_METRIC]
@@ -49,12 +46,6 @@ class _METRICS(TypedDict):
4946
warning_cache = WarningCache()
5047

5148

52-
class MetricSource(LightningEnum):
53-
CALLBACK = "callback"
54-
PBAR = "pbar"
55-
LOG = "log"
56-
57-
5849
@dataclass
5950
class _Sync:
6051
fn: Optional[Callable] = None
@@ -80,14 +71,15 @@ def _generate_sync_fn(self) -> None:
8071
"""Used to compute the syncing function and cache it."""
8172
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
8273
# save the function as `_fn` as the meta are being re-created and the object references need to match.
83-
self._fn = partial(fn, reduce_op=self.op, group=self.group)
74+
# ignore typing, bad support for `partial`: mypy/issues/1484
75+
self._fn: Callable = partial(fn, reduce_op=self.op, group=self.group) # type: ignore [arg-type]
8476

8577
@property
8678
def __call__(self) -> Any:
8779
return self._fn
8880

8981
@staticmethod
90-
def no_op(value: Any, *_, **__) -> Any:
82+
def no_op(value: Any, *_: Any, **__: Any) -> Any:
9183
return value
9284

9385

@@ -125,7 +117,8 @@ def _parse_reduce_fx(self) -> None:
125117
raise MisconfigurationException(error)
126118

127119
@property
128-
def sync(self) -> Optional[_Sync]:
120+
def sync(self) -> _Sync:
121+
assert self._sync is not None
129122
return self._sync
130123

131124
@sync.setter
@@ -196,7 +189,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
196189
if self.meta.is_mean_reduction:
197190
self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float), dist_reduce_fx=torch.sum)
198191

199-
def update(self, value: _METRIC, batch_size: torch.Tensor) -> None:
192+
def update(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None:
200193
if self.is_tensor:
201194
value = value.float()
202195
# performance: no need to accumulate on values only logged on_step
@@ -232,7 +225,7 @@ def reset(self) -> None:
232225
self.value.reset()
233226
self.has_reset = True
234227

235-
def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None:
228+
def forward(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None:
236229
if self.meta.enable_graph:
237230
with torch.no_grad():
238231
self.update(value, batch_size)
@@ -243,7 +236,7 @@ def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None:
243236
def _wrap_compute(self, compute: Any) -> Any:
244237
# Override to avoid syncing - we handle it ourselves.
245238
@wraps(compute)
246-
def wrapped_func(*args, **kwargs):
239+
def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]:
247240
if not self._update_called:
248241
rank_zero_warn(
249242
f"The ``compute`` method of metric {self.__class__.__name__}"
@@ -253,8 +246,8 @@ def wrapped_func(*args, **kwargs):
253246
)
254247

255248
# return cached value
256-
if self._computed is not None:
257-
return self._computed
249+
if self._computed is not None: # type: ignore
250+
return self._computed # type: ignore
258251
self._computed = compute(*args, **kwargs)
259252
return self._computed
260253

@@ -293,7 +286,7 @@ def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "Resul
293286
result_metric.__setstate__(state, sync_fn=sync_fn)
294287
return result_metric
295288

296-
def to(self, *args: Any, **kwargs: Any) -> "DeviceDtypeModuleMixin":
289+
def to(self, *args: Any, **kwargs: Any) -> "ResultMetric":
297290
self.__dict__.update(
298291
apply_to_collection(self.__dict__, (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)
299292
)
@@ -309,7 +302,7 @@ class ResultMetricCollection(dict):
309302
with the same metadata.
310303
"""
311304

312-
def __init__(self, *args) -> None:
305+
def __init__(self, *args: Any) -> None:
313306
super().__init__(*args)
314307

315308
@property
@@ -320,20 +313,12 @@ def __getstate__(self, drop_value: bool = False) -> dict:
320313
def getstate(item: ResultMetric) -> dict:
321314
return item.__getstate__(drop_value=drop_value)
322315

323-
items = apply_to_collection(dict(self), (ResultMetric, ResultMetricCollection), getstate)
316+
items = apply_to_collection(dict(self), ResultMetric, getstate)
324317
return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__}
325318

326319
def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None:
327-
def setstate(item: dict) -> Union[Dict[str, ResultMetric], ResultMetric, Any]:
328-
# recurse through dictionaries to set the state. can't use `apply_to_collection`
329-
# as it does not recurse items of the same type.
330-
if not isinstance(item, dict):
331-
return item
332-
if item.get("_class") == ResultMetric.__name__:
333-
return ResultMetric._reconstruct(item, sync_fn=sync_fn)
334-
return {k: setstate(v) for k, v in item.items()}
335-
336-
items = setstate(state["items"])
320+
# can't use `apply_to_collection` as it does not recurse items of the same type
321+
items = {k: ResultMetric._reconstruct(v, sync_fn=sync_fn) for k, v in state["items"].items()}
337322
self.update(items)
338323

339324
@classmethod
@@ -343,6 +328,9 @@ def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "Resul
343328
return rmc
344329

345330

331+
_METRIC_COLLECTION = Union[_IN_METRIC, ResultMetricCollection]
332+
333+
346334
class ResultCollection(dict):
347335
"""
348336
Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or
@@ -364,7 +352,7 @@ class ResultCollection(dict):
364352
def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None:
365353
super().__init__()
366354
self.training = training
367-
self._minimize = None
355+
self._minimize: Optional[torch.Tensor] = None
368356
self._batch_size = torch.tensor(1, device=device)
369357
self.device: Optional[Union[str, torch.device]] = device
370358

@@ -413,7 +401,7 @@ def extra(self) -> Dict[str, Any]:
413401

414402
@extra.setter
415403
def extra(self, extra: Dict[str, Any]) -> None:
416-
def check_fn(v):
404+
def check_fn(v: torch.Tensor) -> torch.Tensor:
417405
if v.grad_fn is not None:
418406
warning_cache.deprecation(
419407
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
@@ -494,7 +482,7 @@ def log(
494482
def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
495483
"""Create one ResultMetric object per value. Value can be provided as a nested collection"""
496484

497-
def fn(v: _METRIC) -> ResultMetric:
485+
def fn(v: _IN_METRIC) -> ResultMetric:
498486
metric = ResultMetric(meta, isinstance(v, torch.Tensor))
499487
return metric.to(self.device)
500488

@@ -504,7 +492,7 @@ def fn(v: _METRIC) -> ResultMetric:
504492
self[key] = value
505493

506494
def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None:
507-
def fn(result_metric, v):
495+
def fn(result_metric: ResultMetric, v: ResultMetric) -> None:
508496
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
509497
result_metric.forward(v.to(self.device), self.batch_size)
510498
result_metric.has_reset = False
@@ -545,7 +533,7 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str,
545533
return name, forked_name
546534

547535
def metrics(self, on_step: bool) -> _METRICS:
548-
metrics = {k: {} for k in MetricSource}
536+
metrics = _METRICS(callback={}, log={}, pbar={})
549537

550538
for _, result_metric in self.valid_items():
551539

@@ -559,7 +547,7 @@ def metrics(self, on_step: bool) -> _METRICS:
559547
# check if the collection is empty
560548
has_tensor = False
561549

562-
def any_tensor(_):
550+
def any_tensor(_: Any) -> None:
563551
nonlocal has_tensor
564552
has_tensor = True
565553

@@ -571,16 +559,16 @@ def any_tensor(_):
571559

572560
# populate logging metrics
573561
if result_metric.meta.logger:
574-
metrics[MetricSource.LOG][forked_name] = value
562+
metrics["log"][forked_name] = value
575563

576564
# populate callback metrics. callback metrics don't take `_step` forked metrics
577565
if self.training or result_metric.meta.on_epoch and not on_step:
578-
metrics[MetricSource.CALLBACK][name] = value
579-
metrics[MetricSource.CALLBACK][forked_name] = value
566+
metrics["callback"][name] = value
567+
metrics["callback"][forked_name] = value
580568

581569
# populate progress_bar metrics. convert tensors to numbers
582570
if result_metric.meta.prog_bar:
583-
metrics[MetricSource.PBAR][forked_name] = metrics_to_scalars(value)
571+
metrics["pbar"][forked_name] = metrics_to_scalars(value)
584572

585573
return metrics
586574

@@ -609,7 +597,7 @@ def extract_batch_size(self, batch: Any) -> None:
609597
except RecursionError:
610598
self.batch_size = 1
611599

612-
def to(self, *args, **kwargs) -> "ResultCollection":
600+
def to(self, *args: Any, **kwargs: Any) -> "ResultCollection":
613601
"""Move all data to the given device."""
614602
self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs))
615603

@@ -641,7 +629,7 @@ def __str__(self) -> str:
641629
self_str = str({k: v for k, v in self.items() if v})
642630
return f"{self.__class__.__name__}({minimize}{self_str})"
643631

644-
def __repr__(self):
632+
def __repr__(self) -> str:
645633
# sample output: `{True, cpu, minimize=tensor(1.23 grad_fn=<SumBackward0>), {'_extra': {}}}`
646634
minimize = f"minimize={repr(self.minimize)}, " if self.minimize is not None else ""
647635
return f"{{{self.training}, {repr(self.device)}, " + minimize + f"{super().__repr__()}}}"

pytorch_lightning/utilities/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
- Do not include any `_TYPE` suffix
1717
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`)
1818
"""
19-
from numbers import Number
2019
from pathlib import Path
2120
from typing import Any, Dict, Iterator, List, Mapping, Sequence, Type, Union
2221

@@ -25,7 +24,8 @@
2524
from torch.utils.data import DataLoader
2625
from torchmetrics import Metric
2726

28-
_METRIC = Union[Metric, torch.Tensor, Number]
27+
_NUMBER = Union[int, float]
28+
_METRIC = Union[Metric, torch.Tensor, _NUMBER]
2929
_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]]
3030
STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]]
3131
EPOCH_OUTPUT = List[STEP_OUTPUT]

tests/core/test_metric_result_integration.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import tests.helpers.utils as tutils
2828
from pytorch_lightning import Trainer
2929
from pytorch_lightning.callbacks import ModelCheckpoint
30-
from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection
30+
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
3131
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_7
3232
from tests.helpers import BoringModel
3333
from tests.helpers.runif import RunIf
@@ -81,10 +81,10 @@ def _ddp_test_fn(rank, worldsize):
8181
result.log("h", "b", metric_b, on_step=False, on_epoch=True)
8282
result.log("h", "c", metric_c, on_step=True, on_epoch=False)
8383

84-
batch_log = result.metrics(True)[MetricSource.LOG]
84+
batch_log = result.metrics(True)["log"]
8585
assert batch_log == {"a_step": i, "c": i}
8686

87-
epoch_log = result.metrics(False)[MetricSource.LOG]
87+
epoch_log = result.metrics(False)["log"]
8888
result.reset()
8989

9090
# assert metric state reset to default values
@@ -124,10 +124,10 @@ def test_result_metric_integration():
124124
result.log("h", "b", metric_b, on_step=False, on_epoch=True)
125125
result.log("h", "c", metric_c, on_step=True, on_epoch=False)
126126

127-
batch_log = result.metrics(True)[MetricSource.LOG]
127+
batch_log = result.metrics(True)["log"]
128128
assert batch_log == {"a_step": i, "c": i}
129129

130-
epoch_log = result.metrics(False)[MetricSource.LOG]
130+
epoch_log = result.metrics(False)["log"]
131131
result.reset()
132132

133133
# assert metric state reset to default values
@@ -248,7 +248,7 @@ def lightning_log(fx, *args, **kwargs):
248248
lightning_log("training_step", "b_1", b, on_step=False, on_epoch=True)
249249
lightning_log("training_step", "c_1", {"1": c, "2": c}, on_step=True, on_epoch=False)
250250

251-
batch_log = result.metrics(on_step=True)[MetricSource.LOG]
251+
batch_log = result.metrics(on_step=True)["log"]
252252
assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
253253
assert set(batch_log["c_1"]) == {"1", "2"}
254254

@@ -269,12 +269,12 @@ def lightning_log(fx, *args, **kwargs):
269269
# the sync fn has been kept
270270
assert result_copy["training_step.a"].meta.sync.fn == new_result["training_step.a"].meta.sync.fn
271271

272-
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
273-
epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG]
272+
epoch_log = result.metrics(on_step=False)["log"]
273+
epoch_log_copy = result_copy.metrics(on_step=False)["log"]
274274
assert epoch_log == epoch_log_copy
275275

276276
lightning_log("train_epoch_end", "a", metric_a, on_step=False, on_epoch=True)
277-
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
277+
epoch_log = result.metrics(on_step=False)["log"]
278278
assert epoch_log == {
279279
"a_1_epoch": 1,
280280
"a_epoch": cumulative_sum,
@@ -451,9 +451,9 @@ def on_epoch_end(self) -> None:
451451
total = sum(range(5)) * num_processes
452452
metrics = self.results.metrics(on_step=False)
453453
assert self.results["training_step.tracking"].value == total
454-
assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2
454+
assert metrics["callback"]["tracking"] == self.dummy_metric.compute() == 2
455455
assert self.results["training_step.tracking_2"].value == total
456-
assert metrics[MetricSource.CALLBACK]["tracking_2"] == self.dummy_metric.compute() == 2
456+
assert metrics["callback"]["tracking_2"] == self.dummy_metric.compute() == 2
457457
self.has_validated_sum = True
458458

459459
model = ExtendedBoringModel()

0 commit comments

Comments
 (0)