Skip to content

Commit 6bade18

Browse files
tchatonUbuntu
authored andcommitted
[bug-fix] Metric reduction with Logging (#5150)
* add test * resolve bug * udpate test * wrongly copy / paste * update test * resolve a second bug Co-authored-by: Ubuntu <[email protected]>
1 parent 29c6bf5 commit 6bade18

File tree

4 files changed

+20
-8
lines changed

4 files changed

+20
-8
lines changed

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Monitor a metric and stop training when it stops improving.
2020
2121
"""
22+
import numbers
2223
import os
2324

2425
import numpy as np
@@ -201,8 +202,11 @@ def _run_early_stopping_check(self, trainer, pl_module):
201202
# when in dev debugging
202203
trainer.dev_debugger.track_early_stopping_history(self, current)
203204

204-
if not isinstance(current, torch.Tensor):
205-
current = torch.tensor(current, device=pl_module.device)
205+
if current is not None:
206+
if isinstance(current, Metric):
207+
current = current.compute()
208+
elif isinstance(current, numbers.Number):
209+
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)
206210

207211
if trainer.use_tpu and _TPU_AVAILABLE:
208212
current = current.cpu()

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
2121
"""
2222

23+
import numbers
2324
import os
2425
import re
2526
from copy import deepcopy
@@ -32,8 +33,9 @@
3233

3334
from pytorch_lightning import _logger as log
3435
from pytorch_lightning.callbacks.base import Callback
35-
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
36+
from pytorch_lightning.metrics.metric import Metric
3637
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
38+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
3739
from pytorch_lightning.utilities.cloud_io import get_filesystem
3840
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3941

@@ -580,8 +582,11 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
580582
epoch = metrics.get("epoch")
581583
step = metrics.get("step")
582584

583-
if not isinstance(current, torch.Tensor) and current is not None:
584-
current = torch.tensor(current, device=pl_module.device)
585+
if current is not None:
586+
if isinstance(current, Metric):
587+
current = current.compute()
588+
elif isinstance(current, numbers.Number):
589+
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)
585590

586591
if self.check_monitor_top_k(current):
587592
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)

pytorch_lightning/core/step_result.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,10 @@ def get_forked_metrics(self, add_dataloader_idx=False):
371371
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)
372372

373373
if options['forked']:
374-
result[dl_key] = self[k]
374+
if isinstance(self[k], Metric):
375+
result[dl_key] = self[k].compute().detach()
376+
else:
377+
result[dl_key] = self[k]
375378

376379
return result
377380

tests/trainer/logging_tests/test_train_loop_logging_1_0.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from torch.utils.data import Dataset
2828

2929
import pytorch_lightning as pl
30-
from pytorch_lightning import Trainer, callbacks
31-
from pytorch_lightning.callbacks import ModelCheckpoint
30+
from pytorch_lightning import callbacks, Trainer
31+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
3232
from pytorch_lightning.core.lightning import LightningModule
3333
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
3434
from tests.base.deterministic_model import DeterministicModel

0 commit comments

Comments
 (0)