Skip to content

Commit 81fd33b

Browse files
tchatonUbuntu
andauthored
[bug-fix] Metric reduction with Logging (#5150)
* add test * resolve bug * udpate test * wrongly copy / paste * update test * resolve a second bug Co-authored-by: Ubuntu <[email protected]>
1 parent 8d1ca4c commit 81fd33b

File tree

4 files changed

+67
-9
lines changed

4 files changed

+67
-9
lines changed

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@
1919
Monitor a metric and stop training when it stops improving.
2020
2121
"""
22+
import numbers
2223
import os
2324

2425
import numpy as np
2526
import torch
2627

2728
from pytorch_lightning import _logger as log
2829
from pytorch_lightning.callbacks.base import Callback
29-
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, TPU_AVAILABLE
30+
from pytorch_lightning.metrics.metric import Metric
31+
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_warn
3032

3133

3234
class EarlyStopping(Callback):
@@ -201,8 +203,11 @@ def _run_early_stopping_check(self, trainer, pl_module):
201203
# when in dev debugging
202204
trainer.dev_debugger.track_early_stopping_history(self, current)
203205

204-
if not isinstance(current, torch.Tensor):
205-
current = torch.tensor(current, device=pl_module.device)
206+
if current is not None:
207+
if isinstance(current, Metric):
208+
current = current.compute()
209+
elif isinstance(current, numbers.Number):
210+
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)
206211

207212
if trainer.use_tpu and TPU_AVAILABLE:
208213
current = current.cpu()

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
2121
"""
2222

23+
import numbers
2324
import os
2425
import re
2526
from copy import deepcopy
@@ -32,8 +33,9 @@
3233

3334
from pytorch_lightning import _logger as log
3435
from pytorch_lightning.callbacks.base import Callback
35-
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
36+
from pytorch_lightning.metrics.metric import Metric
3637
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
38+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
3739
from pytorch_lightning.utilities.cloud_io import get_filesystem
3840
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3941

@@ -580,8 +582,11 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
580582
epoch = metrics.get("epoch")
581583
step = metrics.get("step")
582584

583-
if not isinstance(current, torch.Tensor) and current is not None:
584-
current = torch.tensor(current, device=pl_module.device)
585+
if current is not None:
586+
if isinstance(current, Metric):
587+
current = current.compute()
588+
elif isinstance(current, numbers.Number):
589+
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)
585590

586591
if self.check_monitor_top_k(current):
587592
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)

pytorch_lightning/core/step_result.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,10 @@ def get_forked_metrics(self, add_dataloader_idx=False):
371371
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)
372372

373373
if options['forked']:
374-
result[dl_key] = self[k]
374+
if isinstance(self[k], Metric):
375+
result[dl_key] = self[k].compute().detach()
376+
else:
377+
result[dl_key] = self[k]
375378

376379
return result
377380

tests/trainer/logging_tests/test_train_loop_logging_1_0.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from torch.utils.data import Dataset
2828

2929
import pytorch_lightning as pl
30-
from pytorch_lightning import Trainer, callbacks
31-
from pytorch_lightning.callbacks import ModelCheckpoint
30+
from pytorch_lightning import callbacks, Trainer
31+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
3232
from pytorch_lightning.core.lightning import LightningModule
3333
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
3434
from tests.base.deterministic_model import DeterministicModel
@@ -810,3 +810,48 @@ def on_train_epoch_end(self, *_):
810810
trainer.fit(model)
811811
assert model.epoch_end_called
812812
assert model.on_train_epoch_end_called
813+
814+
815+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
816+
def test_metric_are_properly_reduced(tmpdir):
817+
class TestingModel(BoringModel):
818+
def __init__(self, *args, **kwargs):
819+
super().__init__()
820+
self.train_acc = pl.metrics.Accuracy()
821+
self.val_acc = pl.metrics.Accuracy()
822+
823+
def training_step(self, batch, batch_idx):
824+
self.train_acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
825+
self.log('train_acc', self.train_acc, on_step=True, on_epoch=True)
826+
return super().training_step(batch, batch_idx)
827+
828+
def validation_step(self, batch, batch_idx):
829+
preds = torch.tensor(0, device=self.device)
830+
targets = torch.tensor(1, device=self.device)
831+
if batch_idx < 8:
832+
targets = preds
833+
self.val_acc(preds, targets)
834+
self.log('val_acc', self.val_acc, on_step=True, on_epoch=True)
835+
return super().validation_step(batch, batch_idx)
836+
837+
early_stop = EarlyStopping(monitor='val_acc', mode='max')
838+
839+
checkpoint = ModelCheckpoint(
840+
monitor='val_acc',
841+
save_last=True,
842+
save_top_k=2,
843+
mode='max',
844+
)
845+
846+
model = TestingModel()
847+
trainer = Trainer(
848+
default_root_dir=tmpdir,
849+
gpus=1,
850+
max_epochs=2,
851+
limit_train_batches=5,
852+
limit_val_batches=32,
853+
callbacks=[early_stop, checkpoint])
854+
trainer.fit(model)
855+
856+
assert trainer.callback_metrics["val_acc"] == 8 / 32.
857+
assert "train_acc" in trainer.callback_metrics

0 commit comments

Comments
 (0)