Skip to content

Commit 7605c6f

Browse files
tchatonUbuntu
authored andcommitted
[bug-fix] Metric reduction with Logging (#5150)
* add test * resolve bug * udpate test * wrongly copy / paste * update test * resolve a second bug Co-authored-by: Ubuntu <[email protected]>
1 parent 824fad8 commit 7605c6f

File tree

4 files changed

+66
-8
lines changed

4 files changed

+66
-8
lines changed

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
Monitor a metric and stop training when it stops improving.
2020
2121
"""
22+
import numbers
2223

2324
import numpy as np
2425
import torch
2526

2627
from pytorch_lightning.callbacks.base import Callback
27-
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, _TPU_AVAILABLE
28+
from pytorch_lightning.metrics.metric import Metric
29+
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_info, rank_zero_warn
2830

2931

3032
class EarlyStopping(Callback):
@@ -199,8 +201,11 @@ def _run_early_stopping_check(self, trainer, pl_module):
199201
# when in dev debugging
200202
trainer.dev_debugger.track_early_stopping_history(self, current)
201203

202-
if not isinstance(current, torch.Tensor):
203-
current = torch.tensor(current, device=pl_module.device)
204+
if current is not None:
205+
if isinstance(current, Metric):
206+
current = current.compute()
207+
elif isinstance(current, numbers.Number):
208+
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)
204209

205210
if trainer.use_tpu and _TPU_AVAILABLE:
206211
current = current.cpu()

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
2121
"""
2222

23+
import numbers
2324
import os
2425
import re
2526
from copy import deepcopy
@@ -32,6 +33,7 @@
3233

3334
from pytorch_lightning import _logger as log
3435
from pytorch_lightning.callbacks.base import Callback
36+
from pytorch_lightning.metrics.metric import Metric
3537
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
3638
from pytorch_lightning.utilities.cloud_io import get_filesystem
3739
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -581,8 +583,11 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
581583
epoch = metrics.get("epoch")
582584
step = metrics.get("step")
583585

584-
if not isinstance(current, torch.Tensor) and current is not None:
585-
current = torch.tensor(current, device=pl_module.device)
586+
if current is not None:
587+
if isinstance(current, Metric):
588+
current = current.compute()
589+
elif isinstance(current, numbers.Number):
590+
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)
586591

587592
if self.check_monitor_top_k(current):
588593
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)

pytorch_lightning/core/step_result.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,10 @@ def get_forked_metrics(self, add_dataloader_idx=False):
371371
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)
372372

373373
if options['forked']:
374-
result[dl_key] = self[k]
374+
if isinstance(self[k], Metric):
375+
result[dl_key] = self[k].compute().detach()
376+
else:
377+
result[dl_key] = self[k]
375378

376379
return result
377380

tests/trainer/logging_tests/test_train_loop_logging_1_0.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from torch.utils.data import Dataset
2828

2929
import pytorch_lightning as pl
30-
from pytorch_lightning import Trainer, callbacks
31-
from pytorch_lightning.callbacks import ModelCheckpoint
30+
from pytorch_lightning import callbacks, Trainer
31+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
3232
from pytorch_lightning.core.lightning import LightningModule
3333
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
3434
from tests.base.deterministic_model import DeterministicModel
@@ -856,3 +856,48 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
856856
'on_epoch_end': 5,
857857
'on_train_epoch_end': 6}
858858
assert trainer.callback_metrics == expected
859+
860+
861+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
862+
def test_metric_are_properly_reduced(tmpdir):
863+
class TestingModel(BoringModel):
864+
def __init__(self, *args, **kwargs):
865+
super().__init__()
866+
self.train_acc = pl.metrics.Accuracy()
867+
self.val_acc = pl.metrics.Accuracy()
868+
869+
def training_step(self, batch, batch_idx):
870+
self.train_acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
871+
self.log('train_acc', self.train_acc, on_step=True, on_epoch=True)
872+
return super().training_step(batch, batch_idx)
873+
874+
def validation_step(self, batch, batch_idx):
875+
preds = torch.tensor(0, device=self.device)
876+
targets = torch.tensor(1, device=self.device)
877+
if batch_idx < 8:
878+
targets = preds
879+
self.val_acc(preds, targets)
880+
self.log('val_acc', self.val_acc, on_step=True, on_epoch=True)
881+
return super().validation_step(batch, batch_idx)
882+
883+
early_stop = EarlyStopping(monitor='val_acc', mode='max')
884+
885+
checkpoint = ModelCheckpoint(
886+
monitor='val_acc',
887+
save_last=True,
888+
save_top_k=2,
889+
mode='max',
890+
)
891+
892+
model = TestingModel()
893+
trainer = Trainer(
894+
default_root_dir=tmpdir,
895+
gpus=1,
896+
max_epochs=2,
897+
limit_train_batches=5,
898+
limit_val_batches=32,
899+
callbacks=[early_stop, checkpoint])
900+
trainer.fit(model)
901+
902+
assert trainer.callback_metrics["val_acc"] == 8 / 32.
903+
assert "train_acc" in trainer.callback_metrics

0 commit comments

Comments
 (0)