Skip to content

Commit 7be120c

Browse files
committed
Update code
1 parent 03e605f commit 7be120c

File tree

1 file changed

+66
-77
lines changed
  • pytorch_lightning/trainer/connectors/logger_connector

1 file changed

+66
-77
lines changed

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 66 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytorch_lightning.utilities.metrics import metrics_to_scalars
2929

3030
# re-define the ones from pytorch_lightning.utilities.types without the `Number` type
31-
# todo (tchaton) Resolve this typing bug in python 3.6
31+
# TODO(@tchaton): Typing-pickle issue on python<3.7 (https://github.com/cloudpipe/cloudpickle/pull/318)
3232
_METRIC = Any # Union[Metric, torch.Tensor]
3333
_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]]
3434

@@ -202,23 +202,8 @@ def __repr__(self) -> str:
202202
state += f", cumulated_batch_size={self.cumulated_batch_size}"
203203
return f"{self.__class__.__name__}({state})"
204204

205-
206-
class _ResultMetricSerializationHelper(dict):
207-
"""
208-
Since ``ResultCollection`` can hold ``ResultMetric`` values or dictionaries of them, we need
209-
a class to differentiate between the cases after converting to state dict when saving its state.
210-
"""
211-
212-
213-
class _ResultMetricCollectionSerializationHelper(dict):
214-
"""
215-
Since several ``ResultCollection`` can hold inside a ``ResultMetricCollection``, we need
216-
a class to differentiate between the cases after converting to state dict when saving its state.
217-
"""
218-
219-
def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None:
220-
super().__init__(*args)
221-
self.meta = metadata
205+
def __getstate__(self) -> dict:
206+
return {**super().__getstate__(), '_class': self.__class__.__name__}
222207

223208

224209
class ResultMetricCollection(dict):
@@ -234,6 +219,31 @@ def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None:
234219
super().__init__(*args)
235220
self.meta = metadata
236221

222+
def __getstate__(self) -> dict:
223+
224+
def getstate(item: ResultMetric) -> dict:
225+
return item.__getstate__()
226+
227+
items = apply_to_collection(dict(self), (ResultMetric, ResultMetricCollection), getstate)
228+
return {"items": items, "meta": self.meta, "_class": self.__class__.__name__}
229+
230+
def __setstate__(self, state: dict) -> None:
231+
self.meta = state["meta"]
232+
233+
def setstate(item: dict) -> Union[Dict[str, ResultMetric], ResultMetric, Any]:
234+
# recurse through dictionaries to set the state. can't use `apply_to_collection`
235+
# as it does not recurse items of the same type.
236+
if not isinstance(item, dict):
237+
return item
238+
if item.get('_class') == ResultMetric.__name__:
239+
result_metric = ResultMetric(item['meta'], item['is_tensor'])
240+
result_metric.__setstate__(item)
241+
return result_metric
242+
return {k: setstate(v) for k, v in item.items()}
243+
244+
items = setstate(state["items"])
245+
self.update(items)
246+
237247

238248
class ResultCollection(dict):
239249
"""
@@ -353,10 +363,6 @@ def log(
353363
)
354364
)
355365

356-
# the reduce function was drop while saving a checkpoint.
357-
if key in self and self[key].meta.sync.fn is None:
358-
self[key].meta.sync.fn = meta.sync.fn
359-
360366
if key not in self:
361367
self.register_key(key, meta, value)
362368
elif meta != self[key].meta:
@@ -424,9 +430,7 @@ def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]:
424430
for _, result_metric in self.valid_items():
425431

426432
# extract forward_cache or computed from the ResultMetric. ignore when the output is None
427-
value = apply_to_collection(
428-
result_metric, ResultMetric, self._get_cache, on_step, include_none=False, wrong_dtype=ResultCollection
429-
)
433+
value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False)
430434

431435
# check if the collection is empty
432436
has_tensor = False
@@ -525,60 +529,45 @@ def __str__(self) -> str:
525529
return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})'
526530

527531
def __getstate__(self) -> dict:
528-
d = self.__dict__.copy()
529532
# can't deepcopy tensors with grad_fn
530-
minimize = d.get('_minimize')
531-
if minimize is not None:
532-
d['_minimize'] = minimize.detach()
533-
return d
534-
535-
def state_dict(self):
536-
537-
def to_state_dict(
538-
item: Union[ResultMetric, ResultMetricCollection]
539-
) -> Union[_ResultMetricSerializationHelper, _ResultMetricCollectionSerializationHelper]:
540-
if isinstance(item, ResultMetricCollection):
541-
return _ResultMetricCollectionSerializationHelper(
542-
apply_to_collection(item, ResultMetric, to_state_dict), metadata=item.meta
543-
)
544-
state = item.__getstate__()
545-
state["meta"].sync.fn = None
546-
return _ResultMetricSerializationHelper(**item.__getstate__())
533+
minimize = None
534+
if self.minimize is not None:
535+
minimize = self.minimize.detach()
547536

537+
# all the items should be either `ResultMetric`s or `ResultMetricCollection`s
538+
items = {k: v.__getstate__() for k, v in self.items()}
548539
return {
549-
k: apply_to_collection(v, (ResultMetric, ResultMetricCollection), to_state_dict)
550-
for k, v in self.items()
540+
'training': self.training,
541+
'device': self.device,
542+
'minimize': minimize,
543+
'batch_size': self.batch_size,
544+
'items': items,
551545
}
552546

553-
def load_state_dict(self, state_dict: Dict[str, Any], sync_fn: Optional[Callable] = None) -> None:
554-
555-
def to_result_metric_collection(item: _ResultMetricCollectionSerializationHelper) -> ResultCollection:
556-
result_metric_collection = ResultMetricCollection()
557-
result_metric_collection.update(item)
558-
559-
def _to_device(item: ResultMetric) -> ResultMetric:
560-
return item.to(self.device)
561-
562-
result_metric_collection = apply_to_collection(result_metric_collection, ResultMetric, _to_device)
563-
result_metric_collection.meta = item.meta
564-
result_metric_collection.meta.sync.fn = sync_fn
565-
return result_metric_collection
566-
567-
def to_result_metric(item: _ResultMetricSerializationHelper) -> ResultMetric:
568-
result_metric = ResultMetric(item["meta"], item["is_tensor"])
569-
result_metric.__dict__.update(item)
570-
result_metric.meta.sync.fn = sync_fn
571-
return result_metric.to(self.device)
572-
573-
state_dict = {
574-
k: apply_to_collection(v, _ResultMetricCollectionSerializationHelper, to_result_metric_collection)
575-
for k, v in state_dict.items()
576-
}
577-
result_metric_collection = {k: v.meta for k, v in state_dict.items() if isinstance(v, ResultMetricCollection)}
578-
state_dict = {
579-
k: apply_to_collection(v, _ResultMetricSerializationHelper, to_result_metric)
580-
for k, v in state_dict.items()
581-
}
582-
self.update(state_dict)
583-
for k, meta in result_metric_collection.items():
584-
self[k].meta = meta
547+
def __setstate__(self, state: dict) -> None:
548+
self.training = state['training']
549+
self.device = state['device']
550+
self._minimize = state['minimize']
551+
self._batch_size = state['batch_size']
552+
553+
def setstate(item: dict) -> Union[ResultMetric, ResultMetricCollection]:
554+
if not isinstance(item, dict):
555+
raise ValueError(f'Unexpected value: {item}')
556+
cls = item['_class']
557+
if cls == ResultMetric.__name__:
558+
result_metric = ResultMetric(item['meta'], item['is_tensor'])
559+
elif cls == ResultMetricCollection.__name__:
560+
result_metric = ResultMetricCollection()
561+
else:
562+
raise ValueError(f"Unexpected class name: {cls}")
563+
result_metric.__setstate__(item)
564+
return result_metric
565+
566+
items = {k: setstate(v) for k, v in state['items'].items()}
567+
self.update(items)
568+
569+
def state_dict(self) -> dict:
570+
return self.__getstate__()
571+
572+
def load_state_dict(self, state_dict: dict) -> None:
573+
self.__setstate__(state_dict)

0 commit comments

Comments
 (0)