From 0d085c10faadb395e4bac507c36040caee238464 Mon Sep 17 00:00:00 2001 From: Shabie Iqbal Date: Sun, 14 Nov 2021 04:26:55 +0100 Subject: [PATCH 01/11] log metrics for correct dataloader only --- .../logger_connector/logger_connector.py | 17 +++++++++++--- tests/loops/test_evaluation_loop.py | 22 +++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 37fcb06a1dc24..27f012b207513 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -13,6 +13,7 @@ # limitations under the License. from pprint import pprint from typing import Any, Dict, Iterable, List, Optional, Union +import re import torch @@ -154,6 +155,18 @@ def update_eval_step_metrics(self) -> None: # increment the step even if nothing was logged self._increment_eval_log_step() + @staticmethod + def _filter_metrics_for_dataloader(dl_idx, metrics, metric_prefix="dataloader_idx"): + result = {} + for k, v in metrics.items(): + if metric_prefix not in k: + continue + num_in_metric = int(re.search(r"\d+", k).group(0)) + if num_in_metric == dl_idx: + result[k] = v + break + return result + def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None: if self.trainer.sanity_checking: return @@ -162,9 +175,7 @@ def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None: has_been_initialized = len(self.eval_loop_results) == num_dataloaders for dl_idx in range(self.trainer._evaluation_loop.num_dataloaders): # remove callback metrics that don't belong to this dataloader - callback_metrics = { - k: v for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k - } + callback_metrics = self._filter_metrics_for_dataloader(dl_idx, metrics) if has_been_initialized: self.eval_loop_results[dl_idx].update(callback_metrics) else: diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index d6b2c15553fb9..e9374b8e21846 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -130,3 +130,25 @@ def on_advance_end(self): trainer.test_loop.connect(TestLoop()) trainer.test(model) assert did_assert + + +def test_log_metrics_only_include_metrics_from_concerned_dataloader(tmpdir): + class LessBoringModel(BoringModel): + def test_step(self, batch, batch_idx, dataloader_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('fake_test_acc', loss) + return {"y": loss} + + def test_epoch_end(self, outputs) -> None: + torch.stack([x["y"] for x in outputs[0]]).mean() + + num_dataloaders = 11 + test = RandomDataset(32, 128) + test_dataloaders = [DataLoader(test, batch_size=32)] * num_dataloaders + + model = LessBoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + output = trainer.test(model, dataloaders=test_dataloaders) + assert sum([len(x) for x in output]) == num_dataloaders From fce6e5376f2bfc06d0dc9d8095b720c6e1f31de0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 14 Nov 2021 03:33:00 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- tests/loops/test_evaluation_loop.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 27f012b207513..63de1edeca0d4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re from pprint import pprint from typing import Any, Dict, Iterable, List, Optional, Union -import re import torch diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index e9374b8e21846..35aadff1dd992 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -137,9 +137,9 @@ class LessBoringModel(BoringModel): def test_step(self, batch, batch_idx, dataloader_idx): output = self.layer(batch) loss = self.loss(batch, output) - self.log('fake_test_acc', loss) + self.log("fake_test_acc", loss) return {"y": loss} - + def test_epoch_end(self, outputs) -> None: torch.stack([x["y"] for x in outputs[0]]).mean() @@ -151,4 +151,4 @@ def test_epoch_end(self, outputs) -> None: trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) output = trainer.test(model, dataloaders=test_dataloaders) - assert sum([len(x) for x in output]) == num_dataloaders + assert sum(len(x) for x in output) == num_dataloaders From bc844c40c3d37ef3388fbd9276849669dc0dda2e Mon Sep 17 00:00:00 2001 From: Shabie Iqbal Date: Sun, 14 Nov 2021 05:02:48 +0100 Subject: [PATCH 03/11] collect all metrics for each dataloder --- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 27f012b207513..124c8cc7a5d85 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -160,11 +160,11 @@ def _filter_metrics_for_dataloader(dl_idx, metrics, metric_prefix="dataloader_id result = {} for k, v in metrics.items(): if metric_prefix not in k: + result[k] = v continue num_in_metric = int(re.search(r"\d+", k).group(0)) if num_in_metric == dl_idx: result[k] = v - break return result def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None: From 50ecc68a3866c6d6d39a943fcf3ab2b4f5d11cbc Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 15 Nov 2021 10:40:48 +0000 Subject: [PATCH 04/11] update --- tests/loops/test_evaluation_loop.py | 22 --------------- .../logging_/test_eval_loop_logging.py | 27 +++++++++++++++++++ 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index 35aadff1dd992..d6b2c15553fb9 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -130,25 +130,3 @@ def on_advance_end(self): trainer.test_loop.connect(TestLoop()) trainer.test(model) assert did_assert - - -def test_log_metrics_only_include_metrics_from_concerned_dataloader(tmpdir): - class LessBoringModel(BoringModel): - def test_step(self, batch, batch_idx, dataloader_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - self.log("fake_test_acc", loss) - return {"y": loss} - - def test_epoch_end(self, outputs) -> None: - torch.stack([x["y"] for x in outputs[0]]).mean() - - num_dataloaders = 11 - test = RandomDataset(32, 128) - test_dataloaders = [DataLoader(test, batch_size=32)] * num_dataloaders - - model = LessBoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - - output = trainer.test(model, dataloaders=test_dataloaders) - assert sum(len(x) for x in output) == num_dataloaders diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 6ed40b5f03082..9414de93aaf00 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -20,6 +20,7 @@ import numpy as np import pytest import torch +from torch.utils.data import DataLoader from pytorch_lightning import callbacks, Trainer from pytorch_lightning.loggers import TensorBoardLogger @@ -672,3 +673,29 @@ def val_dataloader(self): enable_model_summary=False, ) trainer.fit(model) + + +@pytest.mark.parametrize("num_dataloaders", [1, 2, 11]) +def test_log_metrics_only_include_metrics_from_concerned_dataloader(num_dataloaders, tmpdir): + class TestModel(BoringModel): + def test_step(self, batch, batch_idx, dataloader_idx=0): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log("fake_test_acc", loss) + return {"y": loss} + + def test_epoch_end(self, *_) -> None: + pass + + test = RandomDataset(32, 2) + test_dataloaders = [DataLoader(test, batch_size=1)] * num_dataloaders + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + output = trainer.test(model, dataloaders=test_dataloaders) + assert sum(len(x) for x in output) == num_dataloaders + if num_dataloaders == 1: + assert "dataloader_idx" not in output[0] + else: + assert all(f"dataloader_idx_{idx}" == list(x.keys())[0].split("/")[1] for idx, x in enumerate(output)) From a4a1d463ffbdf34e734a9a610a2842b1d6440f76 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 15 Nov 2021 15:49:10 +0000 Subject: [PATCH 05/11] update --- .../trainer/connectors/logger_connector/logger_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 59e92923ae30f..477d5db4e9b0a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -156,7 +156,9 @@ def update_eval_step_metrics(self) -> None: self._increment_eval_log_step() @staticmethod - def _filter_metrics_for_dataloader(dl_idx, metrics, metric_prefix="dataloader_idx"): + def _filter_metrics_for_dataloader( + dl_idx: int, metrics: Dict[str, int], metric_prefix: str = "dataloader_idx" + ) -> Dict[str, int]: result = {} for k, v in metrics.items(): if metric_prefix not in k: From 79ddeac41a13d4e0e64bdb592f6c9850220b7bcb Mon Sep 17 00:00:00 2001 From: shabie <30535146+shabie@users.noreply.github.com> Date: Mon, 15 Nov 2021 18:31:56 +0100 Subject: [PATCH 06/11] remove max_epochs as only call is to trainer.test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- tests/trainer/logging_/test_eval_loop_logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 9414de93aaf00..7f3b4b7bf437d 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -691,7 +691,7 @@ def test_epoch_end(self, *_) -> None: test_dataloaders = [DataLoader(test, batch_size=1)] * num_dataloaders model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + trainer = Trainer(default_root_dir=tmpdir) output = trainer.test(model, dataloaders=test_dataloaders) assert sum(len(x) for x in output) == num_dataloaders From dacef9097a54ed2a52b6d4a0cb60d97a82da7925 Mon Sep 17 00:00:00 2001 From: Shabie Iqbal Date: Tue, 16 Nov 2021 17:44:18 +0100 Subject: [PATCH 07/11] update test for specificity --- tests/trainer/logging_/test_eval_loop_logging.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 7f3b4b7bf437d..27fe483497d1d 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -677,12 +677,15 @@ def val_dataloader(self): @pytest.mark.parametrize("num_dataloaders", [1, 2, 11]) def test_log_metrics_only_include_metrics_from_concerned_dataloader(num_dataloaders, tmpdir): + + metric_prefix = "fake_test_acc" + dataloader_prefix = "dataloader_idx" + class TestModel(BoringModel): def test_step(self, batch, batch_idx, dataloader_idx=0): output = self.layer(batch) loss = self.loss(batch, output) - self.log("fake_test_acc", loss) - return {"y": loss} + self.log(metric_prefix, loss) def test_epoch_end(self, *_) -> None: pass @@ -694,8 +697,11 @@ def test_epoch_end(self, *_) -> None: trainer = Trainer(default_root_dir=tmpdir) output = trainer.test(model, dataloaders=test_dataloaders) - assert sum(len(x) for x in output) == num_dataloaders + if num_dataloaders == 1: - assert "dataloader_idx" not in output[0] + assert dataloader_prefix not in output[0] else: - assert all(f"dataloader_idx_{idx}" == list(x.keys())[0].split("/")[1] for idx, x in enumerate(output)) + for idx, metric in enumerate(output): + expected_dl_idx_str = f"{metric_prefix}/{dataloader_prefix}_{idx}" + assert len(metric) == 1 + assert expected_dl_idx_str in metric From 51a9e641e2ba0f29e7247f7438865c04351f2c36 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 17 Nov 2021 17:56:37 +0100 Subject: [PATCH 08/11] Change into unit tests --- .../logging_/test_eval_loop_logging.py | 56 +++++++++---------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 27fe483497d1d..88229effbc8c9 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -20,10 +20,10 @@ import numpy as np import pytest import torch -from torch.utils.data import DataLoader from pytorch_lightning import callbacks, Trainer from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset @@ -675,33 +675,27 @@ def val_dataloader(self): trainer.fit(model) -@pytest.mark.parametrize("num_dataloaders", [1, 2, 11]) -def test_log_metrics_only_include_metrics_from_concerned_dataloader(num_dataloaders, tmpdir): - - metric_prefix = "fake_test_acc" - dataloader_prefix = "dataloader_idx" - - class TestModel(BoringModel): - def test_step(self, batch, batch_idx, dataloader_idx=0): - output = self.layer(batch) - loss = self.loss(batch, output) - self.log(metric_prefix, loss) - - def test_epoch_end(self, *_) -> None: - pass - - test = RandomDataset(32, 2) - test_dataloaders = [DataLoader(test, batch_size=1)] * num_dataloaders - - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir) - - output = trainer.test(model, dataloaders=test_dataloaders) - - if num_dataloaders == 1: - assert dataloader_prefix not in output[0] - else: - for idx, metric in enumerate(output): - expected_dl_idx_str = f"{metric_prefix}/{dataloader_prefix}_{idx}" - assert len(metric) == 1 - assert expected_dl_idx_str in metric +@pytest.mark.parametrize( + ["kwargs", "expected"], + [ + ({"dl_idx": 0, "metrics": {"acc": 123}}, {"acc": 123}), + ( + {"dl_idx": 0, "metrics": {"acc/dataloader_idx_0": 123, "acc/dataloader_idx_1": 321}}, + {"acc/dataloader_idx_0": 123}, + ), + ( + {"dl_idx": 10, "metrics": {"acc/dataloader_idx_1": 123, "acc/dataloader_idx_10": 321}}, + {"acc/dataloader_idx_10": 321}, + ), + ( + {"dl_idx": 3, "metrics": {"top_3_acc/dataloader_idx_0": 123, "top_3_acc/dataloader_idx_3": 321}}, + {"top_3_acc/dataloader_idx_3": 321}, + ), + # theoretical case, as `/dataloader_idx_3` would have been added + ({"dl_idx": 3, "metrics": {"top_3_acc": 123}}, {"top_3_acc": 123}), + ], +) +def test_filter_metrics_for_dataloader(kwargs, expected): + """Logged metrics should only include metrics from the concerned dataloader.""" + actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs) + assert actual == expected From 0789b4d2b733a03c39415131add034b33a4d9c01 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 18 Nov 2021 14:51:35 +0000 Subject: [PATCH 09/11] update --- .../trainer/connectors/logger_connector/logger_connector.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 477d5db4e9b0a..df841788d76f4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re from pprint import pprint from typing import Any, Dict, Iterable, List, Optional, Union @@ -164,8 +163,7 @@ def _filter_metrics_for_dataloader( if metric_prefix not in k: result[k] = v continue - num_in_metric = int(re.search(r"\d+", k).group(0)) - if num_in_metric == dl_idx: + if k.endswith(f"{metric_prefix}_{dl_idx}"): result[k] = v return result From 0c5665d021fd7794204c18dc525817f7bde89f4d Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 18 Nov 2021 14:55:17 +0000 Subject: [PATCH 10/11] update --- .../connectors/logger_connector/logger_connector.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index df841788d76f4..f24eadd101349 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from pprint import pprint -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Mapping, Optional, Union import torch +from torch import Tensor import pytorch_lightning as pl from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger @@ -156,8 +157,8 @@ def update_eval_step_metrics(self) -> None: @staticmethod def _filter_metrics_for_dataloader( - dl_idx: int, metrics: Dict[str, int], metric_prefix: str = "dataloader_idx" - ) -> Dict[str, int]: + dl_idx: int, metrics: Mapping[str, Union[Tensor, Dict[str, Tensor]]], metric_prefix: str = "dataloader_idx" + ) -> Mapping[str, Union[Tensor, Dict[str, Tensor]]]: result = {} for k, v in metrics.items(): if metric_prefix not in k: From ef63cddaabc50ce859da37d3ab2bc4a6c1bb521e Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 18 Nov 2021 15:03:22 +0000 Subject: [PATCH 11/11] typing --- .../connectors/logger_connector/logger_connector.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index f24eadd101349..640fc667705a8 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from pprint import pprint -from typing import Any, Dict, Iterable, List, Mapping, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union import torch -from torch import Tensor import pytorch_lightning as pl from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger @@ -157,8 +156,8 @@ def update_eval_step_metrics(self) -> None: @staticmethod def _filter_metrics_for_dataloader( - dl_idx: int, metrics: Mapping[str, Union[Tensor, Dict[str, Tensor]]], metric_prefix: str = "dataloader_idx" - ) -> Mapping[str, Union[Tensor, Dict[str, Tensor]]]: + dl_idx: int, metrics: Dict[str, Union[Any, Dict[str, Any]]], metric_prefix: str = "dataloader_idx" + ) -> Dict[str, Union[Any, Dict[str, Any]]]: result = {} for k, v in metrics.items(): if metric_prefix not in k: