Skip to content

Commit a48ca18

Browse files
tchatonUbuntu
andcommitted
[bug-fix] Metric reduction with Logging (#5150)
* add test * resolve bug * udpate test * wrongly copy / paste * update test * resolve a second bug Co-authored-by: Ubuntu <[email protected]>
1 parent 0f36525 commit a48ca18

File tree

4 files changed

+64
-7
lines changed

4 files changed

+64
-7
lines changed

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,11 @@ def _run_early_stopping_check(self, trainer, pl_module):
199199
# when in dev debugging
200200
trainer.dev_debugger.track_early_stopping_history(self, current)
201201

202-
if not isinstance(current, torch.Tensor):
203-
current = torch.tensor(current, device=pl_module.device)
202+
if current is not None:
203+
if isinstance(current, Metric):
204+
current = current.compute()
205+
elif isinstance(current, numbers.Number):
206+
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)
204207

205208
if trainer.use_tpu and _TPU_AVAILABLE:
206209
current = current.cpu()

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
2121
"""
2222

23+
import numbers
2324
import os
2425
import re
2526
from copy import deepcopy
@@ -32,6 +33,8 @@
3233

3334
from pytorch_lightning import _logger as log
3435
from pytorch_lightning.callbacks.base import Callback
36+
from pytorch_lightning.metrics.metric import Metric
37+
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
3538
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
3639
from pytorch_lightning.utilities.cloud_io import get_filesystem
3740
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -574,8 +577,11 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
574577
epoch = metrics.get("epoch")
575578
step = metrics.get("step")
576579

577-
if not isinstance(current, torch.Tensor) and current is not None:
578-
current = torch.tensor(current, device=pl_module.device)
580+
if current is not None:
581+
if isinstance(current, Metric):
582+
current = current.compute()
583+
elif isinstance(current, numbers.Number):
584+
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)
579585

580586
if self.check_monitor_top_k(current):
581587
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)

pytorch_lightning/core/step_result.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,10 @@ def get_forked_metrics(self, add_dataloader_idx=False):
367367
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)
368368

369369
if options['forked']:
370-
result[dl_key] = self[k]
370+
if isinstance(self[k], Metric):
371+
result[dl_key] = self[k].compute().detach()
372+
else:
373+
result[dl_key] = self[k]
371374

372375
return result
373376

tests/trainer/logging_tests/test_train_loop_logging_1_0.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from torch.utils.data import Dataset
2727

2828
import pytorch_lightning as pl
29-
from pytorch_lightning import Trainer, callbacks
30-
from pytorch_lightning.callbacks import ModelCheckpoint
29+
from pytorch_lightning import callbacks, Trainer
30+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
3131
from pytorch_lightning.core.lightning import LightningModule
3232
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
3333
from tests.base.deterministic_model import DeterministicModel
@@ -817,3 +817,48 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
817817
'on_epoch_end': 5,
818818
'on_train_epoch_end': 6}
819819
assert trainer.callback_metrics == expected
820+
821+
822+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
823+
def test_metric_are_properly_reduced(tmpdir):
824+
class TestingModel(BoringModel):
825+
def __init__(self, *args, **kwargs):
826+
super().__init__()
827+
self.train_acc = pl.metrics.Accuracy()
828+
self.val_acc = pl.metrics.Accuracy()
829+
830+
def training_step(self, batch, batch_idx):
831+
self.train_acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
832+
self.log('train_acc', self.train_acc, on_step=True, on_epoch=True)
833+
return super().training_step(batch, batch_idx)
834+
835+
def validation_step(self, batch, batch_idx):
836+
preds = torch.tensor(0, device=self.device)
837+
targets = torch.tensor(1, device=self.device)
838+
if batch_idx < 8:
839+
targets = preds
840+
self.val_acc(preds, targets)
841+
self.log('val_acc', self.val_acc, on_step=True, on_epoch=True)
842+
return super().validation_step(batch, batch_idx)
843+
844+
early_stop = EarlyStopping(monitor='val_acc', mode='max')
845+
846+
checkpoint = ModelCheckpoint(
847+
monitor='val_acc',
848+
save_last=True,
849+
save_top_k=2,
850+
mode='max',
851+
)
852+
853+
model = TestingModel()
854+
trainer = Trainer(
855+
default_root_dir=tmpdir,
856+
gpus=1,
857+
max_epochs=2,
858+
limit_train_batches=5,
859+
limit_val_batches=32,
860+
callbacks=[early_stop, checkpoint])
861+
trainer.fit(model)
862+
863+
assert trainer.callback_metrics["val_acc"] == 8 / 32.
864+
assert "train_acc" in trainer.callback_metrics

0 commit comments

Comments
 (0)