Skip to content

Commit d757f8b

Browse files
tchatonUbuntu
authored andcommitted
[bug-fix] Metric reduction with Logging (#5150)
* add test * resolve bug * udpate test * wrongly copy / paste * update test * resolve a second bug Co-authored-by: Ubuntu <[email protected]>
1 parent 0e81284 commit d757f8b

File tree

4 files changed

+66
-8
lines changed

4 files changed

+66
-8
lines changed

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
Monitor a metric and stop training when it stops improving.
2020
2121
"""
22+
import numbers
2223

2324
import numpy as np
2425
import torch
2526

2627
from pytorch_lightning.callbacks.base import Callback
27-
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, _TPU_AVAILABLE
28+
from pytorch_lightning.metrics.metric import Metric
29+
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_info, rank_zero_warn
2830

2931

3032
class EarlyStopping(Callback):
@@ -199,8 +201,11 @@ def _run_early_stopping_check(self, trainer, pl_module):
199201
# when in dev debugging
200202
trainer.dev_debugger.track_early_stopping_history(self, current)
201203

202-
if not isinstance(current, torch.Tensor):
203-
current = torch.tensor(current, device=pl_module.device)
204+
if current is not None:
205+
if isinstance(current, Metric):
206+
current = current.compute()
207+
elif isinstance(current, numbers.Number):
208+
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)
204209

205210
if trainer.use_tpu and _TPU_AVAILABLE:
206211
current = current.cpu()

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
2121
"""
2222

23+
import numbers
2324
import os
2425
import re
2526
from copy import deepcopy
@@ -32,6 +33,7 @@
3233

3334
from pytorch_lightning import _logger as log
3435
from pytorch_lightning.callbacks.base import Callback
36+
from pytorch_lightning.metrics.metric import Metric
3537
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
3638
from pytorch_lightning.utilities.cloud_io import get_filesystem
3739
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -551,8 +553,11 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
551553
epoch = metrics.get("epoch")
552554
step = metrics.get("step")
553555

554-
if not isinstance(current, torch.Tensor) and current is not None:
555-
current = torch.tensor(current, device=pl_module.device)
556+
if current is not None:
557+
if isinstance(current, Metric):
558+
current = current.compute()
559+
elif isinstance(current, numbers.Number):
560+
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)
556561

557562
if self.check_monitor_top_k(current):
558563
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)

pytorch_lightning/core/step_result.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,10 @@ def get_forked_metrics(self, add_dataloader_idx=False):
371371
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)
372372

373373
if options['forked']:
374-
result[dl_key] = self[k]
374+
if isinstance(self[k], Metric):
375+
result[dl_key] = self[k].compute().detach()
376+
else:
377+
result[dl_key] = self[k]
375378

376379
return result
377380

tests/trainer/logging_tests/test_train_loop_logging_1_0.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from torch.utils.data import Dataset
2828

2929
import pytorch_lightning as pl
30-
from pytorch_lightning import Trainer, callbacks
31-
from pytorch_lightning.callbacks import ModelCheckpoint
30+
from pytorch_lightning import callbacks, Trainer
31+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
3232
from pytorch_lightning.core.lightning import LightningModule
3333
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
3434
from tests.base.deterministic_model import DeterministicModel
@@ -857,3 +857,48 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
857857
'on_epoch_end': 5,
858858
'on_train_epoch_end': 6}
859859
assert trainer.callback_metrics == expected
860+
861+
862+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
863+
def test_metric_are_properly_reduced(tmpdir):
864+
class TestingModel(BoringModel):
865+
def __init__(self, *args, **kwargs):
866+
super().__init__()
867+
self.train_acc = pl.metrics.Accuracy()
868+
self.val_acc = pl.metrics.Accuracy()
869+
870+
def training_step(self, batch, batch_idx):
871+
self.train_acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
872+
self.log('train_acc', self.train_acc, on_step=True, on_epoch=True)
873+
return super().training_step(batch, batch_idx)
874+
875+
def validation_step(self, batch, batch_idx):
876+
preds = torch.tensor(0, device=self.device)
877+
targets = torch.tensor(1, device=self.device)
878+
if batch_idx < 8:
879+
targets = preds
880+
self.val_acc(preds, targets)
881+
self.log('val_acc', self.val_acc, on_step=True, on_epoch=True)
882+
return super().validation_step(batch, batch_idx)
883+
884+
early_stop = EarlyStopping(monitor='val_acc', mode='max')
885+
886+
checkpoint = ModelCheckpoint(
887+
monitor='val_acc',
888+
save_last=True,
889+
save_top_k=2,
890+
mode='max',
891+
)
892+
893+
model = TestingModel()
894+
trainer = Trainer(
895+
default_root_dir=tmpdir,
896+
gpus=1,
897+
max_epochs=2,
898+
limit_train_batches=5,
899+
limit_val_batches=32,
900+
callbacks=[early_stop, checkpoint])
901+
trainer.fit(model)
902+
903+
assert trainer.callback_metrics["val_acc"] == 8 / 32.
904+
assert "train_acc" in trainer.callback_metrics

0 commit comments

Comments
 (0)