Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added IPU Accelerator ([#7867](https://github.com/PyTorchLightning/pytorch-lightning/pull/7867))


- Fault-tolerant training
* Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948))


- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))


Expand Down
164 changes: 143 additions & 21 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, replace
from functools import partial, wraps
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union

Expand All @@ -28,7 +28,8 @@
from pytorch_lightning.utilities.metrics import metrics_to_scalars

# re-define the ones from pytorch_lightning.utilities.types without the `Number` type
_METRIC = Union[Metric, torch.Tensor]
# TODO(@tchaton): Typing-pickle issue on python<3.7 (https://github.com/cloudpipe/cloudpickle/pull/318)
_METRIC = Any # Union[Metric, torch.Tensor]
_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]]


Expand All @@ -40,11 +41,15 @@ class MetricSource(LightningEnum):

@dataclass
class _Sync:
fn: Callable
fn: Optional[Callable] = None
should: bool = False
op: Optional[str] = None
group: Optional[Any] = None

def __post_init__(self) -> None:
if self.fn is None:
self.fn = self.no_op

@property
def __call__(self) -> Any:
return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op
Expand All @@ -62,27 +67,42 @@ class _Metadata:
logger: bool = True
on_step: bool = False
on_epoch: bool = True
reduce_fx: Union[str, Callable] = torch.mean
_reduce_fx: Callable = torch.mean
enable_graph: bool = False
dataloader_idx: Optional[int] = None
sync: _Sync = field(default_factory=_Sync)
_sync: Optional[_Sync] = None

def __post_init__(self) -> None:
@property
def reduce_fx(self) -> Callable:
return self._reduce_fx

@reduce_fx.setter
def reduce_fx(self, reduce_fx: Union[str, Callable]) -> None:
error = (
'Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported.'
' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`.'
f' Found: {self.reduce_fx}'
f' Found: {reduce_fx}'
)
if isinstance(self.reduce_fx, str):
reduce_fx = self.reduce_fx.lower()
self._reduce_fx = reduce_fx
if isinstance(reduce_fx, str):
reduce_fx = reduce_fx.lower()
if reduce_fx == 'avg':
reduce_fx = 'mean'
if reduce_fx not in ('min', 'max', 'mean', 'sum'):
raise MisconfigurationException(error)
self.reduce_fx = getattr(torch, reduce_fx)
self._reduce_fx = getattr(torch, reduce_fx)
elif self.is_custom_reduction:
raise MisconfigurationException(error)
self.sync.op = self.reduce_fx.__name__

@property
def sync(self) -> Optional[_Sync]:
return self._sync

@sync.setter
def sync(self, sync: _Sync) -> None:
if sync.op is None:
sync.op = self.reduce_fx.__name__
self._sync = sync

@property
def forked(self) -> bool:
Expand Down Expand Up @@ -113,6 +133,25 @@ def is_min_reduction(self) -> bool:
def is_custom_reduction(self) -> bool:
return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction or self.is_sum_reduction)

def __getstate__(self) -> dict:
# drop the `sync.fn` to avoid potential pickle errors
# need to drop `fn` first otherwise `asdict` produces a `RecursionError`
copy = replace(self, _sync=replace(self.sync, fn=None))
d = asdict(copy)
# delete the `None` value so it does not override
del d['_sync']['fn']
return d

def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None:
d = {**state, '_sync': _Sync(**state['_sync'], fn=sync_fn)}
self.__dict__.update(d)

@classmethod
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> '_Metadata':
meta = cls(state['fx'], state['name'])
meta.__setstate__(state, sync_fn=sync_fn)
return meta


class ResultMetric(Metric, DeviceDtypeModuleMixin):
"""Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
Expand Down Expand Up @@ -201,6 +240,24 @@ def __repr__(self) -> str:
state += f", cumulated_batch_size={self.cumulated_batch_size}"
return f"{self.__class__.__name__}({state})"

def __getstate__(self) -> dict:
d = super().__getstate__()
d['meta'] = d['meta'].__getstate__()
d['_class'] = self.__class__.__name__
return d

def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None:
d = {**state, 'meta': _Metadata._reconstruct(state['meta'], sync_fn=sync_fn)}
super().__setstate__(d)

@classmethod
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> 'ResultMetric':
# need to reconstruct twice because `meta` is used in `__init__`
meta = _Metadata._reconstruct(state['meta'])
result_metric = cls(meta, state['is_tensor'])
result_metric.__setstate__(state, sync_fn=sync_fn)
return result_metric


class ResultMetricCollection(dict):
"""
Expand All @@ -215,6 +272,37 @@ def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None:
super().__init__(*args)
self.meta = metadata

def __getstate__(self) -> dict:

def getstate(item: ResultMetric) -> dict:
return item.__getstate__()

items = apply_to_collection(dict(self), (ResultMetric, ResultMetricCollection), getstate)
return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__}

def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None:

def setstate(item: dict) -> Union[Dict[str, ResultMetric], ResultMetric, Any]:
# recurse through dictionaries to set the state. can't use `apply_to_collection`
# as it does not recurse items of the same type.
if not isinstance(item, dict):
return item
if item.get('_class') == ResultMetric.__name__:
return ResultMetric._reconstruct(item, sync_fn=sync_fn)
return {k: setstate(v) for k, v in item.items()}

items = setstate(state["items"])
self.update(items)

any_result_metric = next(iter(items.values()))
self.meta = any_result_metric.meta

@classmethod
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> 'ResultMetricCollection':
rmc = cls()
rmc.__setstate__(state, sync_fn=sync_fn)
return rmc


class ResultCollection(dict):
"""
Expand All @@ -234,7 +322,7 @@ class ResultCollection(dict):

DATALOADER_SUFFIX = "/dataloader_idx_{}"

def __init__(self, training: bool, device: Optional[torch.device] = None) -> None:
def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None:
super().__init__()
self.training = training
self._minimize = None
Expand Down Expand Up @@ -324,15 +412,16 @@ def log(
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=dataloader_idx,
sync=_Sync(
should=sync_dist,
fn=sync_dist_fn,
group=sync_dist_group,
)
)
meta.reduce_fx = reduce_fx
meta.sync = _Sync(
should=sync_dist,
fn=sync_dist_fn,
group=sync_dist_group,
)

if key not in self:
self.register_key(key, meta, value)
elif meta != self[key].meta:
Expand Down Expand Up @@ -397,7 +486,7 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str,
def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]:
metrics = {k: {} for k in MetricSource}

for key, result_metric in self.valid_items():
for _, result_metric in self.valid_items():

# extract forward_cache or computed from the ResultMetric. ignore when the output is None
value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False)
Expand Down Expand Up @@ -501,7 +590,40 @@ def __str__(self) -> str:
def __getstate__(self) -> dict:
d = self.__dict__.copy()
# can't deepcopy tensors with grad_fn
minimize = d.get('_minimize')
minimize = d['_minimize']
if minimize is not None:
d['_minimize'] = minimize.detach()
return d
extra = self.get('_extra')
if extra is not None:
d['_extra'] = extra
# all the items should be either `ResultMetric`s or `ResultMetricCollection`s
items = {k: v.__getstate__() for k, v in self.items() if k != '_extra'}
return {**d, 'items': items}

def __setstate__(self, state: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
self.__dict__.update({k: v for k, v in state.items() if k != 'items'})

def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]:
if not isinstance(item, dict):
raise ValueError(f'Unexpected value: {item}')
cls = item['_class']
if cls == ResultMetric.__name__:
cls = ResultMetric
elif cls == ResultMetricCollection.__name__:
cls = ResultMetricCollection
else:
raise ValueError(f"Unexpected class name: {cls}")
sync_fn = self[k].meta.sync.fn if k in self else None
return cls._reconstruct(item, sync_fn=sync_fn)

items = {k: setstate(k, v) for k, v in state['items'].items()}
self.update(items)

device = map_location or self.device
self.to(device)

def state_dict(self) -> dict:
return self.__getstate__()

def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
self.__setstate__(state_dict, map_location=map_location)
Loading