From ffa0e2fec5c430601d5b258065f02f7b1c06fe9c Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 15 Dec 2020 21:14:00 +0100 Subject: [PATCH 1/6] add test --- .../test_train_loop_logging_1_0.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 0c27d8909d760..78bf3fc4455ea 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -771,3 +771,28 @@ def on_train_epoch_end(self, *_): trainer.fit(model) assert model.epoch_end_called assert model.on_train_epoch_end_called + + +def test_metric_are_properly_reduced(tmpdir): + class TestingModel(BoringModel): + def __init__(self, *args, **kwargs): + super().__init__() + self.acc = pl.metrics.Accuracy() + + def training_step(self, batch, batch_idx): + self.acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device)) + self.log('train_acc', self.acc, on_step=True, on_epoch=True) + return super().training_step(batch, batch_idx) + + + def validation_step(self, batch, batch_idx): + self.acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device)) + self.log('val_acc', self.acc, on_step=True, on_epoch=True) + return super().validation_step(batch, batch_idx) + + model = TestingModel() + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + max_epochs=1) + trainer.fit(model) From fcb1ca5681bb9cd229ac1315fa50992f96478fea Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 16 Dec 2020 10:37:30 +0100 Subject: [PATCH 2/6] resolve bug --- pytorch_lightning/callbacks/early_stopping.py | 11 ++++++++--- .../callbacks/model_checkpoint.py | 11 ++++++++--- .../test_train_loop_logging_1_0.py | 19 +++++++++++++++---- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 88f1881643c9a..4125a924cb2c5 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -19,6 +19,7 @@ Monitor a metric and stop training when it stops improving. """ +import numbers import os import numpy as np @@ -26,7 +27,8 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, TPU_AVAILABLE +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_warn class EarlyStopping(Callback): @@ -201,8 +203,11 @@ def _run_early_stopping_check(self, trainer, pl_module): # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) - if not isinstance(current, torch.Tensor): - current = torch.tensor(current, device=pl_module.device) + if current is not None: + if isinstance(current, Metric): + current = current.compute() + elif isinstance(current, numbers.Number): + current = torch.tensor(current, device=pl_module.device, dtype=torch.float) if trainer.use_tpu and TPU_AVAILABLE: current = current.cpu() diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1354f7f5056b3..b13413ef69e9a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -20,6 +20,7 @@ """ +import numbers import os import re from copy import deepcopy @@ -32,8 +33,9 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.plugins.rpc_plugin import RPCPlugin +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -573,8 +575,11 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath): epoch = metrics.get("epoch") step = metrics.get("step") - if not isinstance(current, torch.Tensor) and current is not None: - current = torch.tensor(current, device=pl_module.device) + if current is not None: + if isinstance(current, Metric): + current = current.compute() + elif isinstance(current, numbers.Number): + current = torch.tensor(current, device=pl_module.device, dtype=torch.float) if self.check_monitor_top_k(current): self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module) diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 78bf3fc4455ea..2fe394a99efcf 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -27,7 +27,7 @@ import pytorch_lightning as pl from pytorch_lightning import Trainer, callbacks -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset from tests.base.deterministic_model import DeterministicModel @@ -773,6 +773,7 @@ def on_train_epoch_end(self, *_): assert model.on_train_epoch_end_called +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") def test_metric_are_properly_reduced(tmpdir): class TestingModel(BoringModel): def __init__(self, *args, **kwargs): @@ -781,18 +782,28 @@ def __init__(self, *args, **kwargs): def training_step(self, batch, batch_idx): self.acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device)) - self.log('train_acc', self.acc, on_step=True, on_epoch=True) + self.log('train_acc', self.acc, on_step=False, on_epoch=True) return super().training_step(batch, batch_idx) - def validation_step(self, batch, batch_idx): self.acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device)) - self.log('val_acc', self.acc, on_step=True, on_epoch=True) + self.log('val_acc', self.acc, on_step=False, on_epoch=True) return super().validation_step(batch, batch_idx) + early_stop = EarlyStopping(monitor='val_acc', mode='max') + + checkpoint = ModelCheckpoint( + monitor='val_acc', + save_last=True, + save_top_k=5, + mode='max', + ) + model = TestingModel() trainer = Trainer( default_root_dir=tmpdir, gpus=1, max_epochs=1) trainer.fit(model) + + import pdb; pdb.set_trace() From 3939855a0051e94cf37da153ff7f04f405ebdbbf Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 16 Dec 2020 10:39:04 +0100 Subject: [PATCH 3/6] udpate test --- tests/trainer/logging_tests/test_train_loop_logging_1_0.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 2fe394a99efcf..64909610be7a7 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -806,4 +806,5 @@ def validation_step(self, batch, batch_idx): max_epochs=1) trainer.fit(model) - import pdb; pdb.set_trace() + assert "val_acc" in trainer.callback_metrics + assert "train_acc" in trainer.callback_metrics From 421255ef2c198e2d91cd882bbee9efb05dad44c3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 16 Dec 2020 10:46:38 +0100 Subject: [PATCH 4/6] wrongly copy / paste --- tests/trainer/logging_tests/test_train_loop_logging_1_0.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 64909610be7a7..201dededaf569 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -803,7 +803,8 @@ def validation_step(self, batch, batch_idx): trainer = Trainer( default_root_dir=tmpdir, gpus=1, - max_epochs=1) + max_epochs=2, + callbacks=[early_stop, checkpoint]) trainer.fit(model) assert "val_acc" in trainer.callback_metrics From 2aa76edc615d04d5dd4b5f142def455ddb84ada5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 16 Dec 2020 11:20:36 +0100 Subject: [PATCH 5/6] update test --- tests/trainer/logging_tests/test_train_loop_logging_1_0.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 201dededaf569..5d1cf0768801f 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -26,7 +26,7 @@ from torch.utils.data import Dataset import pytorch_lightning as pl -from pytorch_lightning import Trainer, callbacks +from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset @@ -782,12 +782,12 @@ def __init__(self, *args, **kwargs): def training_step(self, batch, batch_idx): self.acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device)) - self.log('train_acc', self.acc, on_step=False, on_epoch=True) + self.log('train_acc', self.acc, on_step=True, on_epoch=True) return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): self.acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device)) - self.log('val_acc', self.acc, on_step=False, on_epoch=True) + self.log('val_acc', self.acc, on_step=True, on_epoch=True) return super().validation_step(batch, batch_idx) early_stop = EarlyStopping(monitor='val_acc', mode='max') From ba23d77c3334e5e7df72f85b9d237271f4e38b91 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 16 Dec 2020 11:13:46 +0000 Subject: [PATCH 6/6] resolve a second bug --- pytorch_lightning/core/step_result.py | 5 ++++- .../test_train_loop_logging_1_0.py | 21 ++++++++++++------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 142fe9048cb0e..2418178c77dc8 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -367,7 +367,10 @@ def get_forked_metrics(self, add_dataloader_idx=False): dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) if options['forked']: - result[dl_key] = self[k] + if isinstance(self[k], Metric): + result[dl_key] = self[k].compute().detach() + else: + result[dl_key] = self[k] return result diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 5d1cf0768801f..7ee48af32cc6e 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -778,16 +778,21 @@ def test_metric_are_properly_reduced(tmpdir): class TestingModel(BoringModel): def __init__(self, *args, **kwargs): super().__init__() - self.acc = pl.metrics.Accuracy() + self.train_acc = pl.metrics.Accuracy() + self.val_acc = pl.metrics.Accuracy() def training_step(self, batch, batch_idx): - self.acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device)) - self.log('train_acc', self.acc, on_step=True, on_epoch=True) + self.train_acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device)) + self.log('train_acc', self.train_acc, on_step=True, on_epoch=True) return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): - self.acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device)) - self.log('val_acc', self.acc, on_step=True, on_epoch=True) + preds = torch.tensor(0, device=self.device) + targets = torch.tensor(1, device=self.device) + if batch_idx < 8: + targets = preds + self.val_acc(preds, targets) + self.log('val_acc', self.val_acc, on_step=True, on_epoch=True) return super().validation_step(batch, batch_idx) early_stop = EarlyStopping(monitor='val_acc', mode='max') @@ -795,7 +800,7 @@ def validation_step(self, batch, batch_idx): checkpoint = ModelCheckpoint( monitor='val_acc', save_last=True, - save_top_k=5, + save_top_k=2, mode='max', ) @@ -804,8 +809,10 @@ def validation_step(self, batch, batch_idx): default_root_dir=tmpdir, gpus=1, max_epochs=2, + limit_train_batches=5, + limit_val_batches=32, callbacks=[early_stop, checkpoint]) trainer.fit(model) - assert "val_acc" in trainer.callback_metrics + assert trainer.callback_metrics["val_acc"] == 8 / 32. assert "train_acc" in trainer.callback_metrics