From 4cb7e89cbb7a1ece025d732bfad547d457200d84 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 11 Jun 2021 20:59:18 +0100 Subject: [PATCH 01/90] add metric reload --- CHANGELOG.md | 1 - pytorch_lightning/core/lightning.py | 26 ++++++ .../connectors/logger_connector/result.py | 83 +++++++++++++++++- tests/core/test_metric_result_integration.py | 87 ++++++++++++++++++- 4 files changed, 192 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e63484ca8612..dba288cec1c70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -122,7 +122,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Remove `EpochResultStore` and `HookResultStore` in favor of `ResultCollection` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) * Remove `MetricsHolder` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) - - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 02633d3df16fa..b5978f6d8a6b1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -111,6 +111,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() + self._metric_attributes: Optional[Dict[int, str]] = None def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -271,6 +272,7 @@ def log( sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, + metric_attribute: Optional[str] = None, ) -> None: """ Log a key, value @@ -308,6 +310,8 @@ def log( each dataloader to not mix values batch_size: Current batch_size. This will be directly inferred from the loaded batch, but some data structures might need to explicitly provide it. + metric_attribute: The attribute name for the metric in the LightningModule. + Necessary to save/restore its state. """ if tbptt_reduce_fx is not None: rank_zero_deprecation( @@ -360,6 +364,27 @@ def log( # reset any tensors for the new hook name results.reset(metrics=False, fx=self._current_fx_name) + if metric_attribute is None and isinstance(value, Metric): + if self._metric_attributes is None: + # compute once + self._metric_attributes = { + id(module): name + for name, module in self.named_children() if isinstance(module, Metric) + } + if not self._metric_attributes: + raise MisconfigurationException( + "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." + " You can fix this by setting an attribute for the metric in your `LightningModule`." + ) + # try to find the passed metric in the LightningModule + metric_attribute = self._metric_attributes.get(id(value)) + if metric_attribute is None: + raise MisconfigurationException( + "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." + f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one" + f" of {list(self._metric_attributes.values())}" + ) + results.log( self._current_fx_name, name, @@ -375,6 +400,7 @@ def log( sync_dist=sync_dist, sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available, sync_dist_group=sync_dist_group, + metric_attribute=metric_attribute, ) self.trainer.logger_connector._current_fx = self._current_fx_name diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index cbc3dcfdefd98..589ff651e979a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -65,6 +65,7 @@ class _Metadata: reduce_fx: Union[str, Callable] = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None + metric_attribute: Optional[str] = None sync: _Sync = field(default_factory=_Sync) def __post_init__(self) -> None: @@ -202,6 +203,24 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({state})" +class _ResultMetricSerializationHelper(dict): + """ + Since ``ResultCollection`` can hold ``ResultMetric`` values or dictionaries of them, we need + a class to differentiate between the cases after converting to state dict when saving its state. + """ + + +class _ResultMetricCollectionSerializationHelper(dict): + """ + Since ``ResultCollection`` can hold ``ResultMetricCollection`` values or dictionaries of them, we need + a class to differentiate between the cases after converting to state dict when saving its state. + """ + + def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None: + super().__init__(*args) + self.meta = metadata + + class ResultMetricCollection(dict): """ Dict wrapper for easy access to metadata. @@ -300,6 +319,7 @@ def log( sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, + metric_attribute: Optional[str] = None, ) -> None: """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs @@ -327,6 +347,7 @@ def log( reduce_fx=reduce_fx, enable_graph=enable_graph, dataloader_idx=dataloader_idx, + metric_attribute=metric_attribute, sync=_Sync( should=sync_dist, fn=sync_dist_fn, @@ -398,9 +419,12 @@ def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]: metrics = {k: {} for k in MetricSource} for key, result_metric in self.valid_items(): + print(key, metrics) # extract forward_cache or computed from the ResultMetric. ignore when the output is None - value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False) + value = apply_to_collection( + result_metric, ResultMetric, self._get_cache, on_step, include_none=False, wrong_dtype=ResultCollection + ) # check if the collection is empty has_tensor = False @@ -505,3 +529,60 @@ def __getstate__(self) -> dict: if minimize is not None: d['_minimize'] = minimize.detach() return d + + def state_dict(self): + + def to_state_dict( + item: Union[ResultMetric, ResultMetricCollection] + ) -> Union[_ResultMetricSerializationHelper, _ResultMetricCollectionSerializationHelper]: + if isinstance(item, ResultMetricCollection): + return _ResultMetricCollectionSerializationHelper( + apply_to_collection(item, ResultMetric, to_state_dict), metadata=item.meta + ) + return _ResultMetricSerializationHelper(**item.__getstate__()) + + return { + k: apply_to_collection(v, (ResultMetric, ResultMetricCollection), to_state_dict) + for k, v in self.items() + } + + def load_from_state_dict(self, state_dict: Dict[str, Any], metrics: Optional[Dict[str, Metric]] = None) -> None: + + def to_result_metric_collection(item: _ResultMetricCollectionSerializationHelper) -> ResultCollection: + result_metric_collection = ResultMetricCollection() + result_metric_collection.update(item) + + def _to_device(item: ResultMetric) -> ResultMetric: + return item.to(self.device) + + result_metric_collection = apply_to_collection(result_metric_collection, ResultMetric, _to_device) + result_metric_collection.meta = item.meta + return result_metric_collection + + def to_result_metric(item: _ResultMetricSerializationHelper) -> ResultMetric: + result_metric = ResultMetric(item["meta"], item["is_tensor"]) + result_metric.__dict__.update(item) + return result_metric.to(self.device) + + state_dict = { + k: apply_to_collection(v, _ResultMetricCollectionSerializationHelper, to_result_metric_collection) + for k, v in state_dict.items() + } + result_metric_collection = {k: v.meta for k, v in state_dict.items() if isinstance(v, ResultMetricCollection)} + state_dict = { + k: apply_to_collection(v, _ResultMetricSerializationHelper, to_result_metric) + for k, v in state_dict.items() + } + self.update(state_dict) + for k, meta in result_metric_collection.items(): + self[k].meta = meta + + if metrics: + + def re_assign_metric(item: ResultMetric) -> None: + # metric references are lost during serialization and need to be set back during loading + name = item.meta.metric_attribute + if isinstance(name, str) and name in metrics: + item.value = metrics[name] + + apply_to_collection(self, ResultMetric, re_assign_metric) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 2c4b35ad29118..abc3187de1f62 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy + import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -66,9 +68,9 @@ def _ddp_test_fn(rank, worldsize): cumulative_sum += i - result.log('h', 'a', metric_a, on_step=True, on_epoch=True) - result.log('h', 'b', metric_b, on_step=False, on_epoch=True) - result.log('h', 'c', metric_c, on_step=True, on_epoch=False) + result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") + result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") + result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") batch_log = result.metrics(True)[MetricSource.LOG] assert batch_log == {"a_step": i, "c": i} @@ -175,3 +177,82 @@ def lightning_log(fx, *args, **kwargs): for k in ('d0.a', 'd1.a'): assert result[k].value == torch.tensor(3.) + epoch, k assert result[k].cumulated_batch_size == torch.tensor(1.), k + + +def test_result_collection_restoration(): + result = ResultCollection(True, torch.device("cpu")) + _result = None + metric_a = DummyMetric() + metric_b = DummyMetric() + metric_c = DummyMetric() + current_fx_name = None + batch_idx = None + + def lightning_log(fx, *args, **kwargs): + nonlocal current_fx_name + if current_fx_name != fx and batch_idx in (None, 0): + result.reset(metrics=False, fx=fx) + result.log(fx, *args, **kwargs) + current_fx_name = fx + + for _ in range(2): + + cumulative_sum = 0 + + for i in range(3): + + a = metric_a(i) + b = metric_b(i) + c = metric_c(i) + + cumulative_sum += i + + lightning_log('training_step', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") + lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") + lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") + lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True) + lightning_log('training_step', 'b_1', b, on_step=False, on_epoch=True) + lightning_log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False) + + batch_log = result.metrics(on_step=True)[MetricSource.LOG] + assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"} + assert set(batch_log['c_1']) == {'1', '2'} + + _result = deepcopy(result) + state_dict = result.state_dict() + + result = ResultCollection(True, torch.device("cpu")) + result.load_from_state_dict( + state_dict, { + "metric_a": metric_a, + "metric_b": metric_b, + "metric_c": metric_c, + "metric_a_end": metric_a + } + ) + + assert _result.items() == result.items() + assert _result["training_step.c_1"].meta == result["training_step.c_1"].meta + + batch_idx = None + + epoch_log = result.metrics(on_step=False)[MetricSource.LOG] + _epoch_log = result.metrics(on_step=False)[MetricSource.LOG] + assert epoch_log == _epoch_log + + assert set(epoch_log) == {'a_1_epoch', 'a_epoch', 'b', 'b_1'} + for k in epoch_log: + if k in {'a_epoch', 'b'}: + assert epoch_log[k] == cumulative_sum + else: + assert epoch_log[k] == 1 + + lightning_log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True, metric_attribute="metric_a_end") + + result.reset() + _result.reset() + + # assert metric state reset to default values + assert metric_a.x == metric_a._defaults['x'] + assert metric_b.x == metric_b._defaults['x'] + assert metric_c.x == metric_c._defaults['x'] From 4176447329e9ce02b485456d0a5cc19c40800809 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 11 Jun 2021 21:17:23 +0100 Subject: [PATCH 02/90] add tests --- pytorch_lightning/core/lightning.py | 3 +- .../connectors/logger_connector/result.py | 2 +- tests/core/test_metric_result_integration.py | 40 +++++++++++++++++++ 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b5978f6d8a6b1..5ec6fefd51a21 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -369,7 +369,8 @@ def log( # compute once self._metric_attributes = { id(module): name - for name, module in self.named_children() if isinstance(module, Metric) + for name, module in self._named_members(lambda module: module._modules.items(), recurse=True) + if isinstance(module, Metric) } if not self._metric_attributes: raise MisconfigurationException( diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 589ff651e979a..b33d09fa06282 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -212,7 +212,7 @@ class _ResultMetricSerializationHelper(dict): class _ResultMetricCollectionSerializationHelper(dict): """ - Since ``ResultCollection`` can hold ``ResultMetricCollection`` values or dictionaries of them, we need + Since several ``ResultCollection`` can hold inside a ``ResultMetricCollection``, we need a class to differentiate between the cases after converting to state dict when saving its state. """ diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index abc3187de1f62..730fd4a4a5b00 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy +from operator import attrgetter import torch import torch.distributed as dist @@ -19,7 +20,9 @@ from torchmetrics import Metric import tests.helpers.utils as tutils +from pytorch_lightning import Trainer from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection +from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -180,6 +183,10 @@ def lightning_log(fx, *args, **kwargs): def test_result_collection_restoration(): + """" + This test make sure metrics are properly reloaded on failure. + """ + result = ResultCollection(True, torch.device("cpu")) _result = None metric_a = DummyMetric() @@ -256,3 +263,36 @@ def lightning_log(fx, *args, **kwargs): assert metric_a.x == metric_a._defaults['x'] assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] + + +def test_result_collection_attribute_name_nested(tmpdir): + """ + This test make sure metric_attribute is properly capture even when nested in children modules + """ + metric = DummyMetric() + + class CustomModule(torch.nn.Module): + + def __init__(self, metric): + super().__init__() + + self.dummy_metric = metric + + class TestModel(BoringModel): + + def __init__(self, metric): + super().__init__() + + self.custom = CustomModule(metric) + + def training_step(self, batch, batch_idx): + self.custom.dummy_metric(1) + self.log("dummy", self.custom.dummy_metric) + return super().training_step(batch, batch_idx) + + model = TestModel(metric) + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, max_epochs=1) + trainer.fit(model) + metric_attribute = trainer.train_loop.results['training_step.dummy'].meta.metric_attribute + assert metric_attribute == 'custom.dummy_metric' + assert id(attrgetter(metric_attribute)(model)) == id(metric) From 95946533d9530cc7ec5bd38489b1ddfd84558305 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 11 Jun 2021 21:19:41 +0100 Subject: [PATCH 03/90] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index dba288cec1c70..c0600668a2402 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -121,6 +121,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Each of the training loops now keeps its own results collection ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891)) * Remove `EpochResultStore` and `HookResultStore` in favor of `ResultCollection` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) * Remove `MetricsHolder` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) + * Add `load_from_state_dict` to ResultCollection ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) From 0fa64ed5ff15264933724dcf5aa77cb8bb9f72c3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 11 Jun 2021 21:21:26 +0100 Subject: [PATCH 04/90] udpate --- tests/trainer/logging_/test_logger_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index d93054439082b..91779a740152f 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -322,8 +322,8 @@ def _step(self, stage, batch): acc.reset.reset_mock() ap.reset.reset_mock() - self.log(f"{stage}/accuracy", acc) - self.log(f"{stage}/ap", ap) + self.log(f"{stage}/accuracy", acc, metric_attribute="dummy") + self.log(f"{stage}/ap", ap, metric_attribute="dummy") return loss From 9828e72d8136fa53553774300afecb617dfc92cc Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 11 Jun 2021 21:21:48 +0100 Subject: [PATCH 05/90] remove print --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index b33d09fa06282..faf35625c8a53 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -419,7 +419,6 @@ def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]: metrics = {k: {} for k in MetricSource} for key, result_metric in self.valid_items(): - print(key, metrics) # extract forward_cache or computed from the ResultMetric. ignore when the output is None value = apply_to_collection( From f85d590e9d2a503a418efe9f4309aed9e2860407 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 08:05:28 +0100 Subject: [PATCH 06/90] remove attribute_name --- pytorch_lightning/core/lightning.py | 27 ------------ .../connectors/logger_connector/result.py | 13 ------ tests/core/test_metric_result_integration.py | 44 ++----------------- .../trainer/logging_/test_logger_connector.py | 4 +- 4 files changed, 6 insertions(+), 82 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 5ec6fefd51a21..02633d3df16fa 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -111,7 +111,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() - self._metric_attributes: Optional[Dict[int, str]] = None def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -272,7 +271,6 @@ def log( sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, - metric_attribute: Optional[str] = None, ) -> None: """ Log a key, value @@ -310,8 +308,6 @@ def log( each dataloader to not mix values batch_size: Current batch_size. This will be directly inferred from the loaded batch, but some data structures might need to explicitly provide it. - metric_attribute: The attribute name for the metric in the LightningModule. - Necessary to save/restore its state. """ if tbptt_reduce_fx is not None: rank_zero_deprecation( @@ -364,28 +360,6 @@ def log( # reset any tensors for the new hook name results.reset(metrics=False, fx=self._current_fx_name) - if metric_attribute is None and isinstance(value, Metric): - if self._metric_attributes is None: - # compute once - self._metric_attributes = { - id(module): name - for name, module in self._named_members(lambda module: module._modules.items(), recurse=True) - if isinstance(module, Metric) - } - if not self._metric_attributes: - raise MisconfigurationException( - "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." - " You can fix this by setting an attribute for the metric in your `LightningModule`." - ) - # try to find the passed metric in the LightningModule - metric_attribute = self._metric_attributes.get(id(value)) - if metric_attribute is None: - raise MisconfigurationException( - "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." - f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one" - f" of {list(self._metric_attributes.values())}" - ) - results.log( self._current_fx_name, name, @@ -401,7 +375,6 @@ def log( sync_dist=sync_dist, sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available, sync_dist_group=sync_dist_group, - metric_attribute=metric_attribute, ) self.trainer.logger_connector._current_fx = self._current_fx_name diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index faf35625c8a53..3a268c4fdada6 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -65,7 +65,6 @@ class _Metadata: reduce_fx: Union[str, Callable] = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None - metric_attribute: Optional[str] = None sync: _Sync = field(default_factory=_Sync) def __post_init__(self) -> None: @@ -319,7 +318,6 @@ def log( sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, - metric_attribute: Optional[str] = None, ) -> None: """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs @@ -347,7 +345,6 @@ def log( reduce_fx=reduce_fx, enable_graph=enable_graph, dataloader_idx=dataloader_idx, - metric_attribute=metric_attribute, sync=_Sync( should=sync_dist, fn=sync_dist_fn, @@ -575,13 +572,3 @@ def to_result_metric(item: _ResultMetricSerializationHelper) -> ResultMetric: self.update(state_dict) for k, meta in result_metric_collection.items(): self[k].meta = meta - - if metrics: - - def re_assign_metric(item: ResultMetric) -> None: - # metric references are lost during serialization and need to be set back during loading - name = item.meta.metric_attribute - if isinstance(name, str) and name in metrics: - item.value = metrics[name] - - apply_to_collection(self, ResultMetric, re_assign_metric) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 730fd4a4a5b00..4b54e8677d320 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from operator import attrgetter import torch import torch.distributed as dist @@ -20,9 +19,7 @@ from torchmetrics import Metric import tests.helpers.utils as tutils -from pytorch_lightning import Trainer from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection -from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -214,9 +211,9 @@ def lightning_log(fx, *args, **kwargs): cumulative_sum += i - lightning_log('training_step', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") - lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") - lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") + lightning_log('training_step', 'a', metric_a, on_step=True, on_epoch=True) + lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True) + lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False) lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True) lightning_log('training_step', 'b_1', b, on_step=False, on_epoch=True) lightning_log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False) @@ -254,7 +251,7 @@ def lightning_log(fx, *args, **kwargs): else: assert epoch_log[k] == 1 - lightning_log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True, metric_attribute="metric_a_end") + lightning_log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True) result.reset() _result.reset() @@ -263,36 +260,3 @@ def lightning_log(fx, *args, **kwargs): assert metric_a.x == metric_a._defaults['x'] assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] - - -def test_result_collection_attribute_name_nested(tmpdir): - """ - This test make sure metric_attribute is properly capture even when nested in children modules - """ - metric = DummyMetric() - - class CustomModule(torch.nn.Module): - - def __init__(self, metric): - super().__init__() - - self.dummy_metric = metric - - class TestModel(BoringModel): - - def __init__(self, metric): - super().__init__() - - self.custom = CustomModule(metric) - - def training_step(self, batch, batch_idx): - self.custom.dummy_metric(1) - self.log("dummy", self.custom.dummy_metric) - return super().training_step(batch, batch_idx) - - model = TestModel(metric) - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, max_epochs=1) - trainer.fit(model) - metric_attribute = trainer.train_loop.results['training_step.dummy'].meta.metric_attribute - assert metric_attribute == 'custom.dummy_metric' - assert id(attrgetter(metric_attribute)(model)) == id(metric) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 91779a740152f..d93054439082b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -322,8 +322,8 @@ def _step(self, stage, batch): acc.reset.reset_mock() ap.reset.reset_mock() - self.log(f"{stage}/accuracy", acc, metric_attribute="dummy") - self.log(f"{stage}/ap", ap, metric_attribute="dummy") + self.log(f"{stage}/accuracy", acc) + self.log(f"{stage}/ap", ap) return loss From 31d390d107980543a03ca9735621a01643fee243 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 08:21:34 +0100 Subject: [PATCH 07/90] update --- .../connectors/logger_connector/result.py | 4 +-- tests/core/test_metric_result_integration.py | 27 +++++++++---------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 3a268c4fdada6..a74ca48d32e98 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -415,7 +415,7 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]: metrics = {k: {} for k in MetricSource} - for key, result_metric in self.valid_items(): + for _, result_metric in self.valid_items(): # extract forward_cache or computed from the ResultMetric. ignore when the output is None value = apply_to_collection( @@ -542,7 +542,7 @@ def to_state_dict( for k, v in self.items() } - def load_from_state_dict(self, state_dict: Dict[str, Any], metrics: Optional[Dict[str, Metric]] = None) -> None: + def load_from_state_dict(self, state_dict: Dict[str, Any]) -> None: def to_result_metric_collection(item: _ResultMetricCollectionSerializationHelper) -> ResultCollection: result_metric_collection = ResultMetricCollection() diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 4b54e8677d320..4442ac58bd4b2 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -25,9 +25,10 @@ class DummyMetric(Metric): - def __init__(self): + def __init__(self, name: str = None): super().__init__() self.add_state("x", torch.tensor(0), dist_reduce_fx="sum") + self.name = name def update(self, x): self.x += x @@ -35,6 +36,9 @@ def update(self, x): def compute(self): return self.x + def extra_repr(self) -> str: + return str(self.name) if self.name else '' + def _setup_ddp(rank, worldsize): import os @@ -186,9 +190,10 @@ def test_result_collection_restoration(): result = ResultCollection(True, torch.device("cpu")) _result = None - metric_a = DummyMetric() - metric_b = DummyMetric() - metric_c = DummyMetric() + metric_a = DummyMetric('a') + metric_b = DummyMetric('b') + metric_c = DummyMetric('c') + metric_d = DummyMetric('d') current_fx_name = None batch_idx = None @@ -208,10 +213,12 @@ def lightning_log(fx, *args, **kwargs): a = metric_a(i) b = metric_b(i) c = metric_c(i) + metric_d(i) cumulative_sum += i - lightning_log('training_step', 'a', metric_a, on_step=True, on_epoch=True) + metric = metric_a if i < 1 else metric_d + lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True) lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True) lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False) lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True) @@ -221,19 +228,11 @@ def lightning_log(fx, *args, **kwargs): batch_log = result.metrics(on_step=True)[MetricSource.LOG] assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"} assert set(batch_log['c_1']) == {'1', '2'} - _result = deepcopy(result) state_dict = result.state_dict() result = ResultCollection(True, torch.device("cpu")) - result.load_from_state_dict( - state_dict, { - "metric_a": metric_a, - "metric_b": metric_b, - "metric_c": metric_c, - "metric_a_end": metric_a - } - ) + result.load_from_state_dict(state_dict) assert _result.items() == result.items() assert _result["training_step.c_1"].meta == result["training_step.c_1"].meta From e7644dea2d9d7a1e2156a119849ef40aff19b3d4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 09:14:08 +0100 Subject: [PATCH 08/90] update --- .../connectors/logger_connector/result.py | 7 +++++ tests/core/test_metric_result_integration.py | 31 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index a74ca48d32e98..bb362b0c84bf0 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -351,6 +351,11 @@ def log( group=sync_dist_group, ) ) + + # the reduce function was drop while saving a checkpoint. + if key in self and self[key].meta.sync.fn is None: + self[key].meta.sync.fn = meta.sync.fn + if key not in self: self.register_key(key, meta, value) elif meta != self[key].meta: @@ -535,6 +540,8 @@ def to_state_dict( return _ResultMetricCollectionSerializationHelper( apply_to_collection(item, ResultMetric, to_state_dict), metadata=item.meta ) + state = item.__getstate__() + state["meta"].sync.fn = None return _ResultMetricSerializationHelper(**item.__getstate__()) return { diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 4442ac58bd4b2..5d97ec291415f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -19,7 +19,10 @@ from torchmetrics import Metric import tests.helpers.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection +from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -259,3 +262,31 @@ def lightning_log(fx, *args, **kwargs): assert metric_a.x == metric_a._defaults['x'] assert metric_b.x == metric_b._defaults['x'] assert metric_c.x == metric_c._defaults['x'] + + +def test_lightning_module_logging_result_collection(tmpdir): + + class LoggingModel(BoringModel): + + def __init__(self): + super().__init__() + self.metric = DummyMetric() + + def training_step(self, batch, batch_idx): + v = self.metric(batch_idx) + self.log_dict({"v": v, "m": self.metric}) + return super().training_step(batch, batch_idx) + + def on_save_checkpoint(self, checkpoint) -> None: + state_dict = self.trainer.train_loop.results.state_dict() + checkpoint["result_collections"] = state_dict + self.trainer.train_loop.results.load_from_state_dict(state_dict) + assert self.trainer.train_loop.results['training_step.v'].meta.sync.fn is None + return super().on_save_checkpoint(checkpoint) + + model = LoggingModel() + ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, max_epochs=3, limit_train_batches=2, limit_val_batches=2, callbacks=[ckpt] + ) + trainer.fit(model) From 3a1019e18c60d60082891a08a3f35a1163051e86 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 08:40:13 +0100 Subject: [PATCH 09/90] updat --- .../connectors/checkpoint_connector.py | 24 +++- tests/checkpointing/test_model_checkpoint.py | 106 +++++++++++------- tmp.p | Bin 0 -> 241 bytes 3 files changed, 83 insertions(+), 47 deletions(-) create mode 100644 tmp.p diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 6711ef3cb748e..b5b9b2df9faa1 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,6 +21,7 @@ import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -295,11 +296,16 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'epoch': training epoch 'global_step': training global step 'pytorch-lightning_version': PyTorch Lightning's version - 'callbacks': "callback specific state"[] # if not weights_only - 'optimizer_states': "PT optim's state_dict"[] # if not weights_only - 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only - 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp - 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp + 'callbacks': "callback specific state"[] # if not weights_only + 'optimizer_states': "PT optim's state_dict"[] # if not weights_only + 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only + 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp + 'result_collections': { + "train": PT TrainLoop ResultCollection state_dict + "validation": PT ValidationLoop ResultCollection state_dict + "test": PT TestLoop ResultCollection state_dict + } + 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp 'state_dict': Model's state_dict (e.g. network weights) CHECKPOINT_HYPER_PARAMS_NAME: CHECKPOINT_HYPER_PARAMS_KEY: @@ -325,6 +331,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'global_step': global_step, 'pytorch-lightning_version': pytorch_lightning.__version__, 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), + # "result_collections": self.get_result_collections_state_dict() } if not weights_only: @@ -365,6 +372,13 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint + def get_result_collections_state_dict(self) -> Dict[str, Dict[str, Any]]: + return { + RunningStage.TRAINING.value: self.trainer.train_loop.results.state_dict(), + RunningStage.VALIDATING.value: self.trainer.evaluation_loop._val_results.state_dict(), + RunningStage.TESTING.value: self.trainer.evaluation_loop._test_results.state_dict(), + } + def hpc_load(self, checkpoint_path: str, on_gpu: bool): """ Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 62b9d8364b01c..d34ae13f80518 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -55,6 +55,61 @@ def validation_epoch_end(self, outputs): self.log('val_acc', outs) +class CustomBoringModelScoreAndCkpt(BoringModel): + + def __init__( + self, + max_epochs: int, + limit_train_batches: int, + limit_val_batches: int, + reduce_lr_on_plateau: bool, + monitor: str, + lr: float = 1e-1, + gamma: int = 2, + ): + super().__init__() + self.train_log_epochs = torch.randn(max_epochs, limit_train_batches) + self.val_logs = torch.randn(max_epochs, limit_val_batches) + self.scores = [] + self.reduce_lr_on_plateau = reduce_lr_on_plateau + self.monitor = monitor + self.lr = lr + self.gamma = gamma + + def training_step(self, batch, batch_idx): + log_value = self.train_log_epochs[self.current_epoch, batch_idx] + self.log('train_log', log_value, on_epoch=True) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + log_value = self.val_logs[self.current_epoch, batch_idx] + self.log('val_log', log_value) + self.log('epoch', self.current_epoch, on_epoch=True) + return super().validation_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer = optim.SGD(self.parameters(), lr=self.lr) + + if self.reduce_lr_on_plateau: + lr_scheduler = { + 'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), + 'monitor': self.monitor, + 'strict': True, + } + else: + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.gamma) + + return [optimizer], [lr_scheduler] + + def on_train_epoch_end(self): + if 'train' in self.monitor: + self.scores.append(self.trainer.logged_metrics[self.monitor]) + + def on_validation_epoch_end(self): + if not self.trainer.sanity_checking and 'val' in self.monitor: + self.scores.append(self.trainer.logged_metrics[self.monitor]) + + @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize( "validation_step_none,val_dataloaders_none,monitor", @@ -77,52 +132,19 @@ def test_model_checkpoint_score_and_ckpt( limit_val_batches = 7 lr, gamma = 1e-1, 2 - class CustomBoringModel(BoringModel): - - def __init__(self): - super().__init__() - self.train_log_epochs = torch.randn(max_epochs, limit_train_batches) - self.val_logs = torch.randn(max_epochs, limit_val_batches) - self.scores = [] - - def training_step(self, batch, batch_idx): - log_value = self.train_log_epochs[self.current_epoch, batch_idx] - self.log('train_log', log_value, on_epoch=True) - return super().training_step(batch, batch_idx) - - def validation_step(self, batch, batch_idx): - log_value = self.val_logs[self.current_epoch, batch_idx] - self.log('val_log', log_value) - self.log('epoch', self.current_epoch, on_epoch=True) - return super().validation_step(batch, batch_idx) - - def configure_optimizers(self): - optimizer = optim.SGD(self.parameters(), lr=lr) - - if reduce_lr_on_plateau: - lr_scheduler = { - 'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), - 'monitor': monitor, - 'strict': True, - } - else: - lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma) - - return [optimizer], [lr_scheduler] - - def on_train_epoch_end(self): - if 'train' in monitor: - self.scores.append(self.trainer.logged_metrics[monitor]) - - def on_validation_epoch_end(self): - if not self.trainer.sanity_checking and 'val' in monitor: - self.scores.append(self.trainer.logged_metrics[monitor]) + model = CustomBoringModelScoreAndCkpt( + max_epochs=max_epochs, + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + reduce_lr_on_plateau=reduce_lr_on_plateau, + monitor=monitor, + lr=lr, + gamma=gamma, + ) filename = '{' + f'{monitor}' + ':.4f}-{epoch}' checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) - model = CustomBoringModel() - if validation_step_none: model.validation_step = None if val_dataloaders_none: diff --git a/tmp.p b/tmp.p new file mode 100644 index 0000000000000000000000000000000000000000..29086397fdb56e3b236de68ebcf742d2ebda228b GIT binary patch literal 241 zcmWIWW@cev;NW1u0Q?NX42ea_8JT6N`ems_#hLkeZch9RQK-O}E5Mtb Date: Mon, 14 Jun 2021 09:17:02 +0100 Subject: [PATCH 10/90] update --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index b5b9b2df9faa1..64bcdf0226298 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -331,7 +331,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'global_step': global_step, 'pytorch-lightning_version': pytorch_lightning.__version__, 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), - # "result_collections": self.get_result_collections_state_dict() + 'result_collections': self.get_result_collections_state_dict() } if not weights_only: From 659a25a3308c2286015c7ddd2de1fde6cf177cb3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 09:17:43 +0100 Subject: [PATCH 11/90] resolve test --- tests/core/test_metric_result_integration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 5d97ec291415f..39b70471330e5 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -75,9 +75,9 @@ def _ddp_test_fn(rank, worldsize): cumulative_sum += i - result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a") - result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") - result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") + result.log('h', 'a', metric_a, on_step=True, on_epoch=True) + result.log('h', 'b', metric_b, on_step=False, on_epoch=True) + result.log('h', 'c', metric_c, on_step=True, on_epoch=False) batch_log = result.metrics(True)[MetricSource.LOG] assert batch_log == {"a_step": i, "c": i} From 8ab34ce0fb2a5f927540f4016255fa3b1deae7d3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 08:40:13 +0100 Subject: [PATCH 12/90] updat --- .../connectors/checkpoint_connector.py | 24 +++- tests/checkpointing/test_model_checkpoint.py | 106 +++++++++++------- tmp.p | Bin 0 -> 241 bytes 3 files changed, 83 insertions(+), 47 deletions(-) create mode 100644 tmp.p diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 6711ef3cb748e..b5b9b2df9faa1 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,6 +21,7 @@ import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -295,11 +296,16 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'epoch': training epoch 'global_step': training global step 'pytorch-lightning_version': PyTorch Lightning's version - 'callbacks': "callback specific state"[] # if not weights_only - 'optimizer_states': "PT optim's state_dict"[] # if not weights_only - 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only - 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp - 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp + 'callbacks': "callback specific state"[] # if not weights_only + 'optimizer_states': "PT optim's state_dict"[] # if not weights_only + 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only + 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp + 'result_collections': { + "train": PT TrainLoop ResultCollection state_dict + "validation": PT ValidationLoop ResultCollection state_dict + "test": PT TestLoop ResultCollection state_dict + } + 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp 'state_dict': Model's state_dict (e.g. network weights) CHECKPOINT_HYPER_PARAMS_NAME: CHECKPOINT_HYPER_PARAMS_KEY: @@ -325,6 +331,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'global_step': global_step, 'pytorch-lightning_version': pytorch_lightning.__version__, 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), + # "result_collections": self.get_result_collections_state_dict() } if not weights_only: @@ -365,6 +372,13 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint + def get_result_collections_state_dict(self) -> Dict[str, Dict[str, Any]]: + return { + RunningStage.TRAINING.value: self.trainer.train_loop.results.state_dict(), + RunningStage.VALIDATING.value: self.trainer.evaluation_loop._val_results.state_dict(), + RunningStage.TESTING.value: self.trainer.evaluation_loop._test_results.state_dict(), + } + def hpc_load(self, checkpoint_path: str, on_gpu: bool): """ Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 62b9d8364b01c..d34ae13f80518 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -55,6 +55,61 @@ def validation_epoch_end(self, outputs): self.log('val_acc', outs) +class CustomBoringModelScoreAndCkpt(BoringModel): + + def __init__( + self, + max_epochs: int, + limit_train_batches: int, + limit_val_batches: int, + reduce_lr_on_plateau: bool, + monitor: str, + lr: float = 1e-1, + gamma: int = 2, + ): + super().__init__() + self.train_log_epochs = torch.randn(max_epochs, limit_train_batches) + self.val_logs = torch.randn(max_epochs, limit_val_batches) + self.scores = [] + self.reduce_lr_on_plateau = reduce_lr_on_plateau + self.monitor = monitor + self.lr = lr + self.gamma = gamma + + def training_step(self, batch, batch_idx): + log_value = self.train_log_epochs[self.current_epoch, batch_idx] + self.log('train_log', log_value, on_epoch=True) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + log_value = self.val_logs[self.current_epoch, batch_idx] + self.log('val_log', log_value) + self.log('epoch', self.current_epoch, on_epoch=True) + return super().validation_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer = optim.SGD(self.parameters(), lr=self.lr) + + if self.reduce_lr_on_plateau: + lr_scheduler = { + 'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), + 'monitor': self.monitor, + 'strict': True, + } + else: + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.gamma) + + return [optimizer], [lr_scheduler] + + def on_train_epoch_end(self): + if 'train' in self.monitor: + self.scores.append(self.trainer.logged_metrics[self.monitor]) + + def on_validation_epoch_end(self): + if not self.trainer.sanity_checking and 'val' in self.monitor: + self.scores.append(self.trainer.logged_metrics[self.monitor]) + + @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize( "validation_step_none,val_dataloaders_none,monitor", @@ -77,52 +132,19 @@ def test_model_checkpoint_score_and_ckpt( limit_val_batches = 7 lr, gamma = 1e-1, 2 - class CustomBoringModel(BoringModel): - - def __init__(self): - super().__init__() - self.train_log_epochs = torch.randn(max_epochs, limit_train_batches) - self.val_logs = torch.randn(max_epochs, limit_val_batches) - self.scores = [] - - def training_step(self, batch, batch_idx): - log_value = self.train_log_epochs[self.current_epoch, batch_idx] - self.log('train_log', log_value, on_epoch=True) - return super().training_step(batch, batch_idx) - - def validation_step(self, batch, batch_idx): - log_value = self.val_logs[self.current_epoch, batch_idx] - self.log('val_log', log_value) - self.log('epoch', self.current_epoch, on_epoch=True) - return super().validation_step(batch, batch_idx) - - def configure_optimizers(self): - optimizer = optim.SGD(self.parameters(), lr=lr) - - if reduce_lr_on_plateau: - lr_scheduler = { - 'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), - 'monitor': monitor, - 'strict': True, - } - else: - lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma) - - return [optimizer], [lr_scheduler] - - def on_train_epoch_end(self): - if 'train' in monitor: - self.scores.append(self.trainer.logged_metrics[monitor]) - - def on_validation_epoch_end(self): - if not self.trainer.sanity_checking and 'val' in monitor: - self.scores.append(self.trainer.logged_metrics[monitor]) + model = CustomBoringModelScoreAndCkpt( + max_epochs=max_epochs, + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + reduce_lr_on_plateau=reduce_lr_on_plateau, + monitor=monitor, + lr=lr, + gamma=gamma, + ) filename = '{' + f'{monitor}' + ':.4f}-{epoch}' checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) - model = CustomBoringModel() - if validation_step_none: model.validation_step = None if val_dataloaders_none: diff --git a/tmp.p b/tmp.p new file mode 100644 index 0000000000000000000000000000000000000000..29086397fdb56e3b236de68ebcf742d2ebda228b GIT binary patch literal 241 zcmWIWW@cev;NW1u0Q?NX42ea_8JT6N`ems_#hLkeZch9RQK-O}E5Mtb Date: Mon, 14 Jun 2021 09:17:02 +0100 Subject: [PATCH 13/90] update --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index b5b9b2df9faa1..64bcdf0226298 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -331,7 +331,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'global_step': global_step, 'pytorch-lightning_version': pytorch_lightning.__version__, 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), - # "result_collections": self.get_result_collections_state_dict() + 'result_collections': self.get_result_collections_state_dict() } if not weights_only: From b774c3424cd1759689b62791882476fe552e5575 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 10:34:05 +0100 Subject: [PATCH 14/90] update --- CHANGELOG.md | 2 +- .../connectors/checkpoint_connector.py | 23 ++++++++ .../connectors/logger_connector/result.py | 3 +- tests/checkpointing/test_model_checkpoint.py | 54 +++++++++++++++++++ tests/core/test_metric_result_integration.py | 4 +- 5 files changed, 82 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c0600668a2402..5d552bae14867 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -121,7 +121,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Each of the training loops now keeps its own results collection ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891)) * Remove `EpochResultStore` and `HookResultStore` in favor of `ResultCollection` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) * Remove `MetricsHolder` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) - * Add `load_from_state_dict` to ResultCollection ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) + * Add `load_state_dict` to ResultCollection ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 64bcdf0226298..3ee22ae84c329 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -160,6 +160,9 @@ def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: self.restore_optimizers_and_schedulers() + # restore logging values + self.restore_result_collections() + def restore_callbacks(self) -> None: """ Restores all callbacks from the pre-loaded checkpoint. """ if not self._loaded_checkpoint: @@ -218,6 +221,26 @@ def restore_optimizers_and_schedulers(self) -> None: self.restore_optimizers() self.restore_lr_schedulers() + def restore_result_collections(self) -> None: + """ Restores the loop result collections used durint logging """ + if not self._loaded_checkpoint: + return + + state_dict = self._loaded_checkpoint.get('result_collections', None) + if state_dict: + # get current reduce function + sync_fn = self.trainer.training_type_plugin.reduce + + # get current result collections + train_results = self.trainer.train_loop.results + val_results = self.trainer.evaluation_loop._val_results + test_results = self.trainer.evaluation_loop._test_results + + # restore collection and provide sync_fn + train_results.load_state_dict(state_dict[RunningStage.TRAINING.value], sync_fn=sync_fn) + val_results.load_state_dict(state_dict[RunningStage.VALIDATING.value], sync_fn=sync_fn) + test_results.load_state_dict(state_dict[RunningStage.TESTING.value], sync_fn=sync_fn) + def restore_optimizers(self) -> None: """ Restores the optimizer states from the pre-loaded checkpoint. """ if not self._loaded_checkpoint: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index bb362b0c84bf0..c41bd6aa9973c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -549,7 +549,7 @@ def to_state_dict( for k, v in self.items() } - def load_from_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any], sync_fn: Optional[Callable] = None) -> None: def to_result_metric_collection(item: _ResultMetricCollectionSerializationHelper) -> ResultCollection: result_metric_collection = ResultMetricCollection() @@ -565,6 +565,7 @@ def _to_device(item: ResultMetric) -> ResultMetric: def to_result_metric(item: _ResultMetricSerializationHelper) -> ResultMetric: result_metric = ResultMetric(item["meta"], item["is_tensor"]) result_metric.__dict__.update(item) + result_metric.meta.sync.fn = sync_fn return result_metric.to(self.device) state_dict = { diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index d34ae13f80518..b35b0aac30252 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -18,6 +18,7 @@ import re import time from argparse import Namespace +from contextlib import suppress from datetime import timedelta from logging import INFO from pathlib import Path @@ -1339,3 +1340,56 @@ def test_trainer_checkpoint_callback_bool(tmpdir): mc = ModelCheckpoint(dirpath=tmpdir) with pytest.raises(MisconfigurationException, match="Invalid type provided for checkpoint_callback"): Trainer(checkpoint_callback=mc) + + +def test_result_collection_reload(tmpdir): + """ + This test validates that the checkpoint can be called when provided to callbacks list + """ + + class ExtendedBoringModel(BoringModel): + + global_step = 0 + + def on_train_start(self) -> None: + assert self.trainer.global_step == self.global_step + + def training_step(self, batch, batch_idx): + print(batch_idx) + if batch_idx == 2: + raise Exception + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + value = self.trainer.train_loop.results['training_step.tracking'].value + assert value == sum(range(batch_idx + 1)) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log("val_loss", loss) + + model = ExtendedBoringModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=5, + limit_val_batches=2, + ) + with suppress(Exception): + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'ckpt.pt') + trainer.save_checkpoint(checkpoint_path) + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=5, + limit_val_batches=2, + resume_from_checkpoint=checkpoint_path + ) + assert trainer.global_step == 0 + model.global_step = 3 + trainer.fit(model) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 39b70471330e5..4f61c1c9f01ab 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -235,7 +235,7 @@ def lightning_log(fx, *args, **kwargs): state_dict = result.state_dict() result = ResultCollection(True, torch.device("cpu")) - result.load_from_state_dict(state_dict) + result.load_state_dict(state_dict) assert _result.items() == result.items() assert _result["training_step.c_1"].meta == result["training_step.c_1"].meta @@ -280,7 +280,7 @@ def training_step(self, batch, batch_idx): def on_save_checkpoint(self, checkpoint) -> None: state_dict = self.trainer.train_loop.results.state_dict() checkpoint["result_collections"] = state_dict - self.trainer.train_loop.results.load_from_state_dict(state_dict) + self.trainer.train_loop.results.load_state_dict(state_dict) assert self.trainer.train_loop.results['training_step.v'].meta.sync.fn is None return super().on_save_checkpoint(checkpoint) From c453994dabb33fc25baf3e8bfc6b1e85b1721e03 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 10:38:12 +0100 Subject: [PATCH 15/90] update on comments --- CHANGELOG.md | 2 +- .../trainer/connectors/logger_connector/result.py | 4 +++- tests/core/test_metric_result_integration.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c0600668a2402..5d552bae14867 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -121,7 +121,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Each of the training loops now keeps its own results collection ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891)) * Remove `EpochResultStore` and `HookResultStore` in favor of `ResultCollection` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) * Remove `MetricsHolder` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) - * Add `load_from_state_dict` to ResultCollection ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) + * Add `load_state_dict` to ResultCollection ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index bb362b0c84bf0..baedb8c39dc29 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -549,7 +549,7 @@ def to_state_dict( for k, v in self.items() } - def load_from_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any], sync_fn: Optional[Callable] = None) -> None: def to_result_metric_collection(item: _ResultMetricCollectionSerializationHelper) -> ResultCollection: result_metric_collection = ResultMetricCollection() @@ -560,11 +560,13 @@ def _to_device(item: ResultMetric) -> ResultMetric: result_metric_collection = apply_to_collection(result_metric_collection, ResultMetric, _to_device) result_metric_collection.meta = item.meta + result_metric_collection.meta.sync.fn = sync_fn return result_metric_collection def to_result_metric(item: _ResultMetricSerializationHelper) -> ResultMetric: result_metric = ResultMetric(item["meta"], item["is_tensor"]) result_metric.__dict__.update(item) + result_metric.meta.sync.fn = sync_fn return result_metric.to(self.device) state_dict = { diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 39b70471330e5..91a69e51384f0 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -235,7 +235,7 @@ def lightning_log(fx, *args, **kwargs): state_dict = result.state_dict() result = ResultCollection(True, torch.device("cpu")) - result.load_from_state_dict(state_dict) + result.load_state_dict(state_dict, sync_fn=_result['training_step.a'].meta.sync.fn) assert _result.items() == result.items() assert _result["training_step.c_1"].meta == result["training_step.c_1"].meta @@ -280,7 +280,7 @@ def training_step(self, batch, batch_idx): def on_save_checkpoint(self, checkpoint) -> None: state_dict = self.trainer.train_loop.results.state_dict() checkpoint["result_collections"] = state_dict - self.trainer.train_loop.results.load_from_state_dict(state_dict) + self.trainer.train_loop.results.load_state_dict(state_dict) assert self.trainer.train_loop.results['training_step.v'].meta.sync.fn is None return super().on_save_checkpoint(checkpoint) From 9f46a99dc8e848e462de21588f69a2df70f1ed50 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 10:51:37 +0100 Subject: [PATCH 16/90] add test --- tests/checkpointing/test_model_checkpoint.py | 51 +++++++++++--------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index b35b0aac30252..68259210babc2 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1344,38 +1344,44 @@ def test_trainer_checkpoint_callback_bool(tmpdir): def test_result_collection_reload(tmpdir): """ - This test validates that the checkpoint can be called when provided to callbacks list + This is a temporary test to assert ResultCollection is properly reloaded. + The test is done over 2 epochs as Lightning doesn't support restarting middle of an epoch yet. + todo: (tchaton) Update this test when restart in middle of an epoch is supported. """ class ExtendedBoringModel(BoringModel): - global_step = 0 - - def on_train_start(self) -> None: - assert self.trainer.global_step == self.global_step + has_reloaded = False + breaking_batch_idx = 2 def training_step(self, batch, batch_idx): - print(batch_idx) - if batch_idx == 2: - raise Exception - self.log("tracking", batch_idx, on_step=True, on_epoch=True) - value = self.trainer.train_loop.results['training_step.tracking'].value - assert value == sum(range(batch_idx + 1)) - return super().training_step(batch, batch_idx) - - def validation_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - self.log("val_loss", loss) + if self.has_reloaded: + if batch_idx >= self.breaking_batch_idx: + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + value = self.trainer.train_loop.results['training_step.tracking'].value + assert value == sum(range(self.breaking_batch_idx, batch_idx + 1)) + 1 + return super().training_step(batch, batch_idx) + else: + if self.trainer.current_epoch == 1: + return + if batch_idx == self.breaking_batch_idx: + raise Exception + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + value = self.trainer.train_loop.results['training_step.tracking'].value + assert value == sum(range(batch_idx + 1)) + return super().training_step(batch, batch_idx) + + def on_epoch_end(self) -> None: + if self.trainer.current_epoch: + assert self.trainer.train_loop.results['training_step.tracking'].value == sum(range(5)) model = ExtendedBoringModel() - model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=5, - limit_val_batches=2, + limit_val_batches=0, ) with suppress(Exception): trainer.fit(model) @@ -1385,11 +1391,10 @@ def validation_step(self, batch, batch_idx): trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1, + max_epochs=2, limit_train_batches=5, - limit_val_batches=2, + limit_val_batches=0, resume_from_checkpoint=checkpoint_path ) - assert trainer.global_step == 0 - model.global_step = 3 + model.has_reloaded = True trainer.fit(model) From d80eb0079fec0212b9f4ebf1ec5c731c64e0cfe1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 10:58:16 +0100 Subject: [PATCH 17/90] remove tmp.p --- tmp.p | Bin 241 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tmp.p diff --git a/tmp.p b/tmp.p deleted file mode 100644 index 29086397fdb56e3b236de68ebcf742d2ebda228b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 241 zcmWIWW@cev;NW1u0Q?NX42ea_8JT6N`ems_#hLkeZch9RQK-O}E5Mtb Date: Mon, 14 Jun 2021 11:19:16 +0100 Subject: [PATCH 18/90] bypass typing bug --- .../trainer/connectors/logger_connector/result.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index baedb8c39dc29..559d384f6e58c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -28,7 +28,8 @@ from pytorch_lightning.utilities.metrics import metrics_to_scalars # re-define the ones from pytorch_lightning.utilities.types without the `Number` type -_METRIC = Union[Metric, torch.Tensor] +# todo (tchaton) Resolve this typing bug in python 3.6 +_METRIC = Any # Union[Metric, torch.Tensor] _METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] From cc23140d730ac4e325640d79eb566e28c17f06d4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 11:31:31 +0100 Subject: [PATCH 19/90] add deepcopy to keep sync_fn target --- .../connectors/logger_connector/result.py | 16 +++++++++++++--- tests/trainer/logging_/test_logger_connector.py | 1 + 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 559d384f6e58c..b3addc344a717 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Generator +from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union @@ -547,10 +548,17 @@ def to_state_dict( return { k: apply_to_collection(v, (ResultMetric, ResultMetricCollection), to_state_dict) - for k, v in self.items() + for k, v in deepcopy(list(self.items())) } def load_state_dict(self, state_dict: Dict[str, Any], sync_fn: Optional[Callable] = None) -> None: + """ + This function is used to restore the ResultCollection state + + Args: + state_dict: Dict containing the serialized ResultCollection state. + sync_fn: Optional function used to reduce metric across processes. + """ def to_result_metric_collection(item: _ResultMetricCollectionSerializationHelper) -> ResultCollection: result_metric_collection = ResultMetricCollection() @@ -561,13 +569,15 @@ def _to_device(item: ResultMetric) -> ResultMetric: result_metric_collection = apply_to_collection(result_metric_collection, ResultMetric, _to_device) result_metric_collection.meta = item.meta - result_metric_collection.meta.sync.fn = sync_fn + if sync_fn: + result_metric_collection.meta.sync.fn = sync_fn return result_metric_collection def to_result_metric(item: _ResultMetricSerializationHelper) -> ResultMetric: result_metric = ResultMetric(item["meta"], item["is_tensor"]) result_metric.__dict__.update(item) - result_metric.meta.sync.fn = sync_fn + if sync_fn: + result_metric.meta.sync.fn = sync_fn return result_metric.to(self.device) state_dict = { diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index d93054439082b..e2f2761bf752c 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -387,6 +387,7 @@ def _assert_called(model, stage): max_epochs=1, progress_bar_refresh_rate=0, num_sanity_val_steps=2, + checkpoint_callback=False, ) trainer.fit(model) From d1529852df89555a25d46ddbaaa3a08fe9e96515 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 11:43:01 +0100 Subject: [PATCH 20/90] add changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d552bae14867..b910822007b84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -122,6 +122,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Remove `EpochResultStore` and `HookResultStore` in favor of `ResultCollection` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) * Remove `MetricsHolder` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) * Add `load_state_dict` to ResultCollection ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) + * Add `result_collections` to checkpoint and `restore_result_collections` to `CheckpointConnector` ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) From b9090124f7c62b9c0ac1a3d2deb90dc55ec7b046 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 11:45:55 +0100 Subject: [PATCH 21/90] remove test changes --- tests/checkpointing/test_model_checkpoint.py | 106 ++++++++----------- 1 file changed, 42 insertions(+), 64 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 68259210babc2..10893054b3358 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -56,61 +56,6 @@ def validation_epoch_end(self, outputs): self.log('val_acc', outs) -class CustomBoringModelScoreAndCkpt(BoringModel): - - def __init__( - self, - max_epochs: int, - limit_train_batches: int, - limit_val_batches: int, - reduce_lr_on_plateau: bool, - monitor: str, - lr: float = 1e-1, - gamma: int = 2, - ): - super().__init__() - self.train_log_epochs = torch.randn(max_epochs, limit_train_batches) - self.val_logs = torch.randn(max_epochs, limit_val_batches) - self.scores = [] - self.reduce_lr_on_plateau = reduce_lr_on_plateau - self.monitor = monitor - self.lr = lr - self.gamma = gamma - - def training_step(self, batch, batch_idx): - log_value = self.train_log_epochs[self.current_epoch, batch_idx] - self.log('train_log', log_value, on_epoch=True) - return super().training_step(batch, batch_idx) - - def validation_step(self, batch, batch_idx): - log_value = self.val_logs[self.current_epoch, batch_idx] - self.log('val_log', log_value) - self.log('epoch', self.current_epoch, on_epoch=True) - return super().validation_step(batch, batch_idx) - - def configure_optimizers(self): - optimizer = optim.SGD(self.parameters(), lr=self.lr) - - if self.reduce_lr_on_plateau: - lr_scheduler = { - 'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), - 'monitor': self.monitor, - 'strict': True, - } - else: - lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.gamma) - - return [optimizer], [lr_scheduler] - - def on_train_epoch_end(self): - if 'train' in self.monitor: - self.scores.append(self.trainer.logged_metrics[self.monitor]) - - def on_validation_epoch_end(self): - if not self.trainer.sanity_checking and 'val' in self.monitor: - self.scores.append(self.trainer.logged_metrics[self.monitor]) - - @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize( "validation_step_none,val_dataloaders_none,monitor", @@ -133,19 +78,52 @@ def test_model_checkpoint_score_and_ckpt( limit_val_batches = 7 lr, gamma = 1e-1, 2 - model = CustomBoringModelScoreAndCkpt( - max_epochs=max_epochs, - limit_train_batches=limit_train_batches, - limit_val_batches=limit_val_batches, - reduce_lr_on_plateau=reduce_lr_on_plateau, - monitor=monitor, - lr=lr, - gamma=gamma, - ) + class CustomBoringModel(BoringModel): + + def __init__(self): + super().__init__() + self.train_log_epochs = torch.randn(max_epochs, limit_train_batches) + self.val_logs = torch.randn(max_epochs, limit_val_batches) + self.scores = [] + + def training_step(self, batch, batch_idx): + log_value = self.train_log_epochs[self.current_epoch, batch_idx] + self.log('train_log', log_value, on_epoch=True) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + log_value = self.val_logs[self.current_epoch, batch_idx] + self.log('val_log', log_value) + self.log('epoch', self.current_epoch, on_epoch=True) + return super().validation_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer = optim.SGD(self.parameters(), lr=lr) + + if reduce_lr_on_plateau: + lr_scheduler = { + 'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), + 'monitor': monitor, + 'strict': True, + } + else: + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma) + + return [optimizer], [lr_scheduler] + + def on_train_epoch_end(self): + if 'train' in monitor: + self.scores.append(self.trainer.logged_metrics[monitor]) + + def on_validation_epoch_end(self): + if not self.trainer.sanity_checking and 'val' in monitor: + self.scores.append(self.trainer.logged_metrics[monitor]) filename = '{' + f'{monitor}' + ':.4f}-{epoch}' checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) + model = CustomBoringModel() + if validation_step_none: model.validation_step = None if val_dataloaders_none: From 2c13e5ec9213aebcb9b05bdfd2f081147cc740d0 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 11:50:14 +0100 Subject: [PATCH 22/90] typo --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3ee22ae84c329..5f3c9531fa02f 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -222,7 +222,7 @@ def restore_optimizers_and_schedulers(self) -> None: self.restore_lr_schedulers() def restore_result_collections(self) -> None: - """ Restores the loop result collections used durint logging """ + """ Restores the loop result collections used for logging.""" if not self._loaded_checkpoint: return From 158069638bc93f1f212f30824392ddf76525fa94 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 17 Jun 2021 09:32:21 +0100 Subject: [PATCH 23/90] add result collection --- tests/models/test_hooks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 7ab93e9ad2621..849cd4115272f 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -711,7 +711,8 @@ def call(hook, fn, *args, **kwargs): 'lr_schedulers': ANY, 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, - 'state_dict': ANY + 'state_dict': ANY, + 'result_collections': ANY }, ) ), dict(name='teardown', kwargs=dict(stage='fit')), From 5331b8e6a097e9cb42b4e0d72e9adf5a62d7fa16 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 17 Jun 2021 09:54:54 +0100 Subject: [PATCH 24/90] Update pytorch_lightning/trainer/connectors/logger_connector/result.py Co-authored-by: Ethan Harris --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index dfc002ce9e427..0d7dab1f0a4a8 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -622,7 +622,7 @@ def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: cls = ResultMetricCollection else: raise ValueError(f"Unexpected class name: {cls}") - _sync_fn = sync_fn if sync_fn else (self[k].meta.sync.fn if k in self else None) + _sync_fn = sync_fn or (self[k].meta.sync.fn if k in self else None) return cls._reconstruct(item, sync_fn=_sync_fn) items = {k: setstate(k, v) for k, v in state['items'].items()} From e966cfd677a8e180053bd60f4bdff41d9d3f4899 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 17 Jun 2021 10:08:50 +0100 Subject: [PATCH 25/90] update --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 626d733cd7537..5efd9e55b2d81 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -405,7 +405,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint - def get_result_collections_state_dict(self) -> Dict[str, Dict[str, Any]]: + def get_result_collections_state_dict(self): return { RunningStage.TRAINING.value: self.trainer.train_loop.results.state_dict(), RunningStage.VALIDATING.value: self.trainer.evaluation_loop._val_results.state_dict(), From 40ef8d82aec1f0eccb2902eea607f12d98b9ef4f Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 17 Jun 2021 09:26:51 -0400 Subject: [PATCH 26/90] add support for metric and reduction --- .../connectors/checkpoint_connector.py | 8 +- .../connectors/logger_connector/result.py | 16 ++- tests/checkpointing/test_model_checkpoint.py | 123 +++++++++++++----- 3 files changed, 110 insertions(+), 37 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5efd9e55b2d81..4c1c691d00502 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -245,10 +245,12 @@ def restore_result_collections(self) -> None: val_results = self.trainer.evaluation_loop._val_results test_results = self.trainer.evaluation_loop._test_results + should_reset = not self.trainer.is_global_zero + # restore collection and provide sync_fn - train_results.load_state_dict(state_dict[RunningStage.TRAINING.value], sync_fn=sync_fn) - val_results.load_state_dict(state_dict[RunningStage.VALIDATING.value], sync_fn=sync_fn) - test_results.load_state_dict(state_dict[RunningStage.TESTING.value], sync_fn=sync_fn) + train_results.load_state_dict(state_dict[RunningStage.TRAINING.value], sync_fn=sync_fn, should_reset=should_reset) + val_results.load_state_dict(state_dict[RunningStage.VALIDATING.value], sync_fn=sync_fn, should_reset=should_reset) + test_results.load_state_dict(state_dict[RunningStage.TESTING.value], sync_fn=sync_fn, should_reset=should_reset) def restore_optimizers(self) -> None: """ Restores the optimizer states from the pre-loaded checkpoint. """ diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 0d7dab1f0a4a8..238ca445a4454 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -162,9 +162,9 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: self.meta = metadata self.has_reset = False if is_tensor: - self.add_state("value", torch.tensor(0, dtype=torch.float)) + self.add_state("value", torch.tensor(0, dtype=torch.float), dist_reduce_fx=torch.sum) if self.meta.is_mean_reduction: - self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float)) + self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float), dist_reduce_fx=torch.sum) def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: if self.is_tensor: @@ -241,7 +241,9 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({state})" def __getstate__(self) -> dict: - d = super().__getstate__() + sync_manager = self._apply_sync if self.is_tensor else self.value._apply_sync + with sync_manager(): + d = super().__getstate__() d['meta'] = d['meta'].__getstate__() d['_class'] = self.__class__.__name__ return d @@ -465,7 +467,10 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten cache = result_metric._forward_cache elif not on_step and result_metric.meta.on_epoch: if not result_metric._computed: + should = result_metric.meta.sync.should + result_metric.meta.sync.should = True result_metric.compute() + result_metric.meta.sync.should = should cache = result_metric._computed if cache is not None and not result_metric.meta.enable_graph: return cache.detach() @@ -638,6 +643,9 @@ def load_state_dict( self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None, - sync_fn: Optional[Callable] = None + sync_fn: Optional[Callable] = None, + should_reset: bool = False, ) -> None: self.__setstate__(state_dict, map_location=map_location, sync_fn=sync_fn) + if should_reset: + self.reset() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 10893054b3358..8eaf0f5c67096 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -32,7 +32,7 @@ import yaml from omegaconf import Container, OmegaConf from torch import optim - +from torchmetrics import Metric import pytorch_lightning as pl import tests.helpers.utils as tutils from pytorch_lightning import seed_everything, Trainer @@ -42,6 +42,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.runif import RunIf +from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource + class LogInTwoMethods(BoringModel): @@ -1320,59 +1322,120 @@ def test_trainer_checkpoint_callback_bool(tmpdir): Trainer(checkpoint_callback=mc) -def test_result_collection_reload(tmpdir): - """ - This is a temporary test to assert ResultCollection is properly reloaded. - The test is done over 2 epochs as Lightning doesn't support restarting middle of an epoch yet. - todo: (tchaton) Update this test when restart in middle of an epoch is supported. - """ +def result_collection_reload(trainer_kwargs): + num_processes = trainer_kwargs.get("gpus", 1) + + class DummyMetric(Metric): + + def __init__(self): + super().__init__() + self.add_state("sum", torch.tensor(0), dist_reduce_fx=torch.sum) + self.add_state("count", torch.tensor(0), dist_reduce_fx=torch.sum) + + def update(self, increment): + self.sum += increment + self.count += 1 + + def compute(self): + return self.sum / self.count + + def __repr__(self): + return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})" + + + class CustomException(Exception): + pass class ExtendedBoringModel(BoringModel): - has_reloaded = False - breaking_batch_idx = 2 + def __init__(self): + super().__init__() + self.has_reloaded = False + self.breaking_batch_idx = 3 + self.has_validated_sum = False + self.dummy_metric = DummyMetric() def training_step(self, batch, batch_idx): + assert len(batch) == 1 if self.has_reloaded: if batch_idx >= self.breaking_batch_idx: self.log("tracking", batch_idx, on_step=True, on_epoch=True) + + self.dummy_metric(batch_idx) + self.log("tracking_metric", batch_idx, on_step=True, on_epoch=True) + value = self.trainer.train_loop.results['training_step.tracking'].value - assert value == sum(range(self.breaking_batch_idx, batch_idx + 1)) + 1 - return super().training_step(batch, batch_idx) + shift = 1 + if num_processes == 2: + shift += 6 if self.trainer.is_global_zero else 0 + expected = sum(range(self.breaking_batch_idx, batch_idx)) + shift else: - if self.trainer.current_epoch == 1: + if self.trainer.current_epoch == 2: return if batch_idx == self.breaking_batch_idx: - raise Exception + raise CustomException + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + + self.dummy_metric(batch_idx) + self.log("tracking_metric", batch_idx, on_step=True, on_epoch=True) value = self.trainer.train_loop.results['training_step.tracking'].value assert value == sum(range(batch_idx + 1)) - return super().training_step(batch, batch_idx) + return super().training_step(batch, batch_idx) def on_epoch_end(self) -> None: if self.trainer.current_epoch: - assert self.trainer.train_loop.results['training_step.tracking'].value == sum(range(5)) + total = sum(range(5)) * num_processes + metrics = self.trainer.train_loop.results.metrics(on_step=False) + assert self.trainer.train_loop.results['training_step.tracking'].value == total + assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2 + self.has_validated_sum = True model = ExtendedBoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=5, - limit_val_batches=0, - ) - with suppress(Exception): + trainer = Trainer(**trainer_kwargs) + + with suppress(CustomException): trainer.fit(model) - checkpoint_path = os.path.join(tmpdir, 'ckpt.pt') + checkpoint_path = trainer.accelerator.broadcast(os.path.join(trainer_kwargs["default_root_dir"], 'ckpt.pt')) trainer.save_checkpoint(checkpoint_path) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=5, - limit_val_batches=0, - resume_from_checkpoint=checkpoint_path - ) + trainer.accelerator.barrier() + + checkpoint = torch.load(checkpoint_path) + items = checkpoint["result_collections"]["train"]["items"] + assert items["training_step.tracking_metric"]["value"] == 6 + assert items["training_step.tracking"]["value"] == 6 + + trainer_kwargs["resume_from_checkpoint"] = checkpoint_path + trainer_kwargs["max_epochs"] = 2 + + trainer = Trainer(**trainer_kwargs) model.has_reloaded = True trainer.fit(model) + assert model.has_validated_sum + +def test_result_collection_reload(tmpdir): + + trainer_kwargs = { + "default_root_dir": tmpdir, + "max_epochs": 1, + "limit_train_batches": 5, + "limit_val_batches": 0, + } + result_collection_reload(trainer_kwargs) + + +@RunIf(min_gpus=2, special=True) +def test_result_collection_reload_2_gpus(tmpdir): + + trainer_kwargs = { + "default_root_dir": tmpdir, + "max_epochs": 1, + "limit_train_batches": 5, + "limit_val_batches": 0, + "accelerator": "ddp", + "gpus": 2, + } + result_collection_reload(trainer_kwargs) \ No newline at end of file From b90916ed6c6ed2c7f012975bdb4724e2b49cbf3e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Jun 2021 13:28:16 +0000 Subject: [PATCH 27/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/checkpoint_connector.py | 12 +++++++++--- tests/checkpointing/test_model_checkpoint.py | 10 +++++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 4c1c691d00502..0097c08d5d1b1 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -248,9 +248,15 @@ def restore_result_collections(self) -> None: should_reset = not self.trainer.is_global_zero # restore collection and provide sync_fn - train_results.load_state_dict(state_dict[RunningStage.TRAINING.value], sync_fn=sync_fn, should_reset=should_reset) - val_results.load_state_dict(state_dict[RunningStage.VALIDATING.value], sync_fn=sync_fn, should_reset=should_reset) - test_results.load_state_dict(state_dict[RunningStage.TESTING.value], sync_fn=sync_fn, should_reset=should_reset) + train_results.load_state_dict( + state_dict[RunningStage.TRAINING.value], sync_fn=sync_fn, should_reset=should_reset + ) + val_results.load_state_dict( + state_dict[RunningStage.VALIDATING.value], sync_fn=sync_fn, should_reset=should_reset + ) + test_results.load_state_dict( + state_dict[RunningStage.TESTING.value], sync_fn=sync_fn, should_reset=should_reset + ) def restore_optimizers(self) -> None: """ Restores the optimizer states from the pre-loaded checkpoint. """ diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 8eaf0f5c67096..63b2f1504f260 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -33,17 +33,17 @@ from omegaconf import Container, OmegaConf from torch import optim from torchmetrics import Metric + import pytorch_lightning as pl import tests.helpers.utils as tutils from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.runif import RunIf -from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource - class LogInTwoMethods(BoringModel): @@ -1342,7 +1342,6 @@ def compute(self): def __repr__(self): return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})" - class CustomException(Exception): pass @@ -1374,7 +1373,7 @@ def training_step(self, batch, batch_idx): return if batch_idx == self.breaking_batch_idx: raise CustomException - + self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.dummy_metric(batch_idx) @@ -1416,6 +1415,7 @@ def on_epoch_end(self) -> None: trainer.fit(model) assert model.has_validated_sum + def test_result_collection_reload(tmpdir): trainer_kwargs = { @@ -1438,4 +1438,4 @@ def test_result_collection_reload_2_gpus(tmpdir): "accelerator": "ddp", "gpus": 2, } - result_collection_reload(trainer_kwargs) \ No newline at end of file + result_collection_reload(trainer_kwargs) From 5f3e4b372bf269e2bfb7ebc9fd15f5c0a89b0743 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 17 Jun 2021 10:58:34 -0400 Subject: [PATCH 28/90] wip --- .../connectors/checkpoint_connector.py | 7 ++- .../connectors/logger_connector/result.py | 7 ++- tests/checkpointing/test_model_checkpoint.py | 46 ++++++++++++------- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 4c1c691d00502..9fbdfaa0ebc43 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -18,7 +18,7 @@ from typing import Any, Dict, Optional, Union import torch - +from torchmetrics import Metric import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.states import RunningStage @@ -252,6 +252,11 @@ def restore_result_collections(self) -> None: val_results.load_state_dict(state_dict[RunningStage.VALIDATING.value], sync_fn=sync_fn, should_reset=should_reset) test_results.load_state_dict(state_dict[RunningStage.TESTING.value], sync_fn=sync_fn, should_reset=should_reset) + if not self.trainer.is_global_zero: + for _, module in self.trainer.lightning_module.named_modules(): + if isinstance(module, Metric): + module.reset() + def restore_optimizers(self) -> None: """ Restores the optimizer states from the pre-loaded checkpoint. """ if not self._loaded_checkpoint: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 238ca445a4454..24c14f7795f30 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -18,7 +18,6 @@ import torch from torchmetrics import Metric - from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections @@ -605,7 +604,7 @@ def __getstate__(self) -> dict: if extra is not None: d['_extra'] = extra # all the items should be either `ResultMetric`s or `ResultMetricCollection`s - items = {k: v.__getstate__() for k, v in self.items() if k != '_extra'} + items = {k: v.__getstate__() for k, v in self.items() if k not in ('_extra', 'fx_validator')} return {**d, 'items': items} def __setstate__( @@ -614,6 +613,7 @@ def __setstate__( map_location: Optional[Union[str, torch.device]] = None, sync_fn: Optional[Callable] = None ) -> None: + self.__dict__.update({k: v for k, v in state.items() if k != 'items'}) def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: @@ -646,6 +646,9 @@ def load_state_dict( sync_fn: Optional[Callable] = None, should_reset: bool = False, ) -> None: + + self.fx_validator = FxValidator() + self.__setstate__(state_dict, map_location=map_location, sync_fn=sync_fn) if should_reset: self.reset() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 8eaf0f5c67096..27c72fba4a637 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1322,26 +1322,26 @@ def test_trainer_checkpoint_callback_bool(tmpdir): Trainer(checkpoint_callback=mc) -def result_collection_reload(trainer_kwargs): - num_processes = trainer_kwargs.get("gpus", 1) +class DummyMetric(Metric): - class DummyMetric(Metric): + def __init__(self): + super().__init__() + self.add_state("sum", torch.tensor(0), dist_reduce_fx=torch.sum) + self.add_state("count", torch.tensor(0), dist_reduce_fx=torch.sum) - def __init__(self): - super().__init__() - self.add_state("sum", torch.tensor(0), dist_reduce_fx=torch.sum) - self.add_state("count", torch.tensor(0), dist_reduce_fx=torch.sum) + def update(self, increment): + self.sum += increment + self.count += 1 - def update(self, increment): - self.sum += increment - self.count += 1 + def compute(self): + return self.sum / self.count - def compute(self): - return self.sum / self.count + def __repr__(self): + return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})" - def __repr__(self): - return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})" +def result_collection_reload(trainer_kwargs): + num_processes = trainer_kwargs.get("gpus", 1) class CustomException(Exception): pass @@ -1354,15 +1354,23 @@ def __init__(self): self.breaking_batch_idx = 3 self.has_validated_sum = False self.dummy_metric = DummyMetric() + self.dummy_metric_dynamic = DummyMetric() def training_step(self, batch, batch_idx): + print() + print(self.trainer.global_rank, self.dummy_metric) + print() assert len(batch) == 1 if self.has_reloaded: if batch_idx >= self.breaking_batch_idx: self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.dummy_metric(batch_idx) - self.log("tracking_metric", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) + + if self.trainer.accelerator_connector.is_distributed: + self.dummy_metric_dynamic(batch_idx + int(self.trainer.global_rank)) + self.log("tracking_metric_2", self.dummy_metric_dynamic, on_step=True, on_epoch=True) value = self.trainer.train_loop.results['training_step.tracking'].value shift = 1 @@ -1378,13 +1386,19 @@ def training_step(self, batch, batch_idx): self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.dummy_metric(batch_idx) - self.log("tracking_metric", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) + + if self.trainer.accelerator_connector.is_distributed: + self.dummy_metric_dynamic(batch_idx + int(self.trainer.global_rank)) + self.log("tracking_metric_2", self.dummy_metric_dynamic, on_step=True, on_epoch=True) + value = self.trainer.train_loop.results['training_step.tracking'].value assert value == sum(range(batch_idx + 1)) return super().training_step(batch, batch_idx) def on_epoch_end(self) -> None: if self.trainer.current_epoch: + print(self.dummy_metric) total = sum(range(5)) * num_processes metrics = self.trainer.train_loop.results.metrics(on_step=False) assert self.trainer.train_loop.results['training_step.tracking'].value == total From f843d297ce5bf993b3d7ada6708bbcea0af4b9c3 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 17 Jun 2021 14:42:09 -0400 Subject: [PATCH 29/90] update --- pytorch_lightning/core/lightning.py | 12 ++++ .../training_type/training_type_plugin.py | 1 - .../connectors/checkpoint_connector.py | 72 +++++++++++++------ .../connectors/logger_connector/result.py | 36 +++++++--- tests/checkpointing/test_model_checkpoint.py | 8 +-- 5 files changed, 96 insertions(+), 33 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a1b2ce3a5e8f3..b81c688747e7b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -111,6 +111,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() + self._map_id_to_metrics_name: Optional[Dict[int, str]] = None def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -360,6 +361,16 @@ def log( # reset any tensors for the new hook name results.reset(metrics=False, fx=self._current_fx_name) + attribute_name = None + + if isinstance(value, Metric): + + gen = self._named_members(lambda module: module._modules.items()) + for module_name, module in gen: + if isinstance(module, Metric): + if value.__getstate__() == module.__getstate__(): + attribute_name = module_name + results.log( self._current_fx_name, name, @@ -375,6 +386,7 @@ def log( sync_dist=sync_dist, sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available, sync_dist_group=sync_dist_group, + attribute_name=attribute_name, ) self.trainer.logger_connector._current_fx = self._current_fx_name diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 4c825a93da290..2d9717a41c1d9 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -21,7 +21,6 @@ from torch.nn import Module from torch.optim import Optimizer from torch.utils.data import DataLoader - import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.base_plugin import Plugin diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 9fbdfaa0ebc43..1b6fde87fd4a9 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -169,8 +169,8 @@ def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: self.restore_optimizers_and_schedulers() - # restore logging values - self.restore_result_collections() + # restore loops + self.restore_loops() def restore_callbacks(self) -> None: """ Restores all callbacks from the pre-loaded checkpoint. """ @@ -230,32 +230,55 @@ def restore_optimizers_and_schedulers(self) -> None: self.restore_optimizers() self.restore_lr_schedulers() + def restore_loops(self) -> None: + """ Restores the loops state_dicts""" + if not self._loaded_checkpoint: + return + + self.restore_result_collections() + def restore_result_collections(self) -> None: """ Restores the loop result collections used for logging.""" if not self._loaded_checkpoint: return - state_dict = self._loaded_checkpoint.get('result_collections', None) - if state_dict: - # get current reduce function - sync_fn = self.trainer.training_type_plugin.reduce + state_dict = self._loaded_checkpoint["loops_state_dict"].get('result_collections', None) + + if not state_dict: + return - # get current result collections - train_results = self.trainer.train_loop.results - val_results = self.trainer.evaluation_loop._val_results - test_results = self.trainer.evaluation_loop._test_results + # get current reduce function + sync_fn = self.trainer.training_type_plugin.reduce - should_reset = not self.trainer.is_global_zero + # get current result collections + train_results = self.trainer.train_loop.results + val_results = self.trainer.evaluation_loop._val_results + test_results = self.trainer.evaluation_loop._test_results - # restore collection and provide sync_fn - train_results.load_state_dict(state_dict[RunningStage.TRAINING.value], sync_fn=sync_fn, should_reset=should_reset) - val_results.load_state_dict(state_dict[RunningStage.VALIDATING.value], sync_fn=sync_fn, should_reset=should_reset) - test_results.load_state_dict(state_dict[RunningStage.TESTING.value], sync_fn=sync_fn, should_reset=should_reset) + metrics = {} + for module_name, module in self.trainer.lightning_module._named_members(lambda module: module._modules.items()): + if isinstance(module, Metric): + metrics[module_name] = module - if not self.trainer.is_global_zero: - for _, module in self.trainer.lightning_module.named_modules(): - if isinstance(module, Metric): - module.reset() + # restore collection and provide sync_fn + self._restore_restore_collection(train_results, state_dict[RunningStage.TRAINING.value], sync_fn, metrics) + self._restore_restore_collection(val_results, state_dict[RunningStage.VALIDATING.value], sync_fn, metrics) + self._restore_restore_collection(train_results, state_dict[RunningStage.TESTING.value], sync_fn, metrics) + + # restore metrics + if not self.trainer.is_global_zero: + for _, module in self.trainer.lightning_module.named_modules(): + if isinstance(module, Metric): + module.reset() + + def _restore_restore_collection(self, results, state_dict, sync_fn, metrics): + results.load_state_dict( + state_dict, + sync_fn=sync_fn, + metrics=metrics + ) + if not self.trainer.is_global_zero: + results.reset() def restore_optimizers(self) -> None: """ Restores the optimizer states from the pre-loaded checkpoint. """ @@ -366,12 +389,16 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: model = self.trainer.lightning_module + for _, module in model.named_modules(): + if isinstance(module, Metric): + module.persistent(True) + checkpoint = { 'epoch': current_epoch, 'global_step': global_step, 'pytorch-lightning_version': pytorch_lightning.__version__, 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), - 'result_collections': self.get_result_collections_state_dict() + 'loops_state_dict': self.get_loops_state_dict() } if not weights_only: @@ -412,6 +439,11 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint + def get_loops_state_dict(self) -> Dict[str, Dict[str, Any]]: + return { + "result_collections": self.get_result_collections_state_dict() + } + def get_result_collections_state_dict(self): return { RunningStage.TRAINING.value: self.trainer.train_loop.results.state_dict(), diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 24c14f7795f30..e5886380e1735 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -14,8 +14,8 @@ from collections.abc import Generator from dataclasses import asdict, dataclass, replace from functools import partial, wraps -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union - +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union, List +from copy import deepcopy import torch from torchmetrics import Metric from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator @@ -69,6 +69,7 @@ class _Metadata: _reduce_fx: Callable = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None + attribute_name: Optional[str] = None _sync: Optional[_Sync] = None @property @@ -240,9 +241,10 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({state})" def __getstate__(self) -> dict: - sync_manager = self._apply_sync if self.is_tensor else self.value._apply_sync - with sync_manager(): - d = super().__getstate__() + with self._apply_sync(): + d = deepcopy(super().__getstate__()) + if not self.is_tensor: + del d["value"] d['meta'] = d['meta'].__getstate__() d['_class'] = self.__class__.__name__ return d @@ -331,6 +333,18 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = self.device: Optional[Union[str, torch.device]] = device self.fx_validator = FxValidator() + @property + def result_metrics(self) -> List[ResultMetric]: + o = [] + for v in self.values(): + if isinstance(v, ResultMetric): + o.append(v) + elif isinstance(v, ResultCollection): + for _v in v.items(): + if isinstance(v, ResultMetric): + o.append(_v) + return o + @property def batch_size(self) -> torch.Tensor: # performance: cache the `batch_size` tensor instead of re-creating it @@ -389,6 +403,7 @@ def log( sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, + attribute_name: Optional[str] = None, ) -> None: """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs @@ -415,6 +430,7 @@ def log( on_epoch=on_epoch, enable_graph=enable_graph, dataloader_idx=dataloader_idx, + attribute_name=attribute_name, ) meta.reduce_fx = reduce_fx meta.sync = _Sync( @@ -644,11 +660,15 @@ def load_state_dict( state_dict: dict, map_location: Optional[Union[str, torch.device]] = None, sync_fn: Optional[Callable] = None, - should_reset: bool = False, + metrics: Optional[Dict[str, Metric]] = None, ) -> None: self.fx_validator = FxValidator() self.__setstate__(state_dict, map_location=map_location, sync_fn=sync_fn) - if should_reset: - self.reset() + + if metrics: + for attribute_name, metric in metrics.items(): + for result_metric in self.result_metrics: + if result_metric.meta.attribute_name == attribute_name: + result_metric.value = metric diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 27c72fba4a637..293a7e1925246 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1417,10 +1417,10 @@ def on_epoch_end(self) -> None: trainer.accelerator.barrier() - checkpoint = torch.load(checkpoint_path) - items = checkpoint["result_collections"]["train"]["items"] - assert items["training_step.tracking_metric"]["value"] == 6 - assert items["training_step.tracking"]["value"] == 6 + if trainer.is_global_zero: + checkpoint = torch.load(checkpoint_path) + assert checkpoint["state_dict"]['dummy_metric.sum'] == 6 + assert checkpoint["state_dict"]['dummy_metric_dynamic.sum'] == 9 trainer_kwargs["resume_from_checkpoint"] = checkpoint_path trainer_kwargs["max_epochs"] = 2 From f6be5d7065dd55f1c8f60ed4e33cd441c2afa873 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Jun 2021 18:45:00 +0000 Subject: [PATCH 30/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../plugins/training_type/training_type_plugin.py | 1 + .../trainer/connectors/checkpoint_connector.py | 15 +++++---------- .../trainer/connectors/logger_connector/result.py | 10 ++++++---- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 2d9717a41c1d9..4c825a93da290 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -21,6 +21,7 @@ from torch.nn import Module from torch.optim import Optimizer from torch.utils.data import DataLoader + import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.base_plugin import Plugin diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 1b6fde87fd4a9..b4e0b7ce5bd2e 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -19,6 +19,7 @@ import torch from torchmetrics import Metric + import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.states import RunningStage @@ -245,7 +246,7 @@ def restore_result_collections(self) -> None: state_dict = self._loaded_checkpoint["loops_state_dict"].get('result_collections', None) if not state_dict: - return + return # get current reduce function sync_fn = self.trainer.training_type_plugin.reduce @@ -272,11 +273,7 @@ def restore_result_collections(self) -> None: module.reset() def _restore_restore_collection(self, results, state_dict, sync_fn, metrics): - results.load_state_dict( - state_dict, - sync_fn=sync_fn, - metrics=metrics - ) + results.load_state_dict(state_dict, sync_fn=sync_fn, metrics=metrics) if not self.trainer.is_global_zero: results.reset() @@ -391,7 +388,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: for _, module in model.named_modules(): if isinstance(module, Metric): - module.persistent(True) + module.persistent(True) checkpoint = { 'epoch': current_epoch, @@ -440,9 +437,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint def get_loops_state_dict(self) -> Dict[str, Dict[str, Any]]: - return { - "result_collections": self.get_result_collections_state_dict() - } + return {"result_collections": self.get_result_collections_state_dict()} def get_result_collections_state_dict(self): return { diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index e5886380e1735..7a8a28f313145 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Generator +from copy import deepcopy from dataclasses import asdict, dataclass, replace from functools import partial, wraps -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union, List -from copy import deepcopy +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union + import torch from torchmetrics import Metric + from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections @@ -629,7 +631,7 @@ def __setstate__( map_location: Optional[Union[str, torch.device]] = None, sync_fn: Optional[Callable] = None ) -> None: - + self.__dict__.update({k: v for k, v in state.items() if k != 'items'}) def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: @@ -662,7 +664,7 @@ def load_state_dict( sync_fn: Optional[Callable] = None, metrics: Optional[Dict[str, Metric]] = None, ) -> None: - + self.fx_validator = FxValidator() self.__setstate__(state_dict, map_location=map_location, sync_fn=sync_fn) From a117e41f81af4e292fb109fcaf7f57afb96bce9c Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Thu, 17 Jun 2021 14:56:28 -0400 Subject: [PATCH 31/90] update on comments --- CHANGELOG.md | 4 +--- .../trainer/connectors/logger_connector/result.py | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e1918361029c1..1e5024030127a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,7 +82,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training * Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) - + * Add `result_collections` to checkpoint and `restore_result_collections` to `CheckpointConnector` ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) - Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734)) @@ -147,8 +147,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Each of the training loops now keeps its own results collection ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891)) * Remove `EpochResultStore` and `HookResultStore` in favor of `ResultCollection` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) * Remove `MetricsHolder` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) - * Add `load_state_dict` to ResultCollection ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) - * Add `result_collections` to checkpoint and `restore_result_collections` to `CheckpointConnector` ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index e5886380e1735..3af93e3d329cf 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -443,7 +443,7 @@ def log( if key not in self: self.register_key(key, meta, value) - # check the stored metadata and provided one matches + # check the stored metadata and the current one match elif meta != self[key].meta: raise MisconfigurationException( f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed' @@ -633,7 +633,6 @@ def __setstate__( self.__dict__.update({k: v for k, v in state.items() if k != 'items'}) def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: - nonlocal sync_fn if not isinstance(item, dict): raise ValueError(f'Unexpected value: {item}') cls = item['_class'] From c455c69136195f616370354996fe2b4e371b1989 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 06:25:14 -0400 Subject: [PATCH 32/90] improve test --- .../connectors/checkpoint_connector.py | 20 ++++++------ .../connectors/logger_connector/result.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 32 +++++++++---------- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index b4e0b7ce5bd2e..f94cc8fdabc63 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -46,6 +46,7 @@ def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = # used to validate checkpointing logic self.has_trained = False self._loaded_checkpoint = dict() + self._persistent_metrics = False @property def hpc_resume_path(self) -> Optional[str]: @@ -257,20 +258,15 @@ def restore_result_collections(self) -> None: test_results = self.trainer.evaluation_loop._test_results metrics = {} - for module_name, module in self.trainer.lightning_module._named_members(lambda module: module._modules.items()): + model_ref = self.trainer.lightning_module + for module_name, module in model_ref._named_members(lambda module: module._modules.items()): if isinstance(module, Metric): metrics[module_name] = module # restore collection and provide sync_fn self._restore_restore_collection(train_results, state_dict[RunningStage.TRAINING.value], sync_fn, metrics) self._restore_restore_collection(val_results, state_dict[RunningStage.VALIDATING.value], sync_fn, metrics) - self._restore_restore_collection(train_results, state_dict[RunningStage.TESTING.value], sync_fn, metrics) - - # restore metrics - if not self.trainer.is_global_zero: - for _, module in self.trainer.lightning_module.named_modules(): - if isinstance(module, Metric): - module.reset() + self._restore_restore_collection(test_results, state_dict[RunningStage.TESTING.value], sync_fn, metrics) def _restore_restore_collection(self, results, state_dict, sync_fn, metrics): results.load_state_dict(state_dict, sync_fn=sync_fn, metrics=metrics) @@ -386,9 +382,11 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: model = self.trainer.lightning_module - for _, module in model.named_modules(): - if isinstance(module, Metric): - module.persistent(True) + if not self._persistent_metrics: + for _, module in model.named_modules(): + if isinstance(module, Metric): + module.persistent(True) + self._persistent_metrics = True checkpoint = { 'epoch': current_epoch, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 7f5d8df1fd094..120689399f8be 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -243,7 +243,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({state})" def __getstate__(self) -> dict: - with self._apply_sync(): + with self.sync_context(): d = deepcopy(super().__getstate__()) if not self.is_tensor: del d["value"] diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index afe4e36b93196..e8d2e36f91e30 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1357,26 +1357,24 @@ def __init__(self): self.dummy_metric_dynamic = DummyMetric() def training_step(self, batch, batch_idx): - print() - print(self.trainer.global_rank, self.dummy_metric) - print() assert len(batch) == 1 if self.has_reloaded: if batch_idx >= self.breaking_batch_idx: self.log("tracking", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) self.dummy_metric(batch_idx) self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) - if self.trainer.accelerator_connector.is_distributed: - self.dummy_metric_dynamic(batch_idx + int(self.trainer.global_rank)) - self.log("tracking_metric_2", self.dummy_metric_dynamic, on_step=True, on_epoch=True) - value = self.trainer.train_loop.results['training_step.tracking'].value - shift = 1 + shift = 0 if num_processes == 2: - shift += 6 if self.trainer.is_global_zero else 0 - expected = sum(range(self.breaking_batch_idx, batch_idx)) + shift + shift = 3 if self.trainer.is_global_zero else -3 + expected = sum(range(batch_idx + 1)) + shift + assert expected == value + + value = self.trainer.train_loop.results['training_step.tracking_2'] + assert expected == value else: if self.trainer.current_epoch == 2: return @@ -1384,16 +1382,17 @@ def training_step(self, batch, batch_idx): raise CustomException self.log("tracking", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) self.dummy_metric(batch_idx) self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) - if self.trainer.accelerator_connector.is_distributed: - self.dummy_metric_dynamic(batch_idx + int(self.trainer.global_rank)) - self.log("tracking_metric_2", self.dummy_metric_dynamic, on_step=True, on_epoch=True) - value = self.trainer.train_loop.results['training_step.tracking'].value assert value == sum(range(batch_idx + 1)) + + value = self.trainer.train_loop.results['training_step.tracking_2'] + assert value == sum(range(batch_idx + 1)) + return super().training_step(batch, batch_idx) def on_epoch_end(self) -> None: @@ -1403,6 +1402,8 @@ def on_epoch_end(self) -> None: metrics = self.trainer.train_loop.results.metrics(on_step=False) assert self.trainer.train_loop.results['training_step.tracking'].value == total assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2 + assert self.trainer.train_loop.results['training_step.tracking_2'].value == total + assert metrics[MetricSource.CALLBACK]["tracking_2"] == self.dummy_metric.compute() == 2 self.has_validated_sum = True model = ExtendedBoringModel() @@ -1419,8 +1420,7 @@ def on_epoch_end(self) -> None: if trainer.is_global_zero: checkpoint = torch.load(checkpoint_path) - assert checkpoint["state_dict"]['dummy_metric.sum'] == 6 - assert checkpoint["state_dict"]['dummy_metric_dynamic.sum'] == 9 + assert checkpoint["state_dict"]['dummy_metric.sum'] == 3 * num_processes trainer_kwargs["resume_from_checkpoint"] = checkpoint_path trainer_kwargs["max_epochs"] = 2 From 6da2da3dcfd80d2ea662015b53f9203f940745a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 10:26:31 +0000 Subject: [PATCH 33/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index f94cc8fdabc63..378604f384cc4 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -46,7 +46,7 @@ def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = # used to validate checkpointing logic self.has_trained = False self._loaded_checkpoint = dict() - self._persistent_metrics = False + self._persistent_metrics = False @property def hpc_resume_path(self) -> Optional[str]: From 37e531037b43c97f5d2aa343cbf6c634722c2731 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 06:47:30 -0400 Subject: [PATCH 34/90] resolve test --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- .../trainer/connectors/logger_connector/result.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index f94cc8fdabc63..de58f7cf81d33 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -434,7 +434,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint - def get_loops_state_dict(self) -> Dict[str, Dict[str, Any]]: + def get_loops_state_dict(self): return {"result_collections": self.get_result_collections_state_dict()} def get_result_collections_state_dict(self): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 120689399f8be..ab0df85b3983f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -614,6 +614,7 @@ def __str__(self) -> str: def __getstate__(self) -> dict: d = self.__dict__.copy() + d["fx_validator"] = None # can't deepcopy tensors with grad_fn minimize = d['_minimize'] if minimize is not None: @@ -650,6 +651,8 @@ def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: items = {k: setstate(k, v) for k, v in state['items'].items()} self.update(items) + self.fx_validator = FxValidator() + device = map_location or self.device self.to(device) From 49b2647742071993493e35e3f9709047dfcf8cea Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 06:56:13 -0400 Subject: [PATCH 35/90] test with torchmetrics --- requirements.txt | 2 +- setup.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index b564e13551a54..1e629edc600a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ tqdm>=4.41.0 PyYAML>=5.1,<=5.4.1 fsspec[http]>=2021.05.0, !=2021.06.0 tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!' -torchmetrics>=0.3.2 pyDeprecate==0.3.1 packaging typing-extensions # TypedDict support for python<3.8 + diff --git a/setup.py b/setup.py index beebd807c7107..4b2d38b71836d 100755 --- a/setup.py +++ b/setup.py @@ -62,6 +62,9 @@ def _load_py_module(fname, pkg="pytorch_lightning"): version=about.__version__, ) +install_requirements = setup_tools._load_requirements(_PATH_ROOT) + \ + ["torchmetrics @ git+https://github.com/PyTorchLightning/torchmetrics.git@apply_sync_fn"] + # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious # what happens and to non-engineers they won't know to look in init ... @@ -84,7 +87,7 @@ def _load_py_module(fname, pkg="pytorch_lightning"): keywords=['deep learning', 'pytorch', 'AI'], python_requires='>=3.6', setup_requires=[], - install_requires=setup_tools._load_requirements(_PATH_ROOT), + install_requires=install_requirements, extras_require=extras, project_urls={ "Bug Tracker": "https://github.com/PyTorchLightning/pytorch-lightning/issues", From e22578d84f2d0eaf2864ce177cd0d2486d52b558 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 10:57:38 +0000 Subject: [PATCH 36/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 1e629edc600a8..7f2f615f3d033 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,3 @@ tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file in pyDeprecate==0.3.1 packaging typing-extensions # TypedDict support for python<3.8 - From c31282883c7cc06df4e5b68ab0ee7612cf93ffad Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 07:08:48 -0400 Subject: [PATCH 37/90] update on comments --- pytorch_lightning/core/lightning.py | 20 +++++++++++++------ .../connectors/logger_connector/result.py | 8 +++++--- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b81c688747e7b..6f1b51e94bb71 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -215,6 +215,11 @@ def logger(self): """ Reference to the logger object in the Trainer. """ return self.trainer.logger if self.trainer else None + def state_dict(self, destination=None, prefix='', keep_vars=False): + # drop the map id to metrics to avoid saving it. + self._map_id_to_metrics_name = None + return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + def _apply_batch_transfer_handler( self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None ) -> Any: @@ -364,12 +369,15 @@ def log( attribute_name = None if isinstance(value, Metric): - - gen = self._named_members(lambda module: module._modules.items()) - for module_name, module in gen: - if isinstance(module, Metric): - if value.__getstate__() == module.__getstate__(): - attribute_name = module_name + # this is used to effiently find the attribute prefix path of metric objects + # this will enable Lightning to re-attach metric reference when reloading states. + if self._map_id_to_metrics_name is None: + self._map_id_to_metrics_name = { + id(module): module_name + for module_name, module in self._named_members(lambda module: module._modules.items()) + if isinstance(module, Metric) + } + attribute_name = self._map_id_to_metrics_name[id(value)] results.log( self._current_fx_name, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index ab0df85b3983f..cca4544994066 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -245,6 +245,8 @@ def __repr__(self) -> str: def __getstate__(self) -> dict: with self.sync_context(): d = deepcopy(super().__getstate__()) + # metric are being dropped, so they won't be serialized + # this would prevent pickling error if their API change. if not self.is_tensor: del d["value"] d['meta'] = d['meta'].__getstate__() @@ -342,9 +344,9 @@ def result_metrics(self) -> List[ResultMetric]: if isinstance(v, ResultMetric): o.append(v) elif isinstance(v, ResultCollection): - for _v in v.items(): - if isinstance(v, ResultMetric): - o.append(_v) + def append_fn(v: ResultMetric) -> None: + o.append(v) + apply_to_collection(v, ResultMetric, append_fn) return o @property From 7c2d63c66cf7a61d457a7ce20fbf4fcb918c0272 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 07:09:57 -0400 Subject: [PATCH 38/90] update torchmetrics path --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4b2d38b71836d..9f1220f5ef6bb 100755 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def _load_py_module(fname, pkg="pytorch_lightning"): ) install_requirements = setup_tools._load_requirements(_PATH_ROOT) + \ - ["torchmetrics @ git+https://github.com/PyTorchLightning/torchmetrics.git@apply_sync_fn"] + ["torchmetrics @ git+https://github.com/PyTorchLightning/metrics.git@apply_sync_fn"] # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious From 15746fbeb6335196d2d91657c4bd01ba87b863d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 11:11:25 +0000 Subject: [PATCH 39/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/lightning.py | 4 ++-- .../trainer/connectors/logger_connector/result.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6f1b51e94bb71..1771b87556194 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -369,8 +369,8 @@ def log( attribute_name = None if isinstance(value, Metric): - # this is used to effiently find the attribute prefix path of metric objects - # this will enable Lightning to re-attach metric reference when reloading states. + # this is used to effiently find the attribute prefix path of metric objects + # this will enable Lightning to re-attach metric reference when reloading states. if self._map_id_to_metrics_name is None: self._map_id_to_metrics_name = { id(module): module_name diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index cca4544994066..639eb0d10804b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -344,8 +344,10 @@ def result_metrics(self) -> List[ResultMetric]: if isinstance(v, ResultMetric): o.append(v) elif isinstance(v, ResultCollection): + def append_fn(v: ResultMetric) -> None: o.append(v) + apply_to_collection(v, ResultMetric, append_fn) return o From 4c866369e828909b92836e37badc11e5efa7f042 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 07:19:45 -0400 Subject: [PATCH 40/90] update --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9f1220f5ef6bb..c7ff3e5c17c09 100755 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def _load_py_module(fname, pkg="pytorch_lightning"): ) install_requirements = setup_tools._load_requirements(_PATH_ROOT) + \ - ["torchmetrics @ git+https://github.com/PyTorchLightning/metrics.git@apply_sync_fn"] + ["metrics @ git+https://github.com/PyTorchLightning/metrics.git@apply_sync_fn"] # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious From e560406850c87f3e1a0ccbb19b70191119455010 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 07:26:07 -0400 Subject: [PATCH 41/90] update --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c7ff3e5c17c09..5f51dfccf8f75 100755 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def _load_py_module(fname, pkg="pytorch_lightning"): ) install_requirements = setup_tools._load_requirements(_PATH_ROOT) + \ - ["metrics @ git+https://github.com/PyTorchLightning/metrics.git@apply_sync_fn"] + ["metrics@git+https://github.com/PyTorchLightning/metrics.git@apply_sync_fn"] # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious From d0b012f0b900567f557ea83d3c9f5e5529c73206 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 08:06:36 -0400 Subject: [PATCH 42/90] update setup --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5f51dfccf8f75..3be9c59ee391c 100755 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def _load_py_module(fname, pkg="pytorch_lightning"): ) install_requirements = setup_tools._load_requirements(_PATH_ROOT) + \ - ["metrics@git+https://github.com/PyTorchLightning/metrics.git@apply_sync_fn"] + ["git+https://github.com/PyTorchLightning/metrics.git@apply_sync_fn"] # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious From 09f05f6a2641fd03dd3d1a9ff012156dcb86d8b6 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 09:17:12 -0400 Subject: [PATCH 43/90] add directly in CI --- .github/workflows/ci_test-full.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 1064e603bee1f..26e5e90f13262 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -122,6 +122,7 @@ jobs: pip --version # python -m pip install --upgrade --user pip flag=$(python -c "print('--pre' if '${{matrix.release}}' == 'pre' else '')" 2>&1) + pip install git+https://github.com/PyTorchLightning/metrics.git@apply_sync_fn pip install --requirement requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade $flag # adjust versions according installed Torch version python ./requirements/adjust_versions.py requirements/extra.txt From d0a6cf9c6f48b49adc55897ed9a71b215f0da3c4 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 09:19:51 -0400 Subject: [PATCH 44/90] update --- setup.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 3be9c59ee391c..beebd807c7107 100755 --- a/setup.py +++ b/setup.py @@ -62,9 +62,6 @@ def _load_py_module(fname, pkg="pytorch_lightning"): version=about.__version__, ) -install_requirements = setup_tools._load_requirements(_PATH_ROOT) + \ - ["git+https://github.com/PyTorchLightning/metrics.git@apply_sync_fn"] - # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious # what happens and to non-engineers they won't know to look in init ... @@ -87,7 +84,7 @@ def _load_py_module(fname, pkg="pytorch_lightning"): keywords=['deep learning', 'pytorch', 'AI'], python_requires='>=3.6', setup_requires=[], - install_requires=install_requirements, + install_requires=setup_tools._load_requirements(_PATH_ROOT), extras_require=extras, project_urls={ "Bug Tracker": "https://github.com/PyTorchLightning/pytorch-lightning/issues", From a5ae2c4052011a6849b5db4a0466750b1508e026 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Jun 2021 15:51:07 +0200 Subject: [PATCH 45/90] Whitespace --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e5024030127a..36e7413748073 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -148,6 +148,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Remove `EpochResultStore` and `HookResultStore` in favor of `ResultCollection` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) * Remove `MetricsHolder` ([#7909](https://github.com/PyTorchLightning/pytorch-lightning/pull/7909)) + - Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/)) From f7eafd7716eae7372e6fddbeb0d405b07b096c94 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 18 Jun 2021 12:15:35 -0400 Subject: [PATCH 46/90] resolve bug --- pytorch_lightning/core/lightning.py | 7 ++++--- pytorch_lightning/plugins/training_type/ddp_spawn.py | 5 ++++- pytorch_lightning/plugins/training_type/tpu_spawn.py | 5 ++++- .../trainer/connectors/logger_connector/result.py | 6 +++--- tests/core/test_metric_result_integration.py | 10 +++++----- tests/metrics/test_remove_1-5_metrics.py | 2 +- tests/models/test_hooks.py | 5 +++-- tests/trainer/logging_/test_logger_connector.py | 6 +++--- 8 files changed, 27 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 1771b87556194..9cfde84d701bf 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -277,6 +277,7 @@ def log( sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, + attribute_name: Optional[str] = None, ) -> None: """ Log a key, value @@ -314,6 +315,8 @@ def log( each dataloader to not mix values batch_size: Current batch_size. This will be directly inferred from the loaded batch, but some data structures might need to explicitly provide it. + attribute_name: This would be the attribute name to find a given metric. + If None, this will be automatically inferred by Lightning. """ if tbptt_reduce_fx is not None: rank_zero_deprecation( @@ -366,9 +369,7 @@ def log( # reset any tensors for the new hook name results.reset(metrics=False, fx=self._current_fx_name) - attribute_name = None - - if isinstance(value, Metric): + if attribute_name is None and isinstance(value, Metric): # this is used to effiently find the attribute prefix path of metric objects # this will enable Lightning to re-attach metric reference when reloading states. if self._map_id_to_metrics_name is None: diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 8d2cc217835fb..fd9e8ca07e742 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -274,6 +274,9 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + # requires to compute the state_dict on all processes in case Metric are presents + state_dict = self.lightning_module.state_dict() + if self.global_rank == 0 and self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") @@ -284,7 +287,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): and len(best_model_path) > 0 ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - atomic_save(self.on_save(self.lightning_module.state_dict()), last_path) + atomic_save(self.on_save(state_dict), last_path) # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 9921fadd2cfc1..f6477b75641c3 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -185,6 +185,9 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + # requires to compute the state_dict on all processes in case Metric are presents + state_dict = self.lightning_module.state_dict() + if self.mp_queue is not None: rank_zero_warn("cleaning up tpu spawn environment...") @@ -195,7 +198,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): and len(best_model_path) > 0 ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.save(self.lightning_module.state_dict(), last_path) + self.save(state_dict, last_path) if self.local_rank == 0: # todo, pass complete checkpoint as state dictionary diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 639eb0d10804b..af02ce70bb76c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -242,12 +242,12 @@ def __repr__(self) -> str: state += f", cumulated_batch_size={self.cumulated_batch_size}" return f"{self.__class__.__name__}({state})" - def __getstate__(self) -> dict: + def __getstate__(self, drop_value: bool = False) -> dict: with self.sync_context(): d = deepcopy(super().__getstate__()) # metric are being dropped, so they won't be serialized # this would prevent pickling error if their API change. - if not self.is_tensor: + if drop_value and self.is_tensor: del d["value"] d['meta'] = d['meta'].__getstate__() d['_class'] = self.__class__.__name__ @@ -282,7 +282,7 @@ def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None: def __getstate__(self) -> dict: def getstate(item: ResultMetric) -> dict: - return item.__getstate__() + return item.__getstate__(drop_value=True) items = apply_to_collection(dict(self), (ResultMetric, ResultMetricCollection), getstate) return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__} diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 6b7163c4aa643..b29e76b778740 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -184,7 +184,7 @@ def lightning_log(fx, *args, **kwargs): assert result[k].cumulated_batch_size == torch.tensor(1.), k -def my_sync_dist(x): +def my_sync_dist(x, *_, **__): return x @@ -222,9 +222,9 @@ def lightning_log(fx, *args, **kwargs): cumulative_sum += i metric = metric_a if i < 1 else metric_d - lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True) - lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True) - lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False) + lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True, attribute_name="metric") + lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True, attribute_name="metric_b") + lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False, attribute_name="metric_c") lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True) lightning_log('training_step', 'b_1', b, on_step=False, on_epoch=True) lightning_log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False) @@ -238,7 +238,7 @@ def lightning_log(fx, *args, **kwargs): state_dict = result.state_dict() # check the sync fn was dropped assert 'fn' not in state_dict['items']['training_step.a']['meta']['_sync'] - new_result.load_state_dict(state_dict) + new_result.load_state_dict(state_dict, metrics={"metric": metric, "metric_b": metric_b, "metric_c": metric_c}) # should match assert result_copy == new_result # the sync fn has been kept diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index d3703bf3691c9..aa7d4977d1133 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -215,7 +215,7 @@ def test_v1_5_metric_classif_mix(): preds = torch.tensor([0, 1, 0, 0]) confusion_matrix._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert torch.equal(confusion_matrix(preds, target, num_classes=2), torch.tensor([[2., 0.], [1., 1.]])) + assert torch.equal(confusion_matrix(preds, target, num_classes=2).float(), torch.tensor([[2., 0.], [1., 1.]])) target = torch.tensor([0, 1, 2, 0, 1, 2]) preds = torch.tensor([0, 2, 1, 0, 0, 1]) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 849cd4115272f..998ea4969bfad 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -15,7 +15,6 @@ from inspect import getmembers, isfunction from unittest import mock from unittest.mock import ANY, PropertyMock - import pytest import torch from torch.utils.data import DataLoader @@ -712,7 +711,9 @@ def call(hook, fn, *args, **kwargs): 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, 'state_dict': ANY, - 'result_collections': ANY + 'loops_state_dict': { + "result_collections": ANY + } }, ) ), dict(name='teardown', kwargs=dict(stage='fit')), diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 27b0a252054a2..d61aa3f0da0c7 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -309,7 +309,7 @@ def _step(self, stage, batch): logits = self.forward(batch) loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1)) probs = torch.sigmoid(logits.detach()) - self.log(f"loss/{stage}", loss) + self.log(f"loss/{stage}", loss, attribute_name="dummy") acc = self._modules[f"acc_{stage}"] ap = self._modules[f"ap_{stage}"] @@ -322,8 +322,8 @@ def _step(self, stage, batch): acc.reset.reset_mock() ap.reset.reset_mock() - self.log(f"{stage}/accuracy", acc) - self.log(f"{stage}/ap", ap) + self.log(f"{stage}/accuracy", acc, attribute_name="dummy") + self.log(f"{stage}/ap", ap, attribute_name="dummy") return loss From 5f3b33c74dfb3e0de6456642360e69aa471d4150 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Jun 2021 16:17:01 +0000 Subject: [PATCH 47/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/core/test_metric_result_integration.py | 8 +++++++- tests/models/test_hooks.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index b29e76b778740..43ee522b260f7 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -238,7 +238,13 @@ def lightning_log(fx, *args, **kwargs): state_dict = result.state_dict() # check the sync fn was dropped assert 'fn' not in state_dict['items']['training_step.a']['meta']['_sync'] - new_result.load_state_dict(state_dict, metrics={"metric": metric, "metric_b": metric_b, "metric_c": metric_c}) + new_result.load_state_dict( + state_dict, metrics={ + "metric": metric, + "metric_b": metric_b, + "metric_c": metric_c + } + ) # should match assert result_copy == new_result # the sync fn has been kept diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 998ea4969bfad..f6e02fe416953 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -15,6 +15,7 @@ from inspect import getmembers, isfunction from unittest import mock from unittest.mock import ANY, PropertyMock + import pytest import torch from torch.utils.data import DataLoader From ecbd5e6e6bc8c89c66cdbe57aa9d7dcb5cf36178 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 21 Jun 2021 05:47:01 -0400 Subject: [PATCH 48/90] update --- .azure-pipelines/gpu-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 5333bfd867da0..55bf0401788a1 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -55,6 +55,7 @@ jobs: displayName: 'Image info & NVIDIA' - bash: | + pip install git+https://github.com/PyTorchLightning/metrics.git@apply_sync_fn python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)" pip install fairscale>=0.3.4 pip install deepspeed>=0.4.0 -U From f3a3996189421ff54bb4d0dc56e5729ac23857a1 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 21 Jun 2021 06:11:34 -0400 Subject: [PATCH 49/90] update on comments --- pytorch_lightning/core/lightning.py | 21 +++++++++++-------- .../connectors/checkpoint_connector.py | 11 +++++++--- .../connectors/logger_connector/result.py | 21 +++++++------------ tests/core/test_metric_result_integration.py | 6 +++--- .../trainer/logging_/test_logger_connector.py | 6 +++--- 5 files changed, 34 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9cfde84d701bf..7343f1511f104 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -215,10 +215,10 @@ def logger(self): """ Reference to the logger object in the Trainer. """ return self.trainer.logger if self.trainer else None - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, *args, **kwargs) -> Dict[str, Any]: # drop the map id to metrics to avoid saving it. self._map_id_to_metrics_name = None - return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + return super().state_dict(*args, **kwargs) def _apply_batch_transfer_handler( self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None @@ -277,7 +277,7 @@ def log( sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, - attribute_name: Optional[str] = None, + metric_prefix_name: Optional[str] = None, ) -> None: """ Log a key, value @@ -315,8 +315,11 @@ def log( each dataloader to not mix values batch_size: Current batch_size. This will be directly inferred from the loaded batch, but some data structures might need to explicitly provide it. - attribute_name: This would be the attribute name to find a given metric. - If None, this will be automatically inferred by Lightning. + metric_prefix_name: To enable ``Fault Tolerant Logging``, Lightning requires a way to restore TorchMetric Metric + instance references on-reload. When the logged Metric are LightningModule attributes, + metric_prefix_name should be None. However, when this is not, metric_prefix_name should be provided as + Lightning won't be able to find your nn.Metric reference. + """ if tbptt_reduce_fx is not None: rank_zero_deprecation( @@ -369,8 +372,8 @@ def log( # reset any tensors for the new hook name results.reset(metrics=False, fx=self._current_fx_name) - if attribute_name is None and isinstance(value, Metric): - # this is used to effiently find the attribute prefix path of metric objects + if metric_prefix_name is None and isinstance(value, Metric): + # this is used to efficiently find the attribute prefix path of metric objects # this will enable Lightning to re-attach metric reference when reloading states. if self._map_id_to_metrics_name is None: self._map_id_to_metrics_name = { @@ -378,7 +381,7 @@ def log( for module_name, module in self._named_members(lambda module: module._modules.items()) if isinstance(module, Metric) } - attribute_name = self._map_id_to_metrics_name[id(value)] + metric_prefix_name = self._map_id_to_metrics_name[id(value)] results.log( self._current_fx_name, @@ -395,7 +398,7 @@ def log( sync_dist=sync_dist, sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available, sync_dist_group=sync_dist_group, - attribute_name=attribute_name, + metric_prefix_name=metric_prefix_name, ) self.trainer.logger_connector._current_fx = self._current_fx_name diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 1951b1e736123..5f92c4d8d22d0 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -244,7 +244,12 @@ def restore_result_collections(self) -> None: if not self._loaded_checkpoint: return - state_dict = self._loaded_checkpoint["loops_state_dict"].get('result_collections', None) + loops = self._loaded_checkpoint.get("loops", None) + + if not loops: + return + + state_dict = loops.get('result_collections', None) if not state_dict: return @@ -393,7 +398,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'global_step': global_step, 'pytorch-lightning_version': pytorch_lightning.__version__, 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), - 'loops_state_dict': self.get_loops_state_dict() + 'loops': self.get_loops_state_dict() } if not weights_only: @@ -439,7 +444,7 @@ def get_loops_state_dict(self): def get_result_collections_state_dict(self): return { - RunningStage.TRAINING.value: self.trainer.train_loop.results.state_dict(), + RunningStage.TRAINING.value: self.trainer.fit_loop.results.state_dict(), RunningStage.VALIDATING.value: self.trainer.evaluation_loop._val_results.state_dict(), RunningStage.TESTING.value: self.trainer.evaluation_loop._test_results.state_dict(), } diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index af02ce70bb76c..f6f0a238ce3db 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -71,7 +71,7 @@ class _Metadata: _reduce_fx: Callable = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None - attribute_name: Optional[str] = None + metric_prefix_name: Optional[str] = None _sync: Optional[_Sync] = None @property @@ -340,15 +340,10 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = @property def result_metrics(self) -> List[ResultMetric]: o = [] - for v in self.values(): - if isinstance(v, ResultMetric): - o.append(v) - elif isinstance(v, ResultCollection): + def append_fn(v: ResultMetric) -> None: + o.append(v) - def append_fn(v: ResultMetric) -> None: - o.append(v) - - apply_to_collection(v, ResultMetric, append_fn) + apply_to_collection(self.values(), ResultMetric, append_fn) return o @property @@ -409,7 +404,7 @@ def log( sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, - attribute_name: Optional[str] = None, + metric_prefix_name: Optional[str] = None, ) -> None: """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs @@ -436,7 +431,7 @@ def log( on_epoch=on_epoch, enable_graph=enable_graph, dataloader_idx=dataloader_idx, - attribute_name=attribute_name, + metric_prefix_name=metric_prefix_name, ) meta.reduce_fx = reduce_fx meta.sync = _Sync( @@ -676,7 +671,7 @@ def load_state_dict( self.__setstate__(state_dict, map_location=map_location, sync_fn=sync_fn) if metrics: - for attribute_name, metric in metrics.items(): + for metric_prefix_name, metric in metrics.items(): for result_metric in self.result_metrics: - if result_metric.meta.attribute_name == attribute_name: + if result_metric.meta.metric_prefix_name == metric_prefix_name: result_metric.value = metric diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 43ee522b260f7..8785c70c28f8f 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -222,9 +222,9 @@ def lightning_log(fx, *args, **kwargs): cumulative_sum += i metric = metric_a if i < 1 else metric_d - lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True, attribute_name="metric") - lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True, attribute_name="metric_b") - lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False, attribute_name="metric_c") + lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True, metric_prefix_name="metric") + lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_prefix_name="metric_b") + lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_prefix_name="metric_c") lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True) lightning_log('training_step', 'b_1', b, on_step=False, on_epoch=True) lightning_log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index d61aa3f0da0c7..0f54b7c16fbe2 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -309,7 +309,7 @@ def _step(self, stage, batch): logits = self.forward(batch) loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1)) probs = torch.sigmoid(logits.detach()) - self.log(f"loss/{stage}", loss, attribute_name="dummy") + self.log(f"loss/{stage}", loss, metric_prefix_name="dummy") acc = self._modules[f"acc_{stage}"] ap = self._modules[f"ap_{stage}"] @@ -322,8 +322,8 @@ def _step(self, stage, batch): acc.reset.reset_mock() ap.reset.reset_mock() - self.log(f"{stage}/accuracy", acc, attribute_name="dummy") - self.log(f"{stage}/ap", ap, attribute_name="dummy") + self.log(f"{stage}/accuracy", acc, metric_prefix_name="dummy") + self.log(f"{stage}/ap", ap, metric_prefix_name="dummy") return loss From 2289712e2405c35a54193642817be66d27819c49 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Jun 2021 10:13:38 +0000 Subject: [PATCH 50/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f6f0a238ce3db..fa346f68be2af 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -340,6 +340,7 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = @property def result_metrics(self) -> List[ResultMetric]: o = [] + def append_fn(v: ResultMetric) -> None: o.append(v) From 0415498d5e4dd01388e22d22dbe9a053e35bde6d Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 21 Jun 2021 18:38:49 +0100 Subject: [PATCH 51/90] update torchmetrics --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b564e13551a54..4d33b6972e3d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ tqdm>=4.41.0 PyYAML>=5.1,<=5.4.1 fsspec[http]>=2021.05.0, !=2021.06.0 tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!' -torchmetrics>=0.3.2 +torchmetrics>=0.4.0rc0 pyDeprecate==0.3.1 packaging typing-extensions # TypedDict support for python<3.8 From 6865a36dab08fd2dc64fea9f6bb3d35d42a10e51 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 21 Jun 2021 18:58:49 +0100 Subject: [PATCH 52/90] resolve tests --- tests/checkpointing/test_model_checkpoint.py | 2 +- tests/models/test_hooks.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 4d4864c491f85..5fb5d19ef1b88 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1338,7 +1338,7 @@ def update(self, increment): self.count += 1 def compute(self): - return self.sum / self.count + return self.sum // self.count def __repr__(self): return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})" diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 5e73e36c894fc..bc4ccf0f72c04 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -400,6 +400,7 @@ def test_trainer_model_hook_system_fit(tmpdir): 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, 'state_dict': ANY, + 'loops': ANY, } expected = [ dict(name='Callback.on_init_start', args=(trainer, )), @@ -757,7 +758,7 @@ def call(hook, fn, *args, **kwargs): 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, 'state_dict': ANY, - 'loops_state_dict': { + 'loops': { "result_collections": ANY } }, ) From fd6cf34c8bf312c1196adf7565e03317bb541c59 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 23 Jun 2021 08:54:07 +0100 Subject: [PATCH 53/90] get duration --- tests/special_tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 9fca3b62bad40..f1f5066be1ac7 100755 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -17,7 +17,7 @@ set -e # this environment variable allows special tests to run export PL_RUNNING_SPECIAL_TESTS=1 # python arguments -defaults='-m coverage run --source pytorch_lightning --append -m pytest --verbose --capture=no' +defaults='-m coverage run --source pytorch_lightning --append -m pytest --durations=0 --capture=no' # find tests marked as `@RunIf(special=True)` grep_output=$(grep --recursive --line-number --word-regexp 'tests' 'benchmarks' --regexp 'special=True') From fc66c10eb7f3d38d3bce99a9f4c01c2a7e57c53f Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 23 Jun 2021 09:00:59 +0100 Subject: [PATCH 54/90] resolve issues --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 9fff168563bbc..a7787c04eac9b 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -451,8 +451,9 @@ def get_loops_state_dict(self): def get_result_collections_state_dict(self): return { RunningStage.TRAINING.value: self.trainer.fit_loop.results.state_dict(), - RunningStage.VALIDATING.value: self.trainer.evaluation_loop._val_results.state_dict(), - RunningStage.TESTING.value: self.trainer.evaluation_loop._test_results.state_dict(), + RunningStage.SANITY_CHECKING.value: self.trainer.fit_loop.validation_loop.results.state_dict(), + RunningStage.VALIDATING.value: self.trainer.validation_loop.results.state_dict(), + RunningStage.TESTING.value: self.trainer.evaluation_loop.results.state_dict(), } def hpc_load(self, checkpoint_path: str) -> None: From 979bc23009eb8fd6289924e422d5f215f7dab085 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 23 Jun 2021 09:48:36 +0100 Subject: [PATCH 55/90] resolve bug --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index a7787c04eac9b..6302b6bbc0df9 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -450,7 +450,7 @@ def get_loops_state_dict(self): def get_result_collections_state_dict(self): return { - RunningStage.TRAINING.value: self.trainer.fit_loop.results.state_dict(), + RunningStage.TRAINING.value: self.trainer.fit_loop.training_loop.results.state_dict(), RunningStage.SANITY_CHECKING.value: self.trainer.fit_loop.validation_loop.results.state_dict(), RunningStage.VALIDATING.value: self.trainer.validation_loop.results.state_dict(), RunningStage.TESTING.value: self.trainer.evaluation_loop.results.state_dict(), From effff31a5dd854f985f5442bc5d2a24aa8f4f064 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 23 Jun 2021 10:18:47 +0100 Subject: [PATCH 56/90] update --- .../connectors/checkpoint_connector.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 6302b6bbc0df9..5d8cc916a4257 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -266,9 +266,10 @@ def restore_result_collections(self) -> None: sync_fn = self.trainer.training_type_plugin.reduce # get current result collections - train_results = self.trainer.train_loop.results - val_results = self.trainer.evaluation_loop._val_results - test_results = self.trainer.evaluation_loop._test_results + train_results = self.trainer.fit_loop.training_loop.results + validation_results = self.trainer.fit_loop.validation_loop.results + validate_results = self.trainer.validation_loop.results + test_results = self.trainer.test_loop.results metrics = {} model_ref = self.trainer.lightning_module @@ -277,8 +278,13 @@ def restore_result_collections(self) -> None: metrics[module_name] = module # restore collection and provide sync_fn - self._restore_restore_collection(train_results, state_dict[RunningStage.TRAINING.value], sync_fn, metrics) - self._restore_restore_collection(val_results, state_dict[RunningStage.VALIDATING.value], sync_fn, metrics) + self._restore_restore_collection( + train_results, state_dict[RunningStage.TRAINING.value][RunningStage.TRAINING.value], sync_fn, metrics + ) + self._restore_restore_collection( + validation_results, state_dict[RunningStage.TRAINING.value][RunningStage.VALIDATING.value], sync_fn, metrics + ) + self._restore_restore_collection(validate_results, state_dict[RunningStage.VALIDATING.value], sync_fn, metrics) self._restore_restore_collection(test_results, state_dict[RunningStage.TESTING.value], sync_fn, metrics) def _restore_restore_collection(self, results, state_dict, sync_fn, metrics): @@ -450,8 +456,10 @@ def get_loops_state_dict(self): def get_result_collections_state_dict(self): return { - RunningStage.TRAINING.value: self.trainer.fit_loop.training_loop.results.state_dict(), - RunningStage.SANITY_CHECKING.value: self.trainer.fit_loop.validation_loop.results.state_dict(), + RunningStage.TRAINING.value: { + RunningStage.TRAINING.value: self.trainer.fit_loop.training_loop.results.state_dict(), + RunningStage.VALIDATING.value: self.trainer.fit_loop.validation_loop.results.state_dict(), + }, RunningStage.VALIDATING.value: self.trainer.validation_loop.results.state_dict(), RunningStage.TESTING.value: self.trainer.evaluation_loop.results.state_dict(), } From b38efe1390d9e8974e27c2eaf82d3de64c5d6c81 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 23 Jun 2021 10:39:07 +0100 Subject: [PATCH 57/90] resolve tests --- tests/models/test_hooks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 80c4e6415e0c5..6e7d8a9f1a15f 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -513,6 +513,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, 'state_dict': ANY, + 'loops': ANY, } expected = [ dict(name='Callback.on_init_start', args=(trainer, )), @@ -531,7 +532,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): 'lr_schedulers': ANY, 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, - 'state_dict': ANY + 'state_dict': ANY, + 'loops': ANY, }, ) ), dict(name='configure_sharded_model'), From 2b70bfbd802bd9044e35ab72e7570320c1defe86 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 23 Jun 2021 10:47:25 +0100 Subject: [PATCH 58/90] update names --- .../connectors/checkpoint_connector.py | 8 +++---- tests/models/test_hooks.py | 22 +++++++++++++++++-- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5d8cc916a4257..12608b9177010 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -22,7 +22,7 @@ import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import ( _OMEGACONF_AVAILABLE, DeviceType, @@ -279,10 +279,10 @@ def restore_result_collections(self) -> None: # restore collection and provide sync_fn self._restore_restore_collection( - train_results, state_dict[RunningStage.TRAINING.value][RunningStage.TRAINING.value], sync_fn, metrics + train_results, state_dict[TrainerFn.FITTING.value][RunningStage.TRAINING.value], sync_fn, metrics ) self._restore_restore_collection( - validation_results, state_dict[RunningStage.TRAINING.value][RunningStage.VALIDATING.value], sync_fn, metrics + validation_results, state_dict[TrainerFn.FITTING.value][RunningStage.VALIDATING.value], sync_fn, metrics ) self._restore_restore_collection(validate_results, state_dict[RunningStage.VALIDATING.value], sync_fn, metrics) self._restore_restore_collection(test_results, state_dict[RunningStage.TESTING.value], sync_fn, metrics) @@ -456,7 +456,7 @@ def get_loops_state_dict(self): def get_result_collections_state_dict(self): return { - RunningStage.TRAINING.value: { + TrainerFn.FITTING.value: { RunningStage.TRAINING.value: self.trainer.fit_loop.training_loop.results.state_dict(), RunningStage.VALIDATING.value: self.trainer.fit_loop.validation_loop.results.state_dict(), }, diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 6e7d8a9f1a15f..ded39924479de 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -513,7 +513,16 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, 'state_dict': ANY, - 'loops': ANY, + 'loops': { + "result_collections": { + "fit": { + "train": ANY, + "validate": ANY, + }, + "validate": ANY, + "test": ANY + } + }, } expected = [ dict(name='Callback.on_init_start', args=(trainer, )), @@ -533,7 +542,16 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, 'state_dict': ANY, - 'loops': ANY, + 'loops': { + "result_collections": { + "fit": { + "train": ANY, + "validate": ANY, + }, + "validate": ANY, + "test": ANY + } + }, }, ) ), dict(name='configure_sharded_model'), From 61d46bb4f04138a9c35bb2d5e1a4f3ec0e72f81a Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 23 Jun 2021 11:40:31 -0400 Subject: [PATCH 59/90] resolve bug --- docs/source/advanced/multi_gpu.rst | 17 ++++++++++++ pytorch_lightning/core/lightning.py | 4 +++ .../connectors/logger_connector/result.py | 27 +++++++++++++++---- .../test_checkpoint_callback_frequency.py | 5 ++-- 4 files changed, 46 insertions(+), 7 deletions(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 1c465ae314e4f..c21e16b92244e 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -106,6 +106,23 @@ Note if you use any built in metrics or custom metrics that use the :doc:`Metric # Add sync_dist=True to sync logging across all GPU workers self.log('test_loss', loss, on_step=True, on_epoch=True, sync_dist=True) +It is possible to perform computation manually and log the value on rank 0. + +.. testcode:: + + def test_step(self, batch, batch_idx): + x, y = batch + tensors = self(x) + return tensors + + def test_epoch_end(self, outputs): + mean = torch.mean(self.all_gather(outputs)) + + # when logging only rank 0, don't forget to add + # ``is_global_zero`` in self.log to avoid deadlock. + if self.trainer.is_global_zero: + self.log("my_reduced_metric", mean, is_global_zero=True) + Make models pickleable ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 0e7e90a1fdd28..a8a130c1589a1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -280,6 +280,7 @@ def log( add_dataloader_idx: bool = True, batch_size: Optional[int] = None, metric_prefix_name: Optional[str] = None, + is_global_zero: Optional[bool] = None, ) -> None: """ Log a key, value @@ -322,6 +323,8 @@ def log( instance references on-reload. When the logged Metric are LightningModule attributes, metric_prefix_name should be None. However, when this is not, metric_prefix_name should be provided as Lightning won't be able to find your nn.Metric reference. + is_global_zero: Whether the value will be logged only on rank 0. This will prevent synchornization across processes + and avoid a deadlock. """ if tbptt_reduce_fx is not None: @@ -402,6 +405,7 @@ def log( sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available, sync_dist_group=sync_dist_group, metric_prefix_name=metric_prefix_name, + is_global_zero=is_global_zero, ) self.trainer.logger_connector._current_fx = self._current_fx_name diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 3f45867dab596..a750c6a120138 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -16,7 +16,8 @@ from dataclasses import asdict, dataclass, replace from functools import partial, wraps from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union - +from datetime import timedelta +import pytorch_lightning as pl import torch from torchmetrics import Metric @@ -47,6 +48,7 @@ class MetricSource(LightningEnum): class _Sync: fn: Optional[Callable] = None should: bool = False + is_global_zero: bool = False op: Optional[str] = None group: Optional[Any] = None @@ -54,9 +56,13 @@ def __post_init__(self) -> None: if self.fn is None: self.fn = self.no_op + @property + def should_sync(self) -> bool: + return self.should and not self.is_global_zero + @property def __call__(self) -> Any: - return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op + return partial(self.fn, reduce_op=self.op, group=self.group) if self.should_sync else self.no_op @staticmethod def no_op(value: Any, *_, **__) -> Any: @@ -193,6 +199,7 @@ def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: def compute(self) -> torch.Tensor: if self.is_tensor: + print(self.meta.name, "sync", self.meta.sync.is_global_zero) value = self.meta.sync(self.value) if self.meta.is_mean_reduction: cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) @@ -246,8 +253,10 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({state})" def __getstate__(self, drop_value: bool = False) -> dict: - with self.sync_context(): + print(self.meta.name, "__getstate__", self.meta.sync.is_global_zero) + with self.sync_context(should_sync=not self.meta.sync.is_global_zero): d = deepcopy(super().__getstate__()) + print("SYNCED") # metric are being dropped, so they won't be serialized # this would prevent pickling error if their API change. if drop_value and self.is_tensor: @@ -345,9 +354,10 @@ def result_metrics(self) -> List[ResultMetric]: o = [] def append_fn(v: ResultMetric) -> None: + nonlocal o o.append(v) - apply_to_collection(self.values(), ResultMetric, append_fn) + apply_to_collection(list(self.values()), ResultMetric, append_fn) return o @property @@ -416,6 +426,7 @@ def log( dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, metric_prefix_name: Optional[str] = None, + is_global_zero: bool = False, ) -> None: """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs @@ -449,6 +460,7 @@ def log( should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, + is_global_zero=is_global_zero, ) # register logged value if it doesn't exist @@ -457,6 +469,7 @@ def log( # check the stored metadata and the current one match elif meta != self[key].meta: + import pdb; pdb.set_trace() raise MisconfigurationException( f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed' ) @@ -632,6 +645,10 @@ def __getstate__(self) -> dict: extra = self.get('_extra') if extra is not None: d['_extra'] = extra + + for result_metric in self.result_metrics: + result_metric.meta.name + # all the items should be either `ResultMetric`s or `ResultMetricCollection`s items = {k: v.__getstate__() for k, v in self.items() if k not in ('_extra', 'fx_validator')} return {**d, 'items': items} @@ -640,7 +657,7 @@ def __setstate__( self, state: dict, map_location: Optional[Union[str, torch.device]] = None, - sync_fn: Optional[Callable] = None + sync_fn: Optional[Callable] = None, ) -> None: self.__dict__.update({k: v for k, v in state.items() if k != 'items'}) diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 9fdd69dba7a9a..5b3fa15216fe1 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -107,7 +107,8 @@ def training_step(self, batch, batch_idx): @mock.patch('torch.save') @RunIf(special=True, min_gpus=2) -@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(1, 1, 1.0, 1), (2, 2, 0.3, 5)]) +#@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(1, 1, 1.0, 1), (2, 2, 0.3, 5)]) +@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(2, 2, 0.3, 5)]) def test_top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected): class TestModel(BoringModel): @@ -120,7 +121,7 @@ def training_step(self, batch, batch_idx): def training_epoch_end(self, outputs) -> None: local_rank = int(os.getenv("LOCAL_RANK")) if self.trainer.is_global_zero: - self.log('my_loss_2', (1 + local_rank), on_epoch=True) + self.log('my_loss_2', (1 + local_rank), on_epoch=True, is_global_zero=True) data = str(self.global_rank) obj = [[data], (data, ), set(data)] out = self.trainer.training_type_plugin.broadcast(obj) From 67ce6913ba4d72232c73864ea17062e8cfdd8e5b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Jun 2021 15:41:46 +0000 Subject: [PATCH 60/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/advanced/multi_gpu.rst | 2 +- pytorch_lightning/core/lightning.py | 4 ++-- .../trainer/connectors/logger_connector/result.py | 8 +++++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index c21e16b92244e..13bd860ca5504 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -118,7 +118,7 @@ It is possible to perform computation manually and log the value on rank 0. def test_epoch_end(self, outputs): mean = torch.mean(self.all_gather(outputs)) - # when logging only rank 0, don't forget to add + # when logging only rank 0, don't forget to add # ``is_global_zero`` in self.log to avoid deadlock. if self.trainer.is_global_zero: self.log("my_reduced_metric", mean, is_global_zero=True) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a8a130c1589a1..305cc91cc19bc 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -323,8 +323,8 @@ def log( instance references on-reload. When the logged Metric are LightningModule attributes, metric_prefix_name should be None. However, when this is not, metric_prefix_name should be provided as Lightning won't be able to find your nn.Metric reference. - is_global_zero: Whether the value will be logged only on rank 0. This will prevent synchornization across processes - and avoid a deadlock. + is_global_zero: Whether the value will be logged only on rank 0. This will prevent synchornization across processes + and avoid a deadlock. """ if tbptt_reduce_fx is not None: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index a750c6a120138..7288448e40849 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -14,13 +14,14 @@ from collections.abc import Generator from copy import deepcopy from dataclasses import asdict, dataclass, replace +from datetime import timedelta from functools import partial, wraps from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union -from datetime import timedelta -import pytorch_lightning as pl + import torch from torchmetrics import Metric +import pytorch_lightning as pl from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections @@ -469,7 +470,8 @@ def log( # check the stored metadata and the current one match elif meta != self[key].meta: - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() raise MisconfigurationException( f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed' ) From 87a9d677e73f795c110c0c230f61f16bd25139de Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 23 Jun 2021 11:44:48 -0400 Subject: [PATCH 61/90] doc update --- docs/source/advanced/multi_gpu.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index c21e16b92244e..662b7b0d3057f 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -106,7 +106,7 @@ Note if you use any built in metrics or custom metrics that use the :doc:`Metric # Add sync_dist=True to sync logging across all GPU workers self.log('test_loss', loss, on_step=True, on_epoch=True, sync_dist=True) -It is possible to perform computation manually and log the value on rank 0. +It is possible to perform some computation manually and log the reduced result on rank 0 as follow: .. testcode:: From 05332007af2f6a98b0dc093943a3d7d72ee607b4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 23 Jun 2021 16:47:04 +0100 Subject: [PATCH 62/90] update flake8 --- pytorch_lightning/core/lightning.py | 4 ++-- .../trainer/connectors/logger_connector/result.py | 2 -- tests/checkpointing/test_checkpoint_callback_frequency.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 305cc91cc19bc..edd25eaddfc6b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -323,8 +323,8 @@ def log( instance references on-reload. When the logged Metric are LightningModule attributes, metric_prefix_name should be None. However, when this is not, metric_prefix_name should be provided as Lightning won't be able to find your nn.Metric reference. - is_global_zero: Whether the value will be logged only on rank 0. This will prevent synchornization across processes - and avoid a deadlock. + is_global_zero: Whether the value will be logged only on rank 0. This will prevent + synchronization across processes and avoid a deadlock. """ if tbptt_reduce_fx is not None: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 7288448e40849..8fe323366ef65 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -14,14 +14,12 @@ from collections.abc import Generator from copy import deepcopy from dataclasses import asdict, dataclass, replace -from datetime import timedelta from functools import partial, wraps from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union import torch from torchmetrics import Metric -import pytorch_lightning as pl from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 5b3fa15216fe1..342f3d9809581 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -107,8 +107,7 @@ def training_step(self, batch, batch_idx): @mock.patch('torch.save') @RunIf(special=True, min_gpus=2) -#@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(1, 1, 1.0, 1), (2, 2, 0.3, 5)]) -@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(2, 2, 0.3, 5)]) +@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(1, 1, 1.0, 1), (2, 2, 0.3, 5)]) def test_top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected): class TestModel(BoringModel): From 6f2b046bf81065ae08b453fe5909144c350f6713 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 23 Jun 2021 11:58:24 -0400 Subject: [PATCH 63/90] remove pdb --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 8fe323366ef65..b22dce37102ec 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -468,8 +468,6 @@ def log( # check the stored metadata and the current one match elif meta != self[key].meta: - import pdb - pdb.set_trace() raise MisconfigurationException( f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed' ) From d831d4b28893f98ef80312f954ae95c94615c21d Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 23 Jun 2021 12:19:35 -0400 Subject: [PATCH 64/90] update on comments --- CHANGELOG.md | 2 +- docs/source/advanced/multi_gpu.rst | 4 ++-- pytorch_lightning/core/lightning.py | 6 +++--- .../trainer/connectors/logger_connector/result.py | 13 +++++-------- tests/checkpointing/test_model_checkpoint.py | 1 - tests/trainer/logging_/test_logger_connector.py | 1 - 6 files changed, 11 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 76fa3ca297165..a9634bf4cebc4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,7 +85,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training * Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) - * Add `result_collections` to checkpoint and `restore_result_collections` to `CheckpointConnector` ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) + * Add `result_collections` to checkpoint and `restore_result_collections` to `CheckpointConnector` and add `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) - Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734)) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 45c2f745c066a..870df6d0dc858 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -106,7 +106,7 @@ Note if you use any built in metrics or custom metrics that use the :doc:`Metric # Add sync_dist=True to sync logging across all GPU workers self.log('test_loss', loss, on_step=True, on_epoch=True, sync_dist=True) -It is possible to perform some computation manually and log the reduced result on rank 0 as follow: +It is possible to perform some computation manually and log the reduced result on rank 0 as follows: .. testcode:: @@ -121,7 +121,7 @@ It is possible to perform some computation manually and log the reduced result o # when logging only rank 0, don't forget to add # ``is_global_zero`` in self.log to avoid deadlock. if self.trainer.is_global_zero: - self.log("my_reduced_metric", mean, is_global_zero=True) + self.log("my_reduced_metric", mean, rank_zero_only=True) Make models pickleable diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index edd25eaddfc6b..b4a7f47f8cf27 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -280,7 +280,7 @@ def log( add_dataloader_idx: bool = True, batch_size: Optional[int] = None, metric_prefix_name: Optional[str] = None, - is_global_zero: Optional[bool] = None, + rank_zero_only: Optional[bool] = None, ) -> None: """ Log a key, value @@ -323,7 +323,7 @@ def log( instance references on-reload. When the logged Metric are LightningModule attributes, metric_prefix_name should be None. However, when this is not, metric_prefix_name should be provided as Lightning won't be able to find your nn.Metric reference. - is_global_zero: Whether the value will be logged only on rank 0. This will prevent + rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization across processes and avoid a deadlock. """ @@ -405,7 +405,7 @@ def log( sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available, sync_dist_group=sync_dist_group, metric_prefix_name=metric_prefix_name, - is_global_zero=is_global_zero, + rank_zero_only=rank_zero_only, ) self.trainer.logger_connector._current_fx = self._current_fx_name diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index b22dce37102ec..3e3947647056b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -47,7 +47,7 @@ class MetricSource(LightningEnum): class _Sync: fn: Optional[Callable] = None should: bool = False - is_global_zero: bool = False + rank_zero_only: bool = False op: Optional[str] = None group: Optional[Any] = None @@ -57,7 +57,7 @@ def __post_init__(self) -> None: @property def should_sync(self) -> bool: - return self.should and not self.is_global_zero + return self.should and not self.rank_zero_only @property def __call__(self) -> Any: @@ -198,7 +198,6 @@ def update(self, value: _METRIC, batch_size: torch.Tensor) -> None: def compute(self) -> torch.Tensor: if self.is_tensor: - print(self.meta.name, "sync", self.meta.sync.is_global_zero) value = self.meta.sync(self.value) if self.meta.is_mean_reduction: cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) @@ -252,10 +251,8 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({state})" def __getstate__(self, drop_value: bool = False) -> dict: - print(self.meta.name, "__getstate__", self.meta.sync.is_global_zero) - with self.sync_context(should_sync=not self.meta.sync.is_global_zero): + with self.sync_context(should_sync=not self.meta.sync.rank_zero_only): d = deepcopy(super().__getstate__()) - print("SYNCED") # metric are being dropped, so they won't be serialized # this would prevent pickling error if their API change. if drop_value and self.is_tensor: @@ -425,7 +422,7 @@ def log( dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, metric_prefix_name: Optional[str] = None, - is_global_zero: bool = False, + rank_zero_only: bool = False, ) -> None: """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" # no metrics should be logged with graphs @@ -459,7 +456,7 @@ def log( should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, - is_global_zero=is_global_zero, + rank_zero_only=rank_zero_only, ) # register logged value if it doesn't exist diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index a461ec61e4ab2..4974e0ac0f7fd 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1401,7 +1401,6 @@ def training_step(self, batch, batch_idx): def on_epoch_end(self) -> None: if self.trainer.current_epoch: - print(self.dummy_metric) total = sum(range(5)) * num_processes metrics = self.trainer.train_loop.results.metrics(on_step=False) assert self.trainer.train_loop.results['training_step.tracking'].value == total diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 0f54b7c16fbe2..b58a75dac2263 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -387,7 +387,6 @@ def _assert_called(model, stage): max_epochs=1, progress_bar_refresh_rate=0, num_sanity_val_steps=2, - checkpoint_callback=False, ) trainer.fit(model) From eec17fb4cc790334044a07ebd08e1710e710cf3e Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 23 Jun 2021 12:42:34 -0400 Subject: [PATCH 65/90] update --- .../trainer/connectors/logger_connector/result.py | 11 ++++------- tests/checkpointing/test_model_checkpoint.py | 3 --- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 3e3947647056b..c759fb616ef6d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -255,7 +255,7 @@ def __getstate__(self, drop_value: bool = False) -> dict: d = deepcopy(super().__getstate__()) # metric are being dropped, so they won't be serialized # this would prevent pickling error if their API change. - if drop_value and self.is_tensor: + if drop_value and not self.is_tensor: del d["value"] d['meta'] = d['meta'].__getstate__() d['_class'] = self.__class__.__name__ @@ -287,10 +287,10 @@ def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None: super().__init__(*args) self.meta = metadata - def __getstate__(self) -> dict: + def __getstate__(self, drop_value: bool = False) -> dict: def getstate(item: ResultMetric) -> dict: - return item.__getstate__(drop_value=True) + return item.__getstate__(drop_value=drop_value) items = apply_to_collection(dict(self), (ResultMetric, ResultMetricCollection), getstate) return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__} @@ -641,11 +641,8 @@ def __getstate__(self) -> dict: if extra is not None: d['_extra'] = extra - for result_metric in self.result_metrics: - result_metric.meta.name - # all the items should be either `ResultMetric`s or `ResultMetricCollection`s - items = {k: v.__getstate__() for k, v in self.items() if k not in ('_extra', 'fx_validator')} + items = {k: v.__getstate__(drop_value=True) for k, v in self.items() if k not in ('_extra', 'fx_validator')} return {**d, 'items': items} def __setstate__( diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 4974e0ac0f7fd..3fedc62d8929b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1340,9 +1340,6 @@ def update(self, increment): def compute(self): return self.sum // self.count - def __repr__(self): - return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})" - def result_collection_reload(trainer_kwargs): num_processes = trainer_kwargs.get("gpus", 1) From d5db9d5b30ec8e22cec29334c8f98ae49172cc83 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 23 Jun 2021 12:46:30 -0400 Subject: [PATCH 66/90] resolve test --- tests/checkpointing/test_checkpoint_callback_frequency.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 342f3d9809581..0073676a77eec 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -120,7 +120,7 @@ def training_step(self, batch, batch_idx): def training_epoch_end(self, outputs) -> None: local_rank = int(os.getenv("LOCAL_RANK")) if self.trainer.is_global_zero: - self.log('my_loss_2', (1 + local_rank), on_epoch=True, is_global_zero=True) + self.log('my_loss_2', (1 + local_rank), on_epoch=True, rank_zero_only=True) data = str(self.global_rank) obj = [[data], (data, ), set(data)] out = self.trainer.training_type_plugin.broadcast(obj) From 1ed5eb888cd39b6baefc2cb1128a073d33dd2e3a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 23 Jun 2021 18:50:08 +0200 Subject: [PATCH 67/90] format --- .../trainer/connectors/logger_connector/result.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index c759fb616ef6d..35709062538ce 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -690,8 +690,9 @@ def load_state_dict( self.__setstate__(state_dict, map_location=map_location, sync_fn=sync_fn) - if metrics: - for metric_prefix_name, metric in metrics.items(): - for result_metric in self.result_metrics: - if result_metric.meta.metric_prefix_name == metric_prefix_name: - result_metric.value = metric + if not metrics: + return + for metric_prefix_name, metric in metrics.items(): + for result_metric in self.result_metrics: + if result_metric.meta.metric_prefix_name == metric_prefix_name: + result_metric.value = metric From f7f19922f3bd18b37e406280703adf6499b6744f Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 23 Jun 2021 13:01:50 -0400 Subject: [PATCH 68/90] resolve tests --- .../trainer/connectors/logger_connector/result.py | 10 +++++----- tests/core/test_metric_result_integration.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index c759fb616ef6d..1496942623f61 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -255,7 +255,7 @@ def __getstate__(self, drop_value: bool = False) -> dict: d = deepcopy(super().__getstate__()) # metric are being dropped, so they won't be serialized # this would prevent pickling error if their API change. - if drop_value and not self.is_tensor: + if drop_value and not self.is_tensor and "value" in d: del d["value"] d['meta'] = d['meta'].__getstate__() d['_class'] = self.__class__.__name__ @@ -630,7 +630,7 @@ def cpu(self) -> 'ResultCollection': def __str__(self) -> str: return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' - def __getstate__(self) -> dict: + def __getstate__(self, drop_value: bool = True) -> dict: d = self.__dict__.copy() d["fx_validator"] = None # can't deepcopy tensors with grad_fn @@ -642,7 +642,7 @@ def __getstate__(self) -> dict: d['_extra'] = extra # all the items should be either `ResultMetric`s or `ResultMetricCollection`s - items = {k: v.__getstate__(drop_value=True) for k, v in self.items() if k not in ('_extra', 'fx_validator')} + items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items() if k not in ('_extra', 'fx_validator')} return {**d, 'items': items} def __setstate__( @@ -675,8 +675,8 @@ def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: device = map_location or self.device self.to(device) - def state_dict(self) -> dict: - return self.__getstate__() + def state_dict(self, drop_value: bool = True) -> dict: + return self.__getstate__(drop_value) def load_state_dict( self, diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 8785c70c28f8f..4b6aafb0a63d5 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -296,7 +296,7 @@ def validation_step(self, batch, batch_idx): def on_save_checkpoint(self, checkpoint) -> None: results = self.trainer._results - state_dict = results.state_dict() + state_dict = results.state_dict(drop_value=False) # check device assert results['validation_step.v'].value.device.type == device From a722c61ec7863d8f44b33cf11dc5b43b15fcf8e9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Jun 2021 17:03:16 +0000 Subject: [PATCH 69/90] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/logger_connector/result.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 1fa191e504b72..8105ee3f25caf 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -642,7 +642,10 @@ def __getstate__(self, drop_value: bool = True) -> dict: d['_extra'] = extra # all the items should be either `ResultMetric`s or `ResultMetricCollection`s - items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items() if k not in ('_extra', 'fx_validator')} + items = { + k: v.__getstate__(drop_value=drop_value) + for k, v in self.items() if k not in ('_extra', 'fx_validator') + } return {**d, 'items': items} def __setstate__( From 45aa7935ad63b6a4fe73b6d425a14e397a561c80 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 23 Jun 2021 13:06:42 -0400 Subject: [PATCH 70/90] update --- tests/core/test_metric_result_integration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 4b6aafb0a63d5..ebefb46603a22 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -296,6 +296,7 @@ def validation_step(self, batch, batch_idx): def on_save_checkpoint(self, checkpoint) -> None: results = self.trainer._results + # simplify logic state_dict = results.state_dict(drop_value=False) # check device From ac8df1ca53aabac6022bb7571c9e4464aab55005 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 23 Jun 2021 13:32:32 -0400 Subject: [PATCH 71/90] update --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c24ea2b7c937b..e6b373036675e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ tqdm>=4.41.0 PyYAML>=5.1,<=5.4.1 fsspec[http]>=2021.05.0, !=2021.06.0 tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!' -torchmetrics>=0.4.0rc0 +torchmetrics>=0.4.0rc1 pyDeprecate==0.3.1 packaging>=17.0 typing-extensions # TypedDict support for python<3.8 From 4c9c0c1ae554c8ee4bd24f5f771abbc381c2cf32 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 23 Jun 2021 13:56:03 -0400 Subject: [PATCH 72/90] update --- tests/trainer/logging_/test_logger_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b58a75dac2263..391ddb43ab5b4 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -309,7 +309,7 @@ def _step(self, stage, batch): logits = self.forward(batch) loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1)) probs = torch.sigmoid(logits.detach()) - self.log(f"loss/{stage}", loss, metric_prefix_name="dummy") + self.log(f"loss/{stage}", loss, metric_prefix_name="dummy", rank_zero_only=True) acc = self._modules[f"acc_{stage}"] ap = self._modules[f"ap_{stage}"] @@ -322,8 +322,8 @@ def _step(self, stage, batch): acc.reset.reset_mock() ap.reset.reset_mock() - self.log(f"{stage}/accuracy", acc, metric_prefix_name="dummy") - self.log(f"{stage}/ap", ap, metric_prefix_name="dummy") + self.log(f"{stage}/accuracy", acc, metric_prefix_name="dummy", rank_zero_only=True) + self.log(f"{stage}/ap", ap, metric_prefix_name="dummy", rank_zero_only=True) return loss From 66ad312e0efb913bd1e3c804a8e73c0b43f2bf82 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Wed, 23 Jun 2021 14:10:20 -0400 Subject: [PATCH 73/90] update --- tests/trainer/logging_/test_logger_connector.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 391ddb43ab5b4..0f54b7c16fbe2 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -309,7 +309,7 @@ def _step(self, stage, batch): logits = self.forward(batch) loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1)) probs = torch.sigmoid(logits.detach()) - self.log(f"loss/{stage}", loss, metric_prefix_name="dummy", rank_zero_only=True) + self.log(f"loss/{stage}", loss, metric_prefix_name="dummy") acc = self._modules[f"acc_{stage}"] ap = self._modules[f"ap_{stage}"] @@ -322,8 +322,8 @@ def _step(self, stage, batch): acc.reset.reset_mock() ap.reset.reset_mock() - self.log(f"{stage}/accuracy", acc, metric_prefix_name="dummy", rank_zero_only=True) - self.log(f"{stage}/ap", ap, metric_prefix_name="dummy", rank_zero_only=True) + self.log(f"{stage}/accuracy", acc, metric_prefix_name="dummy") + self.log(f"{stage}/ap", ap, metric_prefix_name="dummy") return loss @@ -387,6 +387,7 @@ def _assert_called(model, stage): max_epochs=1, progress_bar_refresh_rate=0, num_sanity_val_steps=2, + checkpoint_callback=False, ) trainer.fit(model) From ced63b3c9a94d7331cc9ce6901203578164f9e94 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 25 Jun 2021 09:48:15 +0100 Subject: [PATCH 74/90] resolve conflicts --- .../trainer/connectors/checkpoint_connector.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 12608b9177010..7310f027d137f 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -266,8 +266,8 @@ def restore_result_collections(self) -> None: sync_fn = self.trainer.training_type_plugin.reduce # get current result collections - train_results = self.trainer.fit_loop.training_loop.results - validation_results = self.trainer.fit_loop.validation_loop.results + train_results = self.trainer.fit_loop.epoch_loop.results + validation_results = self.trainer.fit_loop.val_loop.results validate_results = self.trainer.validation_loop.results test_results = self.trainer.test_loop.results @@ -457,8 +457,8 @@ def get_loops_state_dict(self): def get_result_collections_state_dict(self): return { TrainerFn.FITTING.value: { - RunningStage.TRAINING.value: self.trainer.fit_loop.training_loop.results.state_dict(), - RunningStage.VALIDATING.value: self.trainer.fit_loop.validation_loop.results.state_dict(), + RunningStage.TRAINING.value: self.trainer.fit_loop.epoch_loop.results.state_dict(), + RunningStage.VALIDATING.value: self.trainer.fit_loop.val_loop.results.state_dict(), }, RunningStage.VALIDATING.value: self.trainer.validation_loop.results.state_dict(), RunningStage.TESTING.value: self.trainer.evaluation_loop.results.state_dict(), From 6c503329be6912972b4b8632252dcde7f1c5d72e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 16:48:25 +0200 Subject: [PATCH 75/90] Update CHANGELOG --- CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e819ef2b82be4..40a7cf54676b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,7 +85,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training * Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) - * Add `result_collections` to checkpoint and `restore_result_collections` to `CheckpointConnector` and add `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) + * Checkpoint the loop results ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) + + +- Add `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) + - Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734)) From 766ef710820355a633a9eadf004fc45f98236d42 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 16:48:32 +0200 Subject: [PATCH 76/90] Docs --- docs/source/advanced/multi_gpu.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 870df6d0dc858..699be201f95b8 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -118,8 +118,8 @@ It is possible to perform some computation manually and log the reduced result o def test_epoch_end(self, outputs): mean = torch.mean(self.all_gather(outputs)) - # when logging only rank 0, don't forget to add - # ``is_global_zero`` in self.log to avoid deadlock. + # When logging only on rank 0, don't forget to add + # ``rank_zero_only=True`` to avoid deadlocks on synchronization. if self.trainer.is_global_zero: self.log("my_reduced_metric", mean, rank_zero_only=True) From d8980166edd4c2081e832ed2cf8090cb2cb05327 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 17:40:40 +0200 Subject: [PATCH 77/90] Rename metric prefix name --- pytorch_lightning/core/lightning.py | 48 +++++++++++-------- .../connectors/logger_connector/result.py | 10 ++-- tests/core/test_metric_result_integration.py | 6 +-- .../trainer/logging_/test_logger_connector.py | 6 +-- 4 files changed, 38 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index be3627445ed2f..dc41044d041ba 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -112,7 +112,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() - self._map_id_to_metrics_name: Optional[Dict[int, str]] = None + self._metric_attributes: Optional[Dict[int, str]] = None def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -219,7 +219,7 @@ def logger(self): def state_dict(self, *args, **kwargs) -> Dict[str, Any]: # drop the map id to metrics to avoid saving it. - self._map_id_to_metrics_name = None + self._metric_attributes = None return super().state_dict(*args, **kwargs) def _apply_batch_transfer_handler( @@ -279,7 +279,7 @@ def log( sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, - metric_prefix_name: Optional[str] = None, + metric_attribute: Optional[str] = None, rank_zero_only: Optional[bool] = None, ) -> None: """ @@ -318,14 +318,10 @@ def log( each dataloader to not mix values batch_size: Current batch_size. This will be directly inferred from the loaded batch, but some data structures might need to explicitly provide it. - metric_prefix_name: To enable ``Fault Tolerant Logging``, Lightning requires - a way to restore TorchMetric Metric - instance references on-reload. When the logged Metric are LightningModule attributes, - metric_prefix_name should be None. However, when this is not, metric_prefix_name should be provided as - Lightning won't be able to find your nn.Metric reference. - rank_zero_only: Whether the value will be logged only on rank 0. This will prevent - synchronization across processes and avoid a deadlock. - + metric_attribute: To restore the metric state, Lightning requires the reference of the + :class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute. + rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which + would produce a deadlock as not all processes would perform this log call. """ if tbptt_reduce_fx is not None: rank_zero_deprecation( @@ -378,16 +374,26 @@ def log( # reset any tensors for the new hook name results.reset(metrics=False, fx=self._current_fx_name) - if metric_prefix_name is None and isinstance(value, Metric): - # this is used to efficiently find the attribute prefix path of metric objects - # this will enable Lightning to re-attach metric reference when reloading states. - if self._map_id_to_metrics_name is None: - self._map_id_to_metrics_name = { - id(module): module_name - for module_name, module in self._named_members(lambda module: module._modules.items()) - if isinstance(module, Metric) + if metric_attribute is None and isinstance(value, Metric): + if self._metric_attributes is None: + # compute once + self._metric_attributes = { + id(module): name + for name, module in self.named_children() if isinstance(module, Metric) } - metric_prefix_name = self._map_id_to_metrics_name[id(value)] + if not self._metric_attributes: + raise MisconfigurationException( + "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." + " You can fix this by setting an attribute for the metric in your `LightningModule`." + ) + # try to find the passed metric in the LightningModule + metric_attribute = self._metric_attributes.get(id(value)) + if metric_attribute is None: + raise MisconfigurationException( + "Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged." + f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one" + f" of {list(self._metric_attributes.values())}" + ) results.log( self._current_fx_name, @@ -404,7 +410,7 @@ def log( sync_dist=sync_dist, sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available, sync_dist_group=sync_dist_group, - metric_prefix_name=metric_prefix_name, + metric_attribute=metric_attribute, rank_zero_only=rank_zero_only, ) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 8105ee3f25caf..7594a05c5930c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -79,7 +79,7 @@ class _Metadata: _reduce_fx: Callable = torch.mean enable_graph: bool = False dataloader_idx: Optional[int] = None - metric_prefix_name: Optional[str] = None + metric_attribute: Optional[str] = None _sync: Optional[_Sync] = None @property @@ -421,7 +421,7 @@ def log( sync_dist_group: Optional[Any] = None, dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, - metric_prefix_name: Optional[str] = None, + metric_attribute: Optional[str] = None, rank_zero_only: bool = False, ) -> None: """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" @@ -449,7 +449,7 @@ def log( on_epoch=on_epoch, enable_graph=enable_graph, dataloader_idx=dataloader_idx, - metric_prefix_name=metric_prefix_name, + metric_attribute=metric_attribute, ) meta.reduce_fx = reduce_fx meta.sync = _Sync( @@ -695,7 +695,7 @@ def load_state_dict( if not metrics: return - for metric_prefix_name, metric in metrics.items(): + for metric_attribute, metric in metrics.items(): for result_metric in self.result_metrics: - if result_metric.meta.metric_prefix_name == metric_prefix_name: + if result_metric.meta.metric_attribute == metric_attribute: result_metric.value = metric diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ebefb46603a22..989ea58efddb0 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -222,9 +222,9 @@ def lightning_log(fx, *args, **kwargs): cumulative_sum += i metric = metric_a if i < 1 else metric_d - lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True, metric_prefix_name="metric") - lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_prefix_name="metric_b") - lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_prefix_name="metric_c") + lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True, metric_attribute="metric") + lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b") + lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c") lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True) lightning_log('training_step', 'b_1', b, on_step=False, on_epoch=True) lightning_log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 0f54b7c16fbe2..4be96d857a56f 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -309,7 +309,7 @@ def _step(self, stage, batch): logits = self.forward(batch) loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1)) probs = torch.sigmoid(logits.detach()) - self.log(f"loss/{stage}", loss, metric_prefix_name="dummy") + self.log(f"loss/{stage}", loss, metric_attribute="dummy") acc = self._modules[f"acc_{stage}"] ap = self._modules[f"ap_{stage}"] @@ -322,8 +322,8 @@ def _step(self, stage, batch): acc.reset.reset_mock() ap.reset.reset_mock() - self.log(f"{stage}/accuracy", acc, metric_prefix_name="dummy") - self.log(f"{stage}/ap", ap, metric_prefix_name="dummy") + self.log(f"{stage}/accuracy", acc, metric_attribute="dummy") + self.log(f"{stage}/ap", ap, metric_attribute="dummy") return loss From 209298c43abe5f0a2fdfbe18a2b7ca0fae271449 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 18:32:07 +0200 Subject: [PATCH 78/90] Refactor metric reset test --- pytorch_lightning/trainer/trainer.py | 2 + .../trainer/logging_/test_logger_connector.py | 102 ++++++++---------- 2 files changed, 49 insertions(+), 55 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c5ee90cd126ce..9d581dc560f23 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1126,6 +1126,8 @@ def _call_teardown_hook(self, model: LightningModule) -> None: model._current_fx_name = None model._current_dataloader_idx = None + # these could have become stale if metrics are defined in `setup` + model._metric_attributes = None def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # Note this implementation is copy/pasted into the TrainLoop class in TrainingEpochLoop._on_train_epoch_end_hook diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 4be96d857a56f..592fde1569344 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -23,7 +23,6 @@ from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection -from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -293,48 +292,63 @@ def __init__(self): super().__init__() self.layer = torch.nn.Linear(32, 1) - for stage in ['train', 'val', 'test']: - acc = Accuracy() - acc.reset = mock.Mock(side_effect=acc.reset) - ap = AveragePrecision(num_classes=1, pos_label=1) - ap.reset = mock.Mock(side_effect=ap.reset) - self.add_module(f"acc_{stage}", acc) - self.add_module(f"ap_{stage}", ap) + def _create_metrics(self): + acc = Accuracy() + acc.reset = mock.Mock(side_effect=acc.reset) + ap = AveragePrecision(num_classes=1, pos_label=1) + ap.reset = mock.Mock(side_effect=ap.reset) + return acc, ap + + def setup(self, stage): + fn = stage + if fn == 'fit': + for stage in ('train', 'validate'): + acc, ap = self._create_metrics() + self.add_module(f"acc_{fn}_{stage}", acc) + self.add_module(f"ap_{fn}_{stage}", ap) + else: + acc, ap = self._create_metrics() + stage = self.trainer.state.stage + self.add_module(f"acc_{fn}_{stage}", acc) + self.add_module(f"ap_{fn}_{stage}", ap) def forward(self, x): return self.layer(x) - def _step(self, stage, batch): - labels = (batch.detach().sum(1) > 0).float() # Fake some targets - logits = self.forward(batch) - loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1)) - probs = torch.sigmoid(logits.detach()) - self.log(f"loss/{stage}", loss, metric_attribute="dummy") + def _step(self, batch): + fn, stage = self.trainer.state.fn, self.trainer.state.stage + + logits = self(batch) + loss = logits.sum() + self.log(f"loss/{fn}_{stage}", loss) - acc = self._modules[f"acc_{stage}"] - ap = self._modules[f"ap_{stage}"] + acc = self._modules[f"acc_{fn}_{stage}"] + ap = self._modules[f"ap_{fn}_{stage}"] - labels_int = labels.to(torch.long) - acc(probs.flatten(), labels_int) - ap(probs.flatten(), labels_int) + preds = torch.rand(len(batch)) # Fake preds + labels = torch.randint(0, 1, [len(batch)]) # Fake targets + acc(preds, labels) + ap(preds, labels) # Metric.forward calls reset so reset the mocks here acc.reset.reset_mock() ap.reset.reset_mock() - self.log(f"{stage}/accuracy", acc, metric_attribute="dummy") - self.log(f"{stage}/ap", ap, metric_attribute="dummy") + self.log(f"acc/{fn}_{stage}", acc) + self.log(f"ap/{fn}_{stage}", ap) return loss def training_step(self, batch, batch_idx, *args, **kwargs): - return self._step('train', batch) + return self._step(batch) def validation_step(self, batch, batch_idx, *args, **kwargs): - return self._step('val', batch) + if self.trainer.sanity_checking: + return + return self._step(batch) def test_step(self, batch, batch_idx, *args, **kwargs): - return self._step('test', batch) + return self._step(batch) def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -350,33 +364,11 @@ def val_dataloader(self): def test_dataloader(self): return DataLoader(RandomDataset(32, 64)) - def _assert_epoch_end(self, stage): - acc = self._modules[f"acc_{stage}"] - ap = self._modules[f"ap_{stage}"] - - acc.reset.assert_called_once() - ap.reset.assert_called_once() - - def teardown(self, stage): - if stage == TrainerFn.FITTING: - self._assert_epoch_end('train') - self._assert_epoch_end('val') - - elif stage == TrainerFn.VALIDATING: - self._assert_epoch_end('val') - - elif stage == TrainerFn.TESTING: - self._assert_epoch_end('test') - - def _assert_called(model, stage): - acc = model._modules[f"acc_{stage}"] - ap = model._modules[f"ap_{stage}"] - - assert acc.reset.call_count == 1 - acc.reset.reset_mock() - - assert ap.reset.call_count == 1 - ap.reset.reset_mock() + def _assert_called(model, fn, stage): + acc = model._modules[f"acc_{fn}_{stage}"] + ap = model._modules[f"ap_{fn}_{stage}"] + acc.reset.assert_called_once() + ap.reset.assert_called_once() model = TestModel() trainer = Trainer( @@ -391,14 +383,14 @@ def _assert_called(model, stage): ) trainer.fit(model) - _assert_called(model, 'train') - _assert_called(model, 'val') + _assert_called(model, 'fit', 'train') + _assert_called(model, 'fit', 'validate') trainer.validate(model) - _assert_called(model, 'val') + _assert_called(model, 'validate', 'validate') trainer.test(model) - _assert_called(model, 'test') + _assert_called(model, 'test', 'test') def test_result_collection_on_tensor_with_mean_reduction(): From 08f3c79cfc337092a0a59486ceee0fadc8dcaad9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 18:34:07 +0200 Subject: [PATCH 79/90] Typos --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index de29c475b3e99..b71fc10609cdc 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -276,7 +276,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None - # requires to compute the state_dict on all processes in case Metric are presents + # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() if self.global_rank == 0 and self.mp_queue is not None: diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 84c1eed6e0793..68e189f6f60cd 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -185,7 +185,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None - # requires to compute the state_dict on all processes in case Metric are presents + # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() if self.mp_queue is not None: From 358cbd3c9e054d779e572abae99a3046511ed71e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 18:35:21 +0200 Subject: [PATCH 80/90] No need for should sync property --- .../trainer/connectors/logger_connector/result.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 7594a05c5930c..c2ee0da8eeb41 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -55,13 +55,12 @@ def __post_init__(self) -> None: if self.fn is None: self.fn = self.no_op - @property - def should_sync(self) -> bool: - return self.should and not self.rank_zero_only - @property def __call__(self) -> Any: - return partial(self.fn, reduce_op=self.op, group=self.group) if self.should_sync else self.no_op + return ( + partial(self.fn, reduce_op=self.op, group=self.group) + if self.should and not self.rank_zero_only else self.no_op + ) @staticmethod def no_op(value: Any, *_, **__) -> Any: From 14d6f413c709c4a9349ced85665c983e9e3fc186 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 18:48:13 +0200 Subject: [PATCH 81/90] Decouple distributeda available --- pytorch_lightning/core/lightning.py | 6 +++--- .../trainer/connectors/logger_connector/result.py | 7 ++++++- pytorch_lightning/utilities/distributed.py | 8 ++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index dc41044d041ba..41f9cc34d084b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -43,7 +43,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.utilities.distributed import sync_ddp_if_available +from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -407,8 +407,8 @@ def log( enable_graph=enable_graph, dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None), batch_size=batch_size, - sync_dist=sync_dist, - sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp_if_available, + sync_dist=sync_dist and distributed_available(), + sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp, sync_dist_group=sync_dist_group, metric_attribute=metric_attribute, rank_zero_only=rank_zero_only, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index c2ee0da8eeb41..531d13bb15b9a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -24,6 +24,7 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.metrics import metrics_to_scalars @@ -250,7 +251,11 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({state})" def __getstate__(self, drop_value: bool = False) -> dict: - with self.sync_context(should_sync=not self.meta.sync.rank_zero_only): + with self.sync_context( + should_sync=not self.meta.sync.rank_zero_only, + process_group=self.meta.sync.group, + distributed_available=distributed_available + ): d = deepcopy(super().__getstate__()) # metric are being dropped, so they won't be serialized # this would prevent pickling error if their API change. diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index ae977bd03bac8..5094f55ba59f8 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -135,6 +135,10 @@ def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None) return gathered_result +def distributed_available() -> bool: + return torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() + + def sync_ddp_if_available( result: Union[torch.Tensor], group: Optional[Any] = None, @@ -151,7 +155,7 @@ def sync_ddp_if_available( Return: reduced value """ - if torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed(): + if distributed_available(): return sync_ddp(result, group=group, reduce_op=reduce_op) return result @@ -230,7 +234,7 @@ def all_gather_ddp_if_available( A tensor of shape (world_size, batch, ...) """ group = group if group is not None else torch.distributed.group.WORLD - if torch.distributed.is_available() and torch.distributed.is_initialized(): + if distributed_available(): if sync_grads: return AllGatherGrad.apply(tensor, group) else: From 9471db60fe18dbafa8b6fd7b6403af58f9a5b1c4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 18:59:04 +0200 Subject: [PATCH 82/90] Avoid deepcopy and dropping value --- .../trainer/connectors/logger_connector/result.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 531d13bb15b9a..41d820f517b45 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Generator -from copy import deepcopy from dataclasses import asdict, dataclass, replace from functools import partial, wraps from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union @@ -251,16 +250,16 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({state})" def __getstate__(self, drop_value: bool = False) -> dict: + skip = ['update', 'compute', '_update_signature'] + if not self.is_tensor and drop_value: + # Avoid serializing ResultMetrics which are passed Metrics + skip.append('value') with self.sync_context( should_sync=not self.meta.sync.rank_zero_only, process_group=self.meta.sync.group, distributed_available=distributed_available ): - d = deepcopy(super().__getstate__()) - # metric are being dropped, so they won't be serialized - # this would prevent pickling error if their API change. - if drop_value and not self.is_tensor and "value" in d: - del d["value"] + d = {k: v for k, v in self.__dict__.items() if k not in skip} d['meta'] = d['meta'].__getstate__() d['_class'] = self.__class__.__name__ return d From 032c1f8376abc99c86d080184cd552459da1fca5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 19:05:07 +0200 Subject: [PATCH 83/90] Remove fx validator in getstate --- .../trainer/connectors/logger_connector/result.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 41d820f517b45..6c443b1bf9a54 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -634,12 +634,13 @@ def __str__(self) -> str: return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' def __getstate__(self, drop_value: bool = True) -> dict: - d = self.__dict__.copy() - d["fx_validator"] = None + d = {k: v for k, v in self.__dict__.items() if k != 'fx_validator'} + # can't deepcopy tensors with grad_fn minimize = d['_minimize'] if minimize is not None: d['_minimize'] = minimize.detach() + extra = self.get('_extra') if extra is not None: d['_extra'] = extra @@ -657,7 +658,6 @@ def __setstate__( map_location: Optional[Union[str, torch.device]] = None, sync_fn: Optional[Callable] = None, ) -> None: - self.__dict__.update({k: v for k, v in state.items() if k != 'items'}) def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: @@ -676,8 +676,6 @@ def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: items = {k: setstate(k, v) for k, v in state['items'].items()} self.update(items) - self.fx_validator = FxValidator() - device = map_location or self.device self.to(device) @@ -691,9 +689,6 @@ def load_state_dict( sync_fn: Optional[Callable] = None, metrics: Optional[Dict[str, Metric]] = None, ) -> None: - - self.fx_validator = FxValidator() - self.__setstate__(state_dict, map_location=map_location, sync_fn=sync_fn) if not metrics: From ee379cf308f79f778d28da1362d957e77ea71045 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 19:06:16 +0200 Subject: [PATCH 84/90] fx_validator shouldn't be in self.items() --- .../trainer/connectors/logger_connector/result.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 6c443b1bf9a54..c9c9afebc940b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -646,10 +646,7 @@ def __getstate__(self, drop_value: bool = True) -> dict: d['_extra'] = extra # all the items should be either `ResultMetric`s or `ResultMetricCollection`s - items = { - k: v.__getstate__(drop_value=drop_value) - for k, v in self.items() if k not in ('_extra', 'fx_validator') - } + items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items() if k != '_extra'} return {**d, 'items': items} def __setstate__( From 710819050c616c5e9443d3a0479f1642294ba947 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 19:08:31 +0200 Subject: [PATCH 85/90] Add reduce comment --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index c9c9afebc940b..52d6c937bf33c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -505,6 +505,7 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten cache = result_metric._forward_cache elif not on_step and result_metric.meta.on_epoch: if not result_metric._computed: + # always reduce on epoch end should = result_metric.meta.sync.should result_metric.meta.sync.should = True result_metric.compute() From eebd4d4fcefb657769cf1e2561e95d0f27e43af3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 19:29:17 +0200 Subject: [PATCH 86/90] Improve result metrics property --- .../trainer/connectors/logger_connector/result.py | 5 +++-- tests/core/test_metric_result_integration.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 52d6c937bf33c..3dd12d4316ef7 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -244,7 +244,7 @@ def __setattr__(self, key: str, value: Any) -> None: object.__setattr__(self, key, value) def __repr__(self) -> str: - state = f"value={self.value}" + state = f"{repr(self.meta.name)}, value={self.value}" if self.is_tensor and self.meta.is_mean_reduction: state += f", cumulated_batch_size={self.cumulated_batch_size}" return f"{self.__class__.__name__}({state})" @@ -691,7 +691,8 @@ def load_state_dict( if not metrics: return + result_metrics = self.result_metrics for metric_attribute, metric in metrics.items(): - for result_metric in self.result_metrics: + for result_metric in result_metrics: if result_metric.meta.metric_attribute == metric_attribute: result_metric.value = metric diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 989ea58efddb0..7471914886a27 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -135,9 +135,9 @@ def test_result_metric_integration(): assert str(result) == ( "ResultCollection(True, cpu, {" - "'h.a': ResultMetric(value=DummyMetric()), " - "'h.b': ResultMetric(value=DummyMetric()), " - "'h.c': ResultMetric(value=DummyMetric())" + "'h.a': ResultMetric('a', value=DummyMetric()), " + "'h.b': ResultMetric('b', value=DummyMetric()), " + "'h.c': ResultMetric('c', value=DummyMetric())" "})" ) @@ -208,7 +208,7 @@ def lightning_log(fx, *args, **kwargs): result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist) current_fx_name = fx - for _ in range(2): + for epoch in range(2): cumulative_sum = 0 @@ -238,6 +238,10 @@ def lightning_log(fx, *args, **kwargs): state_dict = result.state_dict() # check the sync fn was dropped assert 'fn' not in state_dict['items']['training_step.a']['meta']['_sync'] + + assert not new_result.result_metrics + assert len(result.result_metrics) == 7 + epoch > 0 + new_result.load_state_dict( state_dict, metrics={ "metric": metric, From e7bef8c7182fb6312bf54093913f2ba65d47c54c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 19:37:34 +0200 Subject: [PATCH 87/90] State dict wouldnt save metric attributes --- pytorch_lightning/core/lightning.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 41f9cc34d084b..10716c3c3d6f1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -217,11 +217,6 @@ def logger(self): """ Reference to the logger object in the Trainer. """ return self.trainer.logger if self.trainer else None - def state_dict(self, *args, **kwargs) -> Dict[str, Any]: - # drop the map id to metrics to avoid saving it. - self._metric_attributes = None - return super().state_dict(*args, **kwargs) - def _apply_batch_transfer_handler( self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None ) -> Any: From 4e49c983cd9d1da3540fe8f01518178e98f22808 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 20:14:34 +0200 Subject: [PATCH 88/90] Resolve circular imports to import the fx_validator --- pytorch_lightning/core/lightning.py | 3 ++- pytorch_lightning/loggers/base.py | 6 ++--- pytorch_lightning/loggers/comet.py | 4 +-- pytorch_lightning/loggers/tensorboard.py | 4 +-- pytorch_lightning/loggers/test_tube.py | 4 +-- pytorch_lightning/overrides/fairscale.py | 4 +-- .../plugins/precision/apex_amp.py | 7 +++-- pytorch_lightning/plugins/precision/double.py | 6 ++--- .../plugins/training_type/ipu.py | 6 ++--- .../plugins/training_type/sharded.py | 4 +-- .../plugins/training_type/sharded_spawn.py | 4 +-- pytorch_lightning/trainer/callback_hook.py | 10 +++---- .../trainer/connectors/callback_connector.py | 4 +-- .../connectors/checkpoint_connector.py | 26 +++++++++---------- .../connectors/logger_connector/result.py | 4 +-- pytorch_lightning/trainer/data_loading.py | 10 +++---- pytorch_lightning/trainer/model_hooks.py | 6 ++--- pytorch_lightning/trainer/optimizers.py | 4 +-- pytorch_lightning/trainer/properties.py | 6 ++--- pytorch_lightning/trainer/trainer.py | 20 +++++++------- pytorch_lightning/trainer/training_tricks.py | 4 +-- pytorch_lightning/utilities/model_helpers.py | 13 +++++----- 22 files changed, 77 insertions(+), 82 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 10716c3c3d6f1..bf05b1f0772f0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -39,6 +39,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES +from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -353,7 +354,7 @@ def log( results = self.trainer._results assert results is not None assert self._current_fx_name is not None - results.fx_validator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) + FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) # make sure user doesn't introduce logic for multi-dataloaders if "/dataloader_idx_" in name: diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 7736ed24baefe..803d08eb3e645 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -25,8 +25,8 @@ import numpy as np import torch +import pytorch_lightning as pl from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_only @@ -300,7 +300,7 @@ def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): kwargs: Optional keywoard arguments, depends on the specific logger being used """ - def log_graph(self, model: LightningModule, input_array=None) -> None: + def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None: """ Record model graph @@ -396,7 +396,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: for logger in self._logger_iterable: logger.log_hyperparams(params) - def log_graph(self, model: LightningModule, input_array=None) -> None: + def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None: for logger in self._logger_iterable: logger.log_graph(model, input_array) diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 148e512f5e439..498a16a9daa29 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -24,7 +24,7 @@ import torch from torch import is_tensor -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -318,6 +318,6 @@ def __getstate__(self): state["_experiment"] = None return state - def log_graph(self, model: LightningModule, input_array=None) -> None: + def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None: if self._experiment is not None: self._experiment.set_model_graph(model) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index b69f31ae53b32..d59830bd98ae4 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -25,7 +25,7 @@ from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard.summary import hparams -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn @@ -223,7 +223,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> raise ValueError(m) from ex @rank_zero_only - def log_graph(self, model: LightningModule, input_array=None): + def log_graph(self, model: 'pl.LightningModule', input_array=None): if self._log_graph: if input_array is None: input_array = model.example_input_array diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 1107a0bcb2c4c..1650ab8f4ba49 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -18,7 +18,7 @@ from argparse import Namespace from typing import Any, Dict, Optional, Union -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import _module_available, rank_zero_warn from pytorch_lightning.utilities.distributed import rank_zero_only @@ -153,7 +153,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> self.experiment.log(metrics, global_step=step) @rank_zero_only - def log_graph(self, model: LightningModule, input_array=None): + def log_graph(self, model: 'pl.LightningModule', input_array=None): if self._log_graph: if input_array is None: input_array = model.example_input_array diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index f7c3b8d5fd575..e531db6de77f3 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE @@ -23,7 +23,7 @@ class LightningShardedDataParallel(_LightningModuleWrapperBase): # Just do this for later docstrings pass - def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule: + def unwrap_lightning_module_sharded(wrapped_model) -> 'pl.LightningModule': model = wrapped_model if isinstance(model, ShardedDataParallel): model = model.module diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 21253ea9ab4a0..b2565e7dd34b4 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -18,7 +18,6 @@ from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType from pytorch_lightning.utilities.types import _PARAMETERS @@ -50,7 +49,7 @@ def dispatch(self, trainer: 'pl.Trainer') -> None: def backward( self, - model: LightningModule, + model: 'pl.LightningModule', closure_loss: Tensor, optimizer: Optimizer, opt_idx: int, @@ -76,7 +75,7 @@ def backward( # do backward pass # TODO: not entirely sure, why we need this - if model is not None and isinstance(model, LightningModule): + if model is not None and isinstance(model, pl.LightningModule): model.backward(closure_loss, optimizer, opt_idx, **kwargs) # TODO: avoid dev_debugger and track these calls with mock @@ -118,7 +117,7 @@ def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Seq def pre_optimizer_step( self, - pl_module: LightningModule, + pl_module: 'pl.LightningModule', optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index e0ecddf322250..86177c5500e2f 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -18,7 +18,7 @@ import torch.nn as nn from torch.optim import Optimizer -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -33,7 +33,7 @@ class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase): pl_module: the model to wrap """ - def __init__(self, pl_module: LightningModule): + def __init__(self, pl_module: 'pl.LightningModule'): super().__init__(pl_module) @staticmethod @@ -96,7 +96,7 @@ def connect( incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or `lr_schedulers`. """ - model = cast(LightningModule, model.to(dtype=torch.float64)) + model = cast(pl.LightningModule, model.to(dtype=torch.float64)) model = LightningDoublePrecisionModule(model) return super().connect(model, optimizers, lr_schedulers) diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 4e75358b67fae..b3a22ad1ad3b2 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -19,8 +19,8 @@ import torch from torch.utils.data import DataLoader +import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -37,7 +37,7 @@ class LightningIPUModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: LightningModule, precision: Union[str, int]): + def __init__(self, pl_module: 'pl.LightningModule', precision: Union[str, int]): super().__init__(pl_module) self.precision = precision @@ -184,7 +184,7 @@ def _validate_opts(self, opts: 'poptorch.Options', training: bool) -> None: opts.Training.set(gradient_accumulation=1) @property - def lightning_module(self) -> Optional[LightningModule]: + def lightning_module(self) -> Optional['pl.LightningModule']: return self.model.module if isinstance(self.model, LightningIPUModule) else self.model def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index fceafddd66ec0..7e5796d5b5668 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -16,7 +16,7 @@ import torch from torch.optim import Optimizer -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.core.optimizer import is_lightning_optimizer from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.states import TrainerFn @@ -86,7 +86,7 @@ def _optim_state_dict(self, optimizer): return optimizer.state_dict() @property - def lightning_module(self) -> LightningModule: + def lightning_module(self) -> 'pl.LightningModule': if not _FAIRSCALE_AVAILABLE: # pragma: no cover raise MisconfigurationException( "`DDPShardedPlugin` requires `fairscale` to be installed." diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 5daf4e5be3735..c583ac756cd0f 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -16,7 +16,7 @@ import torch from torch.optim import Optimizer -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerFn @@ -71,7 +71,7 @@ def _optim_state_dict(self, optimizer): return optimizer.state_dict() @property - def lightning_module(self) -> LightningModule: + def lightning_module(self) -> 'pl.LightningModule': if not _FAIRSCALE_AVAILABLE: # pragma: no cover raise MisconfigurationException( "`DDPSpawnShardedPlugin` requires `fairscale` to be installed." diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 1f17308df73b3..4f4e44e57d3a3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -17,8 +17,8 @@ from inspect import signature from typing import Any, Callable, Dict, List, Optional, Type +import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -32,19 +32,19 @@ class TrainerCallbackHookMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class callbacks: List[Callback] = [] - lightning_module: LightningModule + lightning_module: 'pl.LightningModule' - def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: + def on_before_accelerator_backend_setup(self, model: 'pl.LightningModule') -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def configure_sharded_model(self, model: LightningModule) -> None: + def configure_sharded_model(self, model: 'pl.LightningModule') -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.on_configure_sharded_model(self, model) - def setup(self, model: LightningModule, stage: Optional[str]) -> None: + def setup(self, model: 'pl.LightningModule', stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: callback.setup(self, model, stage=stage) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 5652a65ee6df0..75cd74b307852 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -15,9 +15,9 @@ from datetime import timedelta from typing import Dict, List, Optional, Union +import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.callbacks.timer import Timer -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -137,7 +137,7 @@ def attach_model_logging_functions(self, model): callback.log_dict = model.log_dict @staticmethod - def _attach_model_callbacks(model: LightningModule, trainer) -> None: + def _attach_model_callbacks(model: 'pl.LightningModule', trainer) -> None: """ Attaches the callbacks defined in the model. If a callback returned by the model's configure_callback method has the same type as one or several diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 7310f027d137f..1c06508f90605 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -20,8 +20,7 @@ import torch from torchmetrics import Metric -import pytorch_lightning -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import ( _OMEGACONF_AVAILABLE, @@ -174,14 +173,13 @@ def restore_training_state(self) -> None: # restore precision plugin (scaler etc.) self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint) - # restore progress (loops etc.) + # restore progress + # FIXME self.restore_progress() + self.restore_loops() self.restore_optimizers_and_schedulers() - # restore loops - self.restore_loops() - def restore_callbacks(self) -> None: """ Restores all callbacks from the pre-loaded checkpoint. """ if not self._loaded_checkpoint: @@ -241,7 +239,7 @@ def restore_optimizers_and_schedulers(self) -> None: self.restore_lr_schedulers() def restore_loops(self) -> None: - """ Restores the loops state_dicts""" + """ Restores the loops states """ if not self._loaded_checkpoint: return @@ -350,8 +348,8 @@ def hpc_save(self, folderpath: str, logger): try: atomic_save(checkpoint, filepath) except AttributeError as err: - if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: - del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + if pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + del checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( 'warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}' @@ -408,7 +406,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: checkpoint = { 'epoch': current_epoch, 'global_step': global_step, - 'pytorch-lightning_version': pytorch_lightning.__version__, + 'pytorch-lightning_version': pl.__version__, 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), 'loops': self.get_loops_state_dict() } @@ -436,13 +434,13 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # dump hyper-parameters if model.hparams: if hasattr(model, '_hparams_name'): - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name + checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name # dump arguments if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) + checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams + checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) else: - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) + checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) # give the model a chance to dump a few things model.on_save_checkpoint(checkpoint) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 3dd12d4316ef7..cbed7368a4372 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -19,7 +19,6 @@ import torch from torchmetrics import Metric -from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin @@ -346,7 +345,6 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = self._minimize = None self._batch_size = torch.tensor(1, device=device) self.device: Optional[Union[str, torch.device]] = device - self.fx_validator = FxValidator() @property def result_metrics(self) -> List[ResultMetric]: @@ -635,7 +633,7 @@ def __str__(self) -> str: return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})' def __getstate__(self, drop_value: bool = True) -> dict: - d = {k: v for k, v in self.__dict__.items() if k != 'fx_validator'} + d = self.__dict__.copy() # can't deepcopy tensors with grad_fn minimize = d['_minimize'] diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index c9b8a6f29652b..ce6caa4e2f330 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -22,8 +22,8 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler +import pytorch_lightning as pl from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.core import LightningModule from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.states import RunningStage @@ -226,7 +226,7 @@ def _get_distributed_sampler( sampler = cls(dataloader.dataset, **kwargs) return sampler - def reset_train_dataloader(self, model: LightningModule) -> None: + def reset_train_dataloader(self, model: 'pl.LightningModule') -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). @@ -312,7 +312,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: def _reset_eval_dataloader( self, - model: LightningModule, + model: 'pl.LightningModule', mode: str, ) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. @@ -412,7 +412,7 @@ def _reset_eval_dataloader( return loader_num_batches, dataloaders - def reset_val_dataloader(self, model: LightningModule) -> None: + def reset_val_dataloader(self, model: 'pl.LightningModule') -> None: """Resets the validation dataloader and determines the number of batches. Args: @@ -457,7 +457,7 @@ def reset_train_val_dataloaders(self, model) -> None: if self.val_dataloaders is None: self.reset_val_dataloader(model) - def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader: + def request_dataloader(self, model: 'pl.LightningModule', stage: str) -> DataLoader: """Handles downloading data in the GPU or TPU case. Args: diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index cbf331913e597..2336379fc3d49 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -15,7 +15,7 @@ from abc import ABC from typing import Optional -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -27,9 +27,9 @@ class TrainerModelHooksMixin(ABC): Use the utilities from ``pytorch_lightning.utilities.signature_utils`` instead. """ - lightning_module: LightningModule + lightning_module: 'pl.LightningModule' - def is_function_implemented(self, f_name: str, model: Optional[LightningModule] = None) -> bool: + def is_function_implemented(self, f_name: str, model: Optional['pl.LightningModule'] = None) -> bool: rank_zero_deprecation( "Internal: TrainerModelHooksMixin.is_function_implemented is deprecated in v1.4" " and will be removed in v1.6." diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index b5afe7bf75168..80ec5857de287 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -19,7 +19,7 @@ from torch import optim from torch.optim.optimizer import Optimizer -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -29,7 +29,7 @@ class TrainerOptimizersMixin(ABC): _lightning_optimizers: Optional[List[LightningOptimizer]] - def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: + def init_optimizers(self, model: 'pl.LightningModule') -> Tuple[List, List, List]: self._lightning_optimizers = None optim_conf = model.configure_optimizers() if optim_conf is None: diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index b77b1b8268b9a..d9620112479f2 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -21,11 +21,11 @@ import torch from torch.optim import Optimizer +import pytorch_lightning as pl from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.tensorboard import TensorBoardLogger @@ -146,7 +146,7 @@ def data_parallel_device_ids(self) -> Optional[List[int]]: return self.accelerator_connector.parallel_device_ids @property - def lightning_module(self) -> LightningModule: + def lightning_module(self) -> 'pl.LightningModule': return self.accelerator.lightning_module @property @@ -277,7 +277,7 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]: def progress_bar_dict(self) -> dict: """ Read-only for progress bar metrics. """ ref_model = self.lightning_module - ref_model = cast(LightningModule, ref_model) + ref_model = cast(pl.LightningModule, ref_model) standard_metrics = ref_model.get_progress_bar_dict() pbar_metrics = self.progress_bar_metrics diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9d581dc560f23..4d097a2de2763 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -21,10 +21,10 @@ import torch +import pytorch_lightning as pl from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop @@ -470,7 +470,7 @@ def _setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamod def fit( self, - model: LightningModule, + model: 'pl.LightningModule', train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, @@ -526,7 +526,7 @@ def fit( def validate( self, - model: Optional[LightningModule] = None, + model: Optional['pl.LightningModule'] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, @@ -602,7 +602,7 @@ def validate( def test( self, - model: Optional[LightningModule] = None, + model: Optional['pl.LightningModule'] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, @@ -677,7 +677,7 @@ def test( def predict( self, - model: Optional[LightningModule] = None, + model: Optional['pl.LightningModule'] = None, dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, @@ -747,7 +747,7 @@ def predict( def tune( self, - model: LightningModule, + model: 'pl.LightningModule', train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, @@ -807,7 +807,7 @@ def tune( return result - def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + def _run(self, model: 'pl.LightningModule') -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) @@ -1090,7 +1090,7 @@ def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: self.checkpoint_connector.restore_model_weights(ckpt_path) return ckpt_path - def _call_setup_hook(self, model: LightningModule) -> None: + def _call_setup_hook(self, model: 'pl.LightningModule') -> None: fn = self.state.fn._setup_fn self.accelerator.barrier("pre_setup") @@ -1102,7 +1102,7 @@ def _call_setup_hook(self, model: LightningModule) -> None: self.accelerator.barrier("post_setup") - def _call_configure_sharded_model(self, model: LightningModule) -> None: + def _call_configure_sharded_model(self, model: 'pl.LightningModule') -> None: # Call configure sharded model hook if accelerator requests. In some cases # we will not call the hook; the hook has initialized the sharded model for example. @@ -1115,7 +1115,7 @@ def _call_configure_sharded_model(self, model: LightningModule) -> None: model.call_configure_sharded_model_hook = True self.accelerator.call_configure_sharded_model_hook = False - def _call_teardown_hook(self, model: LightningModule) -> None: + def _call_teardown_hook(self, model: 'pl.LightningModule') -> None: fn = self.state.fn._setup_fn if self.datamodule is not None: diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index a45c9436dbdb7..beecc5e2a764d 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -18,7 +18,7 @@ import torch from torch import Tensor -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.finite_checks import detect_nan_parameters, print_nan_gradients @@ -34,7 +34,7 @@ class TrainerTrainingTricksMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - lightning_module: LightningModule + lightning_module: 'pl.LightningModule' def print_nan_gradients(self) -> None: rank_zero_deprecation( diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index b7c3c09aff60b..e52f8efa2689f 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -15,8 +15,7 @@ from typing import Optional, Type, Union from unittest.mock import Mock -from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.core.lightning import LightningModule +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation @@ -24,7 +23,7 @@ def is_overridden( method_name: str, instance: Optional[object] = None, parent: Optional[Type[object]] = None, - model: Optional[Union[LightningModule, LightningDataModule]] = None, + model: Optional[Union['pl.LightningModule', 'pl.LightningDataModule']] = None, ) -> bool: if model is not None and instance is None: rank_zero_deprecation( @@ -38,10 +37,10 @@ def is_overridden( return False if parent is None: - if isinstance(instance, LightningModule): - parent = LightningModule - elif isinstance(instance, LightningDataModule): - parent = LightningDataModule + if isinstance(instance, pl.LightningModule): + parent = pl.LightningModule + elif isinstance(instance, pl.LightningDataModule): + parent = pl.LightningDataModule if parent is None: raise ValueError("Expected a parent") From 2c08db6ddde1a413dbd829fb61ba3b1ceb361227 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 20:42:54 +0200 Subject: [PATCH 89/90] Minor changes --- tests/checkpointing/test_model_checkpoint.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 3fedc62d8929b..038c53d3c7b7c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1380,6 +1380,7 @@ def training_step(self, batch, batch_idx): if self.trainer.current_epoch == 2: return if batch_idx == self.breaking_batch_idx: + # simulate failure mid epoch raise CustomException self.log("tracking", batch_idx, on_step=True, on_epoch=True) @@ -1432,25 +1433,21 @@ def on_epoch_end(self) -> None: def test_result_collection_reload(tmpdir): - - trainer_kwargs = { + result_collection_reload({ "default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0, - } - result_collection_reload(trainer_kwargs) + }) @RunIf(min_gpus=2, special=True) def test_result_collection_reload_2_gpus(tmpdir): - - trainer_kwargs = { + result_collection_reload({ "default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0, "accelerator": "ddp", "gpus": 2, - } - result_collection_reload(trainer_kwargs) + }) From 189f0ad152cb8bb94e7c64a920ced512861fd0ec Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 25 Jun 2021 20:51:23 +0200 Subject: [PATCH 90/90] Revert checkpoint changes --- .../connectors/checkpoint_connector.py | 94 +------------ tests/checkpointing/test_model_checkpoint.py | 130 ------------------ tests/models/test_hooks.py | 24 ---- 3 files changed, 6 insertions(+), 242 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 1c06508f90605..f1620c10bbd45 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -18,10 +18,8 @@ from typing import Optional, Union import torch -from torchmetrics import Metric import pytorch_lightning as pl -from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import ( _OMEGACONF_AVAILABLE, DeviceType, @@ -45,7 +43,6 @@ def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = # used to validate checkpointing logic self.has_trained = False self._loaded_checkpoint = dict() - self._persistent_metrics = False @property def hpc_resume_path(self) -> Optional[str]: @@ -173,10 +170,8 @@ def restore_training_state(self) -> None: # restore precision plugin (scaler etc.) self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint) - # restore progress - # FIXME + # restore progress (loops etc.) self.restore_progress() - self.restore_loops() self.restore_optimizers_and_schedulers() @@ -238,58 +233,6 @@ def restore_optimizers_and_schedulers(self) -> None: self.restore_optimizers() self.restore_lr_schedulers() - def restore_loops(self) -> None: - """ Restores the loops states """ - if not self._loaded_checkpoint: - return - - self.restore_result_collections() - - def restore_result_collections(self) -> None: - """ Restores the loop result collections used for logging.""" - if not self._loaded_checkpoint: - return - - loops = self._loaded_checkpoint.get("loops", None) - - if not loops: - return - - state_dict = loops.get('result_collections', None) - - if not state_dict: - return - - # get current reduce function - sync_fn = self.trainer.training_type_plugin.reduce - - # get current result collections - train_results = self.trainer.fit_loop.epoch_loop.results - validation_results = self.trainer.fit_loop.val_loop.results - validate_results = self.trainer.validation_loop.results - test_results = self.trainer.test_loop.results - - metrics = {} - model_ref = self.trainer.lightning_module - for module_name, module in model_ref._named_members(lambda module: module._modules.items()): - if isinstance(module, Metric): - metrics[module_name] = module - - # restore collection and provide sync_fn - self._restore_restore_collection( - train_results, state_dict[TrainerFn.FITTING.value][RunningStage.TRAINING.value], sync_fn, metrics - ) - self._restore_restore_collection( - validation_results, state_dict[TrainerFn.FITTING.value][RunningStage.VALIDATING.value], sync_fn, metrics - ) - self._restore_restore_collection(validate_results, state_dict[RunningStage.VALIDATING.value], sync_fn, metrics) - self._restore_restore_collection(test_results, state_dict[RunningStage.TESTING.value], sync_fn, metrics) - - def _restore_restore_collection(self, results, state_dict, sync_fn, metrics): - results.load_state_dict(state_dict, sync_fn=sync_fn, metrics=metrics) - if not self.trainer.is_global_zero: - results.reset() - def restore_optimizers(self) -> None: """ Restores the optimizer states from the pre-loaded checkpoint. """ if not self._loaded_checkpoint: @@ -367,16 +310,11 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'epoch': training epoch 'global_step': training global step 'pytorch-lightning_version': PyTorch Lightning's version - 'callbacks': "callback specific state"[] # if not weights_only - 'optimizer_states': "PT optim's state_dict"[] # if not weights_only - 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only - 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp - 'result_collections': { - "train": PT TrainLoop ResultCollection state_dict - "validation": PT ValidationLoop ResultCollection state_dict - "test": PT TestLoop ResultCollection state_dict - } - 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp + 'callbacks': "callback specific state"[] # if not weights_only + 'optimizer_states': "PT optim's state_dict"[] # if not weights_only + 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only + 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp + 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp 'state_dict': Model's state_dict (e.g. network weights) CHECKPOINT_HYPER_PARAMS_NAME: CHECKPOINT_HYPER_PARAMS_KEY: @@ -397,18 +335,11 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: model = self.trainer.lightning_module - if not self._persistent_metrics: - for _, module in model.named_modules(): - if isinstance(module, Metric): - module.persistent(True) - self._persistent_metrics = True - checkpoint = { 'epoch': current_epoch, 'global_step': global_step, 'pytorch-lightning_version': pl.__version__, 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), - 'loops': self.get_loops_state_dict() } if not weights_only: @@ -449,19 +380,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint - def get_loops_state_dict(self): - return {"result_collections": self.get_result_collections_state_dict()} - - def get_result_collections_state_dict(self): - return { - TrainerFn.FITTING.value: { - RunningStage.TRAINING.value: self.trainer.fit_loop.epoch_loop.results.state_dict(), - RunningStage.VALIDATING.value: self.trainer.fit_loop.val_loop.results.state_dict(), - }, - RunningStage.VALIDATING.value: self.trainer.validation_loop.results.state_dict(), - RunningStage.TESTING.value: self.trainer.evaluation_loop.results.state_dict(), - } - def hpc_load(self, checkpoint_path: str) -> None: """ Attempts to restore the full training and model state from a HPC checkpoint file. diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 038c53d3c7b7c..2a5a1d0d26d37 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -18,7 +18,6 @@ import re import time from argparse import Namespace -from contextlib import suppress from datetime import timedelta from logging import INFO from pathlib import Path @@ -32,14 +31,12 @@ import yaml from omegaconf import Container, OmegaConf from torch import optim -from torchmetrics import Metric import pytorch_lightning as pl import tests.helpers.utils as tutils from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -1324,130 +1321,3 @@ def test_trainer_checkpoint_callback_bool(tmpdir): mc = ModelCheckpoint(dirpath=tmpdir) with pytest.raises(MisconfigurationException, match="Invalid type provided for checkpoint_callback"): Trainer(checkpoint_callback=mc) - - -class DummyMetric(Metric): - - def __init__(self): - super().__init__() - self.add_state("sum", torch.tensor(0), dist_reduce_fx=torch.sum) - self.add_state("count", torch.tensor(0), dist_reduce_fx=torch.sum) - - def update(self, increment): - self.sum += increment - self.count += 1 - - def compute(self): - return self.sum // self.count - - -def result_collection_reload(trainer_kwargs): - num_processes = trainer_kwargs.get("gpus", 1) - - class CustomException(Exception): - pass - - class ExtendedBoringModel(BoringModel): - - def __init__(self): - super().__init__() - self.has_reloaded = False - self.breaking_batch_idx = 3 - self.has_validated_sum = False - self.dummy_metric = DummyMetric() - self.dummy_metric_dynamic = DummyMetric() - - def training_step(self, batch, batch_idx): - assert len(batch) == 1 - if self.has_reloaded: - if batch_idx >= self.breaking_batch_idx: - self.log("tracking", batch_idx, on_step=True, on_epoch=True) - self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) - - self.dummy_metric(batch_idx) - self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) - - value = self.trainer.train_loop.results['training_step.tracking'].value - shift = 0 - if num_processes == 2: - shift = 3 if self.trainer.is_global_zero else -3 - expected = sum(range(batch_idx + 1)) + shift - assert expected == value - - value = self.trainer.train_loop.results['training_step.tracking_2'] - assert expected == value - else: - if self.trainer.current_epoch == 2: - return - if batch_idx == self.breaking_batch_idx: - # simulate failure mid epoch - raise CustomException - - self.log("tracking", batch_idx, on_step=True, on_epoch=True) - self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) - - self.dummy_metric(batch_idx) - self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) - - value = self.trainer.train_loop.results['training_step.tracking'].value - assert value == sum(range(batch_idx + 1)) - - value = self.trainer.train_loop.results['training_step.tracking_2'] - assert value == sum(range(batch_idx + 1)) - - return super().training_step(batch, batch_idx) - - def on_epoch_end(self) -> None: - if self.trainer.current_epoch: - total = sum(range(5)) * num_processes - metrics = self.trainer.train_loop.results.metrics(on_step=False) - assert self.trainer.train_loop.results['training_step.tracking'].value == total - assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2 - assert self.trainer.train_loop.results['training_step.tracking_2'].value == total - assert metrics[MetricSource.CALLBACK]["tracking_2"] == self.dummy_metric.compute() == 2 - self.has_validated_sum = True - - model = ExtendedBoringModel() - - trainer = Trainer(**trainer_kwargs) - - with suppress(CustomException): - trainer.fit(model) - - checkpoint_path = trainer.accelerator.broadcast(os.path.join(trainer_kwargs["default_root_dir"], 'ckpt.pt')) - trainer.save_checkpoint(checkpoint_path) - - trainer.accelerator.barrier() - - if trainer.is_global_zero: - checkpoint = torch.load(checkpoint_path) - assert checkpoint["state_dict"]['dummy_metric.sum'] == 3 * num_processes - - trainer_kwargs["resume_from_checkpoint"] = checkpoint_path - trainer_kwargs["max_epochs"] = 2 - - trainer = Trainer(**trainer_kwargs) - model.has_reloaded = True - trainer.fit(model) - assert model.has_validated_sum - - -def test_result_collection_reload(tmpdir): - result_collection_reload({ - "default_root_dir": tmpdir, - "max_epochs": 1, - "limit_train_batches": 5, - "limit_val_batches": 0, - }) - - -@RunIf(min_gpus=2, special=True) -def test_result_collection_reload_2_gpus(tmpdir): - result_collection_reload({ - "default_root_dir": tmpdir, - "max_epochs": 1, - "limit_train_batches": 5, - "limit_val_batches": 0, - "accelerator": "ddp", - "gpus": 2, - }) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index ded39924479de..9a689fe9d725a 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -400,7 +400,6 @@ def test_trainer_model_hook_system_fit(tmpdir): 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, 'state_dict': ANY, - 'loops': ANY, } expected = [ dict(name='Callback.on_init_start', args=(trainer, )), @@ -513,16 +512,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, 'state_dict': ANY, - 'loops': { - "result_collections": { - "fit": { - "train": ANY, - "validate": ANY, - }, - "validate": ANY, - "test": ANY - } - }, } expected = [ dict(name='Callback.on_init_start', args=(trainer, )), @@ -542,16 +531,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, 'state_dict': ANY, - 'loops': { - "result_collections": { - "fit": { - "train": ANY, - "validate": ANY, - }, - "validate": ANY, - "test": ANY - } - }, }, ) ), dict(name='configure_sharded_model'), @@ -821,9 +800,6 @@ def call(hook, fn, *args, **kwargs): 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, 'state_dict': ANY, - 'loops': { - "result_collections": ANY - } }, ) ), dict(name='teardown', kwargs=dict(stage='fit')),