diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a46427fe623c..ad24b5d0c7ad6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added IPU Accelerator ([#7867](https://github.com/PyTorchLightning/pytorch-lightning/pull/7867)) +- Fault-tolerant training + * Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) + + - Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index cbc3dcfdefd98..f751916804ad1 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Generator -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, replace from functools import partial, wraps from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union @@ -28,7 +28,8 @@ from pytorch_lightning.utilities.metrics import metrics_to_scalars # re-define the ones from pytorch_lightning.utilities.types without the `Number` type -_METRIC = Union[Metric, torch.Tensor] +# TODO(@tchaton): Typing-pickle issue on python<3.7 (https://github.com/cloudpipe/cloudpickle/pull/318) +_METRIC = Any # Union[Metric, torch.Tensor] _METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] @@ -40,11 +41,15 @@ class MetricSource(LightningEnum): @dataclass class _Sync: - fn: Callable + fn: Optional[Callable] = None should: bool = False op: Optional[str] = None group: Optional[Any] = None + def __post_init__(self) -> None: + if self.fn is None: + self.fn = self.no_op + @property def __call__(self) -> Any: return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op @@ -62,27 +67,42 @@ class _Metadata: logger: bool = True on_step: bool = False on_epoch: bool = True - reduce_fx: Union[str, Callable] = torch.mean + _reduce_fx: Callable = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None - sync: _Sync = field(default_factory=_Sync) + _sync: Optional[_Sync] = None - def __post_init__(self) -> None: + @property + def reduce_fx(self) -> Callable: + return self._reduce_fx + + @reduce_fx.setter + def reduce_fx(self, reduce_fx: Union[str, Callable]) -> None: error = ( 'Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported.' ' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`.' - f' Found: {self.reduce_fx}' + f' Found: {reduce_fx}' ) - if isinstance(self.reduce_fx, str): - reduce_fx = self.reduce_fx.lower() + self._reduce_fx = reduce_fx + if isinstance(reduce_fx, str): + reduce_fx = reduce_fx.lower() if reduce_fx == 'avg': reduce_fx = 'mean' if reduce_fx not in ('min', 'max', 'mean', 'sum'): raise MisconfigurationException(error) - self.reduce_fx = getattr(torch, reduce_fx) + self._reduce_fx = getattr(torch, reduce_fx) elif self.is_custom_reduction: raise MisconfigurationException(error) - self.sync.op = self.reduce_fx.__name__ + + @property + def sync(self) -> Optional[_Sync]: + return self._sync + + @sync.setter + def sync(self, sync: _Sync) -> None: + if sync.op is None: + sync.op = self.reduce_fx.__name__ + self._sync = sync @property def forked(self) -> bool: @@ -113,6 +133,25 @@ def is_min_reduction(self) -> bool: def is_custom_reduction(self) -> bool: return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction or self.is_sum_reduction) + def __getstate__(self) -> dict: + # drop the `sync.fn` to avoid potential pickle errors + # need to drop `fn` first otherwise `asdict` produces a `RecursionError` + copy = replace(self, _sync=replace(self.sync, fn=None)) + d = asdict(copy) + # delete the `None` value so it does not override + del d['_sync']['fn'] + return d + + def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None: + d = {**state, '_sync': _Sync(**state['_sync'], fn=sync_fn)} + self.__dict__.update(d) + + @classmethod + def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> '_Metadata': + meta = cls(state['fx'], state['name']) + meta.__setstate__(state, sync_fn=sync_fn) + return meta + class ResultMetric(Metric, DeviceDtypeModuleMixin): """Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" @@ -201,6 +240,24 @@ def __repr__(self) -> str: state += f", cumulated_batch_size={self.cumulated_batch_size}" return f"{self.__class__.__name__}({state})" + def __getstate__(self) -> dict: + d = super().__getstate__() + d['meta'] = d['meta'].__getstate__() + d['_class'] = self.__class__.__name__ + return d + + def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None: + d = {**state, 'meta': _Metadata._reconstruct(state['meta'], sync_fn=sync_fn)} + super().__setstate__(d) + + @classmethod + def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> 'ResultMetric': + # need to reconstruct twice because `meta` is used in `__init__` + meta = _Metadata._reconstruct(state['meta']) + result_metric = cls(meta, state['is_tensor']) + result_metric.__setstate__(state, sync_fn=sync_fn) + return result_metric + class ResultMetricCollection(dict): """ @@ -215,6 +272,37 @@ def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None: super().__init__(*args) self.meta = metadata + def __getstate__(self) -> dict: + + def getstate(item: ResultMetric) -> dict: + return item.__getstate__() + + items = apply_to_collection(dict(self), (ResultMetric, ResultMetricCollection), getstate) + return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__} + + def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None: + + def setstate(item: dict) -> Union[Dict[str, ResultMetric], ResultMetric, Any]: + # recurse through dictionaries to set the state. can't use `apply_to_collection` + # as it does not recurse items of the same type. + if not isinstance(item, dict): + return item + if item.get('_class') == ResultMetric.__name__: + return ResultMetric._reconstruct(item, sync_fn=sync_fn) + return {k: setstate(v) for k, v in item.items()} + + items = setstate(state["items"]) + self.update(items) + + any_result_metric = next(iter(items.values())) + self.meta = any_result_metric.meta + + @classmethod + def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> 'ResultMetricCollection': + rmc = cls() + rmc.__setstate__(state, sync_fn=sync_fn) + return rmc + class ResultCollection(dict): """ @@ -234,7 +322,7 @@ class ResultCollection(dict): DATALOADER_SUFFIX = "/dataloader_idx_{}" - def __init__(self, training: bool, device: Optional[torch.device] = None) -> None: + def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None: super().__init__() self.training = training self._minimize = None @@ -324,15 +412,16 @@ def log( logger=logger, on_step=on_step, on_epoch=on_epoch, - reduce_fx=reduce_fx, enable_graph=enable_graph, dataloader_idx=dataloader_idx, - sync=_Sync( - should=sync_dist, - fn=sync_dist_fn, - group=sync_dist_group, - ) ) + meta.reduce_fx = reduce_fx + meta.sync = _Sync( + should=sync_dist, + fn=sync_dist_fn, + group=sync_dist_group, + ) + if key not in self: self.register_key(key, meta, value) elif meta != self[key].meta: @@ -397,7 +486,7 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]: metrics = {k: {} for k in MetricSource} - for key, result_metric in self.valid_items(): + for _, result_metric in self.valid_items(): # extract forward_cache or computed from the ResultMetric. ignore when the output is None value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False) @@ -501,7 +590,40 @@ def __str__(self) -> str: def __getstate__(self) -> dict: d = self.__dict__.copy() # can't deepcopy tensors with grad_fn - minimize = d.get('_minimize') + minimize = d['_minimize'] if minimize is not None: d['_minimize'] = minimize.detach() - return d + extra = self.get('_extra') + if extra is not None: + d['_extra'] = extra + # all the items should be either `ResultMetric`s or `ResultMetricCollection`s + items = {k: v.__getstate__() for k, v in self.items() if k != '_extra'} + return {**d, 'items': items} + + def __setstate__(self, state: dict, map_location: Optional[Union[str, torch.device]] = None) -> None: + self.__dict__.update({k: v for k, v in state.items() if k != 'items'}) + + def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: + if not isinstance(item, dict): + raise ValueError(f'Unexpected value: {item}') + cls = item['_class'] + if cls == ResultMetric.__name__: + cls = ResultMetric + elif cls == ResultMetricCollection.__name__: + cls = ResultMetricCollection + else: + raise ValueError(f"Unexpected class name: {cls}") + sync_fn = self[k].meta.sync.fn if k in self else None + return cls._reconstruct(item, sync_fn=sync_fn) + + items = {k: setstate(k, v) for k, v in state['items'].items()} + self.update(items) + + device = map_location or self.device + self.to(device) + + def state_dict(self) -> dict: + return self.__getstate__() + + def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None: + self.__setstate__(state_dict, map_location=map_location) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 2c4b35ad29118..6b7163c4aa643 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -11,13 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pickle +from copy import deepcopy + +import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric import tests.helpers.utils as tutils -from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection +from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -175,3 +182,149 @@ def lightning_log(fx, *args, **kwargs): for k in ('d0.a', 'd1.a'): assert result[k].value == torch.tensor(3.) + epoch, k assert result[k].cumulated_batch_size == torch.tensor(1.), k + + +def my_sync_dist(x): + return x + + +def test_result_collection_restoration(tmpdir): + """" + This test make sure metrics are properly reloaded on failure. + """ + + result = ResultCollection(True, torch.device("cpu")) + metric_a = DummyMetric() + metric_b = DummyMetric() + metric_c = DummyMetric() + metric_d = DummyMetric() + current_fx_name = None + batch_idx = None + + def lightning_log(fx, *args, **kwargs): + nonlocal current_fx_name + if current_fx_name != fx and batch_idx in (None, 0): + result.reset(metrics=False, fx=fx) + result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist) + current_fx_name = fx + + for _ in range(2): + + cumulative_sum = 0 + + for i in range(3): + + a = metric_a(i) + b = metric_b(i) + c = metric_c(i) + metric_d(i) + + cumulative_sum += i + + metric = metric_a if i < 1 else metric_d + lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True) + lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True) + lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False) + lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True) + lightning_log('training_step', 'b_1', b, on_step=False, on_epoch=True) + lightning_log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False) + + batch_log = result.metrics(on_step=True)[MetricSource.LOG] + assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"} + assert set(batch_log['c_1']) == {'1', '2'} + + result_copy = deepcopy(result) + new_result = ResultCollection(True, torch.device("cpu")) + state_dict = result.state_dict() + # check the sync fn was dropped + assert 'fn' not in state_dict['items']['training_step.a']['meta']['_sync'] + new_result.load_state_dict(state_dict) + # should match + assert result_copy == new_result + # the sync fn has been kept + assert result_copy['training_step.a'].meta.sync.fn == new_result['training_step.a'].meta.sync.fn + + epoch_log = result.metrics(on_step=False)[MetricSource.LOG] + epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG] + assert epoch_log == epoch_log_copy + + lightning_log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True) + epoch_log = result.metrics(on_step=False)[MetricSource.LOG] + assert epoch_log == { + 'a_1_epoch': 1, + 'a_epoch': cumulative_sum, + 'a': cumulative_sum, + 'b': cumulative_sum, + 'b_1': 1 + } + + # make sure can be pickled + pickle.loads(pickle.dumps(result)) + # make sure can be torch.loaded + filepath = str(tmpdir / 'result') + torch.save(result, filepath) + torch.load(filepath) + + # assert metric state reset to default values + result.reset() + assert metric_a.x == metric_a._defaults['x'] + assert metric_b.x == metric_b._defaults['x'] + assert metric_c.x == metric_c._defaults['x'] + + batch_idx = None + + +@pytest.mark.parametrize('device', ('cpu', pytest.param('cuda', marks=RunIf(min_gpus=1)))) +def test_lightning_module_logging_result_collection(tmpdir, device): + + class LoggingModel(BoringModel): + + def __init__(self): + super().__init__() + self.metric = DummyMetric() + + def validation_step(self, batch, batch_idx): + v = self.metric(batch_idx) + self.log_dict({"v": v, "m": self.metric}) + return super().validation_step(batch, batch_idx) + + def on_save_checkpoint(self, checkpoint) -> None: + results = self.trainer._results + state_dict = results.state_dict() + + # check device + assert results['validation_step.v'].value.device.type == device + assert state_dict['items']['validation_step.v']['value'].device.type == device + + # sync fn should be kept + assert results['validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce + + # sync fn dropped from the state dict + assert 'fn' not in state_dict['items']['validation_step.v']['meta']['_sync'] + results.load_state_dict(state_dict) + + # check device after loading + assert results['validation_step.v'].value.device.type == device + + # sync fn was preserved in the original result + assert results['validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce + + # default sync fn + new_results = ResultCollection(False, device) + new_results.load_state_dict(state_dict, map_location='cpu') + assert new_results['validation_step.v'].meta.sync.fn == _Sync.no_op + + # check map location + assert new_results['validation_step.v'].value.device.type == 'cpu' + + model = LoggingModel() + ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=2, + limit_val_batches=2, + callbacks=[ckpt], + gpus=1 if device == 'cuda' else 0, + ) + trainer.fit(model)