2828from pytorch_lightning .utilities .metrics import metrics_to_scalars
2929
3030# re-define the ones from pytorch_lightning.utilities.types without the `Number` type
31- # todo ( tchaton) Resolve this typing bug in python 3.6
31+ # TODO(@ tchaton): Typing-pickle issue on python<3.7 (https://github.com/cloudpipe/cloudpickle/pull/318)
3232_METRIC = Any # Union[Metric, torch.Tensor]
3333_METRIC_COLLECTION = Union [_METRIC , Mapping [str , _METRIC ]]
3434
@@ -202,23 +202,8 @@ def __repr__(self) -> str:
202202 state += f", cumulated_batch_size={ self .cumulated_batch_size } "
203203 return f"{ self .__class__ .__name__ } ({ state } )"
204204
205-
206- class _ResultMetricSerializationHelper (dict ):
207- """
208- Since ``ResultCollection`` can hold ``ResultMetric`` values or dictionaries of them, we need
209- a class to differentiate between the cases after converting to state dict when saving its state.
210- """
211-
212-
213- class _ResultMetricCollectionSerializationHelper (dict ):
214- """
215- Since several ``ResultCollection`` can hold inside a ``ResultMetricCollection``, we need
216- a class to differentiate between the cases after converting to state dict when saving its state.
217- """
218-
219- def __init__ (self , * args , metadata : Optional [_Metadata ] = None ) -> None :
220- super ().__init__ (* args )
221- self .meta = metadata
205+ def __getstate__ (self ) -> dict :
206+ return {** super ().__getstate__ (), '_class' : self .__class__ .__name__ }
222207
223208
224209class ResultMetricCollection (dict ):
@@ -234,6 +219,31 @@ def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None:
234219 super ().__init__ (* args )
235220 self .meta = metadata
236221
222+ def __getstate__ (self ) -> dict :
223+
224+ def getstate (item : ResultMetric ) -> dict :
225+ return item .__getstate__ ()
226+
227+ items = apply_to_collection (dict (self ), (ResultMetric , ResultMetricCollection ), getstate )
228+ return {"items" : items , "meta" : self .meta , "_class" : self .__class__ .__name__ }
229+
230+ def __setstate__ (self , state : dict ) -> None :
231+ self .meta = state ["meta" ]
232+
233+ def setstate (item : dict ) -> Union [Dict [str , ResultMetric ], ResultMetric , Any ]:
234+ # recurse through dictionaries to set the state. can't use `apply_to_collection`
235+ # as it does not recurse items of the same type.
236+ if not isinstance (item , dict ):
237+ return item
238+ if item .get ('_class' ) == ResultMetric .__name__ :
239+ result_metric = ResultMetric (item ['meta' ], item ['is_tensor' ])
240+ result_metric .__setstate__ (item )
241+ return result_metric
242+ return {k : setstate (v ) for k , v in item .items ()}
243+
244+ items = setstate (state ["items" ])
245+ self .update (items )
246+
237247
238248class ResultCollection (dict ):
239249 """
@@ -353,10 +363,6 @@ def log(
353363 )
354364 )
355365
356- # the reduce function was drop while saving a checkpoint.
357- if key in self and self [key ].meta .sync .fn is None :
358- self [key ].meta .sync .fn = meta .sync .fn
359-
360366 if key not in self :
361367 self .register_key (key , meta , value )
362368 elif meta != self [key ].meta :
@@ -424,9 +430,7 @@ def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]:
424430 for _ , result_metric in self .valid_items ():
425431
426432 # extract forward_cache or computed from the ResultMetric. ignore when the output is None
427- value = apply_to_collection (
428- result_metric , ResultMetric , self ._get_cache , on_step , include_none = False , wrong_dtype = ResultCollection
429- )
433+ value = apply_to_collection (result_metric , ResultMetric , self ._get_cache , on_step , include_none = False )
430434
431435 # check if the collection is empty
432436 has_tensor = False
@@ -525,60 +529,45 @@ def __str__(self) -> str:
525529 return f'{ self .__class__ .__name__ } ({ self .training } , { self .device } , { repr (self )} )'
526530
527531 def __getstate__ (self ) -> dict :
528- d = self .__dict__ .copy ()
529532 # can't deepcopy tensors with grad_fn
530- minimize = d .get ('_minimize' )
531- if minimize is not None :
532- d ['_minimize' ] = minimize .detach ()
533- return d
534-
535- def state_dict (self ):
536-
537- def to_state_dict (
538- item : Union [ResultMetric , ResultMetricCollection ]
539- ) -> Union [_ResultMetricSerializationHelper , _ResultMetricCollectionSerializationHelper ]:
540- if isinstance (item , ResultMetricCollection ):
541- return _ResultMetricCollectionSerializationHelper (
542- apply_to_collection (item , ResultMetric , to_state_dict ), metadata = item .meta
543- )
544- state = item .__getstate__ ()
545- state ["meta" ].sync .fn = None
546- return _ResultMetricSerializationHelper (** item .__getstate__ ())
533+ minimize = None
534+ if self .minimize is not None :
535+ minimize = self .minimize .detach ()
547536
537+ # all the items should be either `ResultMetric`s or `ResultMetricCollection`s
538+ items = {k : v .__getstate__ () for k , v in self .items ()}
548539 return {
549- k : apply_to_collection (v , (ResultMetric , ResultMetricCollection ), to_state_dict )
550- for k , v in self .items ()
540+ 'training' : self .training ,
541+ 'device' : self .device ,
542+ 'minimize' : minimize ,
543+ 'batch_size' : self .batch_size ,
544+ 'items' : items ,
551545 }
552546
553- def load_state_dict (self , state_dict : Dict [str , Any ], sync_fn : Optional [Callable ] = None ) -> None :
554-
555- def to_result_metric_collection (item : _ResultMetricCollectionSerializationHelper ) -> ResultCollection :
556- result_metric_collection = ResultMetricCollection ()
557- result_metric_collection .update (item )
558-
559- def _to_device (item : ResultMetric ) -> ResultMetric :
560- return item .to (self .device )
561-
562- result_metric_collection = apply_to_collection (result_metric_collection , ResultMetric , _to_device )
563- result_metric_collection .meta = item .meta
564- result_metric_collection .meta .sync .fn = sync_fn
565- return result_metric_collection
566-
567- def to_result_metric (item : _ResultMetricSerializationHelper ) -> ResultMetric :
568- result_metric = ResultMetric (item ["meta" ], item ["is_tensor" ])
569- result_metric .__dict__ .update (item )
570- result_metric .meta .sync .fn = sync_fn
571- return result_metric .to (self .device )
572-
573- state_dict = {
574- k : apply_to_collection (v , _ResultMetricCollectionSerializationHelper , to_result_metric_collection )
575- for k , v in state_dict .items ()
576- }
577- result_metric_collection = {k : v .meta for k , v in state_dict .items () if isinstance (v , ResultMetricCollection )}
578- state_dict = {
579- k : apply_to_collection (v , _ResultMetricSerializationHelper , to_result_metric )
580- for k , v in state_dict .items ()
581- }
582- self .update (state_dict )
583- for k , meta in result_metric_collection .items ():
584- self [k ].meta = meta
547+ def __setstate__ (self , state : dict ) -> None :
548+ self .training = state ['training' ]
549+ self .device = state ['device' ]
550+ self ._minimize = state ['minimize' ]
551+ self ._batch_size = state ['batch_size' ]
552+
553+ def setstate (item : dict ) -> Union [ResultMetric , ResultMetricCollection ]:
554+ if not isinstance (item , dict ):
555+ raise ValueError (f'Unexpected value: { item } ' )
556+ cls = item ['_class' ]
557+ if cls == ResultMetric .__name__ :
558+ result_metric = ResultMetric (item ['meta' ], item ['is_tensor' ])
559+ elif cls == ResultMetricCollection .__name__ :
560+ result_metric = ResultMetricCollection ()
561+ else :
562+ raise ValueError (f"Unexpected class name: { cls } " )
563+ result_metric .__setstate__ (item )
564+ return result_metric
565+
566+ items = {k : setstate (v ) for k , v in state ['items' ].items ()}
567+ self .update (items )
568+
569+ def state_dict (self ) -> dict :
570+ return self .__getstate__ ()
571+
572+ def load_state_dict (self , state_dict : dict ) -> None :
573+ self .__setstate__ (state_dict )
0 commit comments