1414from collections .abc import Generator
1515from dataclasses import asdict , dataclass , replace
1616from functools import partial , wraps
17- from typing import Any , Callable , Dict , List , Mapping , Optional , Tuple , Union
17+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1818
1919import torch
2020from torchmetrics import Metric
2424from pytorch_lightning .utilities import rank_zero_warn
2525from pytorch_lightning .utilities .apply_func import apply_to_collection , apply_to_collections , move_data_to_device
2626from pytorch_lightning .utilities .data import extract_batch_size
27- from pytorch_lightning .utilities .enums import LightningEnum
2827from pytorch_lightning .utilities .exceptions import MisconfigurationException
2928from pytorch_lightning .utilities .memory import recursive_detach
3029from pytorch_lightning .utilities .metrics import metrics_to_scalars
3130from pytorch_lightning .utilities .warnings import WarningCache
3231
33- # re-define the ones from pytorch_lightning.utilities.types without the `Number` type
3432# TODO(@tchaton): Typing-pickle issue on python<3.7 (https://github.com/cloudpipe/cloudpickle/pull/318)
35- _METRIC = Any # Union[Metric, torch.Tensor]
36- _METRIC_COLLECTION = Union [_METRIC , Mapping [str , _METRIC ]]
33+ _IN_METRIC = Any # Union[Metric, torch.Tensor] # Do not include scalars as they were converted to tensors
3734_OUT_METRIC = Union [torch .Tensor , Dict [str , torch .Tensor ]]
3835_PBAR_METRIC = Union [float , Dict [str , float ]]
3936_OUT_DICT = Dict [str , _OUT_METRIC ]
@@ -49,12 +46,6 @@ class _METRICS(TypedDict):
4946warning_cache = WarningCache ()
5047
5148
52- class MetricSource (LightningEnum ):
53- CALLBACK = "callback"
54- PBAR = "pbar"
55- LOG = "log"
56-
57-
5849@dataclass
5950class _Sync :
6051 fn : Optional [Callable ] = None
@@ -80,14 +71,15 @@ def _generate_sync_fn(self) -> None:
8071 """Used to compute the syncing function and cache it."""
8172 fn = self .no_op if self .fn is None or not self .should or self .rank_zero_only else self .fn
8273 # save the function as `_fn` as the meta are being re-created and the object references need to match.
83- self ._fn = partial (fn , reduce_op = self .op , group = self .group )
74+ # ignore typing, bad support for `partial`: mypy/issues/1484
75+ self ._fn : Callable = partial (fn , reduce_op = self .op , group = self .group ) # type: ignore [arg-type]
8476
8577 @property
8678 def __call__ (self ) -> Any :
8779 return self ._fn
8880
8981 @staticmethod
90- def no_op (value : Any , * _ , ** __ ) -> Any :
82+ def no_op (value : Any , * _ : Any , ** __ : Any ) -> Any :
9183 return value
9284
9385
@@ -125,7 +117,8 @@ def _parse_reduce_fx(self) -> None:
125117 raise MisconfigurationException (error )
126118
127119 @property
128- def sync (self ) -> Optional [_Sync ]:
120+ def sync (self ) -> _Sync :
121+ assert self ._sync is not None
129122 return self ._sync
130123
131124 @sync .setter
@@ -196,7 +189,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
196189 if self .meta .is_mean_reduction :
197190 self .add_state ("cumulated_batch_size" , torch .tensor (0 , dtype = torch .float ), dist_reduce_fx = torch .sum )
198191
199- def update (self , value : _METRIC , batch_size : torch .Tensor ) -> None :
192+ def update (self , value : _IN_METRIC , batch_size : torch .Tensor ) -> None :
200193 if self .is_tensor :
201194 value = value .float ()
202195 # performance: no need to accumulate on values only logged on_step
@@ -232,7 +225,7 @@ def reset(self) -> None:
232225 self .value .reset ()
233226 self .has_reset = True
234227
235- def forward (self , value : _METRIC , batch_size : torch .Tensor ) -> None :
228+ def forward (self , value : _IN_METRIC , batch_size : torch .Tensor ) -> None :
236229 if self .meta .enable_graph :
237230 with torch .no_grad ():
238231 self .update (value , batch_size )
@@ -243,7 +236,7 @@ def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None:
243236 def _wrap_compute (self , compute : Any ) -> Any :
244237 # Override to avoid syncing - we handle it ourselves.
245238 @wraps (compute )
246- def wrapped_func (* args , ** kwargs ) :
239+ def wrapped_func (* args : Any , ** kwargs : Any ) -> Optional [ Any ] :
247240 if not self ._update_called :
248241 rank_zero_warn (
249242 f"The ``compute`` method of metric { self .__class__ .__name__ } "
@@ -253,8 +246,8 @@ def wrapped_func(*args, **kwargs):
253246 )
254247
255248 # return cached value
256- if self ._computed is not None :
257- return self ._computed
249+ if self ._computed is not None : # type: ignore
250+ return self ._computed # type: ignore
258251 self ._computed = compute (* args , ** kwargs )
259252 return self ._computed
260253
@@ -293,7 +286,7 @@ def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "Resul
293286 result_metric .__setstate__ (state , sync_fn = sync_fn )
294287 return result_metric
295288
296- def to (self , * args : Any , ** kwargs : Any ) -> "DeviceDtypeModuleMixin " :
289+ def to (self , * args : Any , ** kwargs : Any ) -> "ResultMetric " :
297290 self .__dict__ .update (
298291 apply_to_collection (self .__dict__ , (torch .Tensor , Metric ), move_data_to_device , * args , ** kwargs )
299292 )
@@ -309,7 +302,7 @@ class ResultMetricCollection(dict):
309302 with the same metadata.
310303 """
311304
312- def __init__ (self , * args ) -> None :
305+ def __init__ (self , * args : Any ) -> None :
313306 super ().__init__ (* args )
314307
315308 @property
@@ -320,20 +313,12 @@ def __getstate__(self, drop_value: bool = False) -> dict:
320313 def getstate (item : ResultMetric ) -> dict :
321314 return item .__getstate__ (drop_value = drop_value )
322315
323- items = apply_to_collection (dict (self ), ( ResultMetric , ResultMetricCollection ) , getstate )
316+ items = apply_to_collection (dict (self ), ResultMetric , getstate )
324317 return {"items" : items , "meta" : self .meta .__getstate__ (), "_class" : self .__class__ .__name__ }
325318
326319 def __setstate__ (self , state : dict , sync_fn : Optional [Callable ] = None ) -> None :
327- def setstate (item : dict ) -> Union [Dict [str , ResultMetric ], ResultMetric , Any ]:
328- # recurse through dictionaries to set the state. can't use `apply_to_collection`
329- # as it does not recurse items of the same type.
330- if not isinstance (item , dict ):
331- return item
332- if item .get ("_class" ) == ResultMetric .__name__ :
333- return ResultMetric ._reconstruct (item , sync_fn = sync_fn )
334- return {k : setstate (v ) for k , v in item .items ()}
335-
336- items = setstate (state ["items" ])
320+ # can't use `apply_to_collection` as it does not recurse items of the same type
321+ items = {k : ResultMetric ._reconstruct (v , sync_fn = sync_fn ) for k , v in state ["items" ].items ()}
337322 self .update (items )
338323
339324 @classmethod
@@ -343,6 +328,9 @@ def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "Resul
343328 return rmc
344329
345330
331+ _METRIC_COLLECTION = Union [_IN_METRIC , ResultMetricCollection ]
332+
333+
346334class ResultCollection (dict ):
347335 """
348336 Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or
@@ -364,7 +352,7 @@ class ResultCollection(dict):
364352 def __init__ (self , training : bool , device : Optional [Union [str , torch .device ]] = None ) -> None :
365353 super ().__init__ ()
366354 self .training = training
367- self ._minimize = None
355+ self ._minimize : Optional [ torch . Tensor ] = None
368356 self ._batch_size = torch .tensor (1 , device = device )
369357 self .device : Optional [Union [str , torch .device ]] = device
370358
@@ -413,7 +401,7 @@ def extra(self) -> Dict[str, Any]:
413401
414402 @extra .setter
415403 def extra (self , extra : Dict [str , Any ]) -> None :
416- def check_fn (v ) :
404+ def check_fn (v : torch . Tensor ) -> torch . Tensor :
417405 if v .grad_fn is not None :
418406 warning_cache .deprecation (
419407 f"One of the returned values { set (extra .keys ())} has a `grad_fn`. We will detach it automatically"
@@ -494,7 +482,7 @@ def log(
494482 def register_key (self , key : str , meta : _Metadata , value : _METRIC_COLLECTION ) -> None :
495483 """Create one ResultMetric object per value. Value can be provided as a nested collection"""
496484
497- def fn (v : _METRIC ) -> ResultMetric :
485+ def fn (v : _IN_METRIC ) -> ResultMetric :
498486 metric = ResultMetric (meta , isinstance (v , torch .Tensor ))
499487 return metric .to (self .device )
500488
@@ -504,7 +492,7 @@ def fn(v: _METRIC) -> ResultMetric:
504492 self [key ] = value
505493
506494 def update_metrics (self , key : str , value : _METRIC_COLLECTION ) -> None :
507- def fn (result_metric , v ) :
495+ def fn (result_metric : ResultMetric , v : ResultMetric ) -> None :
508496 # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
509497 result_metric .forward (v .to (self .device ), self .batch_size )
510498 result_metric .has_reset = False
@@ -545,7 +533,7 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str,
545533 return name , forked_name
546534
547535 def metrics (self , on_step : bool ) -> _METRICS :
548- metrics = { k : {} for k in MetricSource }
536+ metrics = _METRICS ( callback = {}, log = {}, pbar = {})
549537
550538 for _ , result_metric in self .valid_items ():
551539
@@ -559,7 +547,7 @@ def metrics(self, on_step: bool) -> _METRICS:
559547 # check if the collection is empty
560548 has_tensor = False
561549
562- def any_tensor (_ ) :
550+ def any_tensor (_ : Any ) -> None :
563551 nonlocal has_tensor
564552 has_tensor = True
565553
@@ -571,16 +559,16 @@ def any_tensor(_):
571559
572560 # populate logging metrics
573561 if result_metric .meta .logger :
574- metrics [MetricSource . LOG ][forked_name ] = value
562+ metrics ["log" ][forked_name ] = value
575563
576564 # populate callback metrics. callback metrics don't take `_step` forked metrics
577565 if self .training or result_metric .meta .on_epoch and not on_step :
578- metrics [MetricSource . CALLBACK ][name ] = value
579- metrics [MetricSource . CALLBACK ][forked_name ] = value
566+ metrics ["callback" ][name ] = value
567+ metrics ["callback" ][forked_name ] = value
580568
581569 # populate progress_bar metrics. convert tensors to numbers
582570 if result_metric .meta .prog_bar :
583- metrics [MetricSource . PBAR ][forked_name ] = metrics_to_scalars (value )
571+ metrics ["pbar" ][forked_name ] = metrics_to_scalars (value )
584572
585573 return metrics
586574
@@ -609,7 +597,7 @@ def extract_batch_size(self, batch: Any) -> None:
609597 except RecursionError :
610598 self .batch_size = 1
611599
612- def to (self , * args , ** kwargs ) -> "ResultCollection" :
600+ def to (self , * args : Any , ** kwargs : Any ) -> "ResultCollection" :
613601 """Move all data to the given device."""
614602 self .update (apply_to_collection (dict (self ), (torch .Tensor , Metric ), move_data_to_device , * args , ** kwargs ))
615603
@@ -641,7 +629,7 @@ def __str__(self) -> str:
641629 self_str = str ({k : v for k , v in self .items () if v })
642630 return f"{ self .__class__ .__name__ } ({ minimize } { self_str } )"
643631
644- def __repr__ (self ):
632+ def __repr__ (self ) -> str :
645633 # sample output: `{True, cpu, minimize=tensor(1.23 grad_fn=<SumBackward0>), {'_extra': {}}}`
646634 minimize = f"minimize={ repr (self .minimize )} , " if self .minimize is not None else ""
647635 return f"{{{ self .training } , { repr (self .device )} , " + minimize + f"{ super ().__repr__ ()} }}"
0 commit comments