Skip to content

Commit 8e76b3f

Browse files
committed
drop duplicate metrics
1 parent e2c404b commit 8e76b3f

File tree

6 files changed

+23
-123
lines changed

6 files changed

+23
-123
lines changed

pytorch_lightning/metrics/classification/average_precision.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,8 @@ def __init__(
9292
self.add_state("target", default=[], dist_reduce_fx=None)
9393

9494
rank_zero_warn(
95-
'Metric `AveragePrecision` will save all targets and'
96-
' predictions in buffer. For large datasets this may lead'
97-
' to large memory footprint.'
95+
'Metric `AveragePrecision` will save all targets and predictions in buffer.'
96+
' For large datasets this may lead to large memory footprint.'
9897
)
9998

10099
def update(self, preds: torch.Tensor, target: torch.Tensor):

pytorch_lightning/metrics/classification/precision_recall_curve.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,8 @@ def __init__(
102102
self.add_state("target", default=[], dist_reduce_fx=None)
103103

104104
rank_zero_warn(
105-
'Metric `PrecisionRecallCurve` will save all targets and'
106-
' predictions in buffer. For large datasets this may lead'
107-
' to large memory footprint.'
105+
'Metric `PrecisionRecallCurve` will save all targets and predictions in buffer.'
106+
' For large datasets this may lead to large memory footprint.'
108107
)
109108

110109
def update(self, preds: torch.Tensor, target: torch.Tensor):

pytorch_lightning/metrics/classification/roc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,8 @@ def __init__(
105105
self.add_state("target", default=[], dist_reduce_fx=None)
106106

107107
rank_zero_warn(
108-
'Metric `ROC` will save all targets and'
109-
' predictions in buffer. For large datasets this may lead'
110-
' to large memory footprint.'
108+
'Metric `ROC` will save all targets and predictions in buffer.'
109+
' For large datasets this may lead to large memory footprint.'
111110
)
112111

113112
def update(self, preds: torch.Tensor, target: torch.Tensor):

pytorch_lightning/metrics/functional/classification.py

Lines changed: 5 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import torch
1818
from torch.nn import functional as F
1919

20+
from pytorch_lightning.metrics.functional import roc
21+
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
2022
from pytorch_lightning.metrics.utils import to_categorical, get_num_classes, reduce, class_reduce
2123
from pytorch_lightning.utilities import rank_zero_warn
2224

@@ -332,107 +334,6 @@ def recall(
332334
num_classes=num_classes, class_reduction=class_reduction)[1]
333335

334336

335-
def _binary_clf_curve(
336-
pred: torch.Tensor,
337-
target: torch.Tensor,
338-
sample_weight: Optional[Sequence] = None,
339-
pos_label: int = 1.,
340-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
341-
"""
342-
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
343-
"""
344-
if sample_weight is not None and not isinstance(sample_weight, torch.Tensor):
345-
sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float)
346-
347-
# remove class dimension if necessary
348-
if pred.ndim > target.ndim:
349-
pred = pred[:, 0]
350-
desc_score_indices = torch.argsort(pred, descending=True)
351-
352-
pred = pred[desc_score_indices]
353-
target = target[desc_score_indices]
354-
355-
if sample_weight is not None:
356-
weight = sample_weight[desc_score_indices]
357-
else:
358-
weight = 1.
359-
360-
# pred typically has many tied values. Here we extract
361-
# the indices associated with the distinct values. We also
362-
# concatenate a value for the end of the curve.
363-
distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0]
364-
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)
365-
366-
target = (target == pos_label).to(torch.long)
367-
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
368-
369-
if sample_weight is not None:
370-
# express fps as a cumsum to ensure fps is increasing even in
371-
# the presence of floating point errors
372-
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
373-
else:
374-
fps = 1 + threshold_idxs - tps
375-
376-
return fps, tps, pred[threshold_idxs]
377-
378-
379-
# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
380-
def __roc(
381-
pred: torch.Tensor,
382-
target: torch.Tensor,
383-
sample_weight: Optional[Sequence] = None,
384-
pos_label: int = 1.,
385-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
386-
"""
387-
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
388-
389-
.. warning:: Deprecated
390-
391-
Args:
392-
pred: estimated probabilities
393-
target: ground-truth labels
394-
sample_weight: sample weights
395-
pos_label: the label for the positive class
396-
397-
Return:
398-
false-positive rate (fpr), true-positive rate (tpr), thresholds
399-
400-
Example:
401-
402-
>>> x = torch.tensor([0, 1, 2, 3])
403-
>>> y = torch.tensor([0, 1, 1, 1])
404-
>>> fpr, tpr, thresholds = __roc(x, y)
405-
>>> fpr
406-
tensor([0., 0., 0., 0., 1.])
407-
>>> tpr
408-
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
409-
>>> thresholds
410-
tensor([4, 3, 2, 1, 0])
411-
412-
"""
413-
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
414-
sample_weight=sample_weight,
415-
pos_label=pos_label)
416-
417-
# Add an extra threshold position
418-
# to make sure that the curve starts at (0, 0)
419-
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
420-
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
421-
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])
422-
423-
if fps[-1] <= 0:
424-
raise ValueError("No negative samples in targets, false positive value should be meaningless")
425-
426-
fpr = fps / fps[-1]
427-
428-
if tps[-1] <= 0:
429-
raise ValueError("No positive samples in targets, true positive value should be meaningless")
430-
431-
tpr = tps / tps[-1]
432-
433-
return fpr, tpr, thresholds
434-
435-
436337
# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
437338
def __multiclass_roc(
438339
pred: torch.Tensor,
@@ -474,7 +375,7 @@ def __multiclass_roc(
474375
for c in range(num_classes):
475376
pred_c = pred[:, c]
476377

477-
class_roc_vals.append(__roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c))
378+
class_roc_vals.append(roc(preds=pred_c, target=target, sample_weights=sample_weight, pos_label=c, num_classes=1))
478379

479380
return tuple(class_roc_vals)
480381

@@ -589,7 +490,7 @@ def auroc(
589490

590491
@auc_decorator(reorder=True)
591492
def _auroc(pred, target, sample_weight, pos_label):
592-
return __roc(pred, target, sample_weight, pos_label)
493+
return roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, num_classes=1)
593494

594495
return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)
595496

@@ -642,7 +543,7 @@ def multiclass_auroc(
642543

643544
@multiclass_auc_decorator(reorder=False)
644545
def _multiclass_auroc(pred, target, sample_weight, num_classes):
645-
return __multiclass_roc(pred, target, sample_weight, num_classes)
546+
return roc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes)
646547

647548
class_aurocs = _multiclass_auroc(pred=pred, target=target,
648549
sample_weight=sample_weight,

pytorch_lightning/metrics/functional/explained_variance.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tup
2323
return preds, target
2424

2525

26-
def _explained_variance_compute(preds: torch.Tensor,
27-
target: torch.Tensor,
28-
multioutput: str = 'uniform_average',
29-
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
26+
def _explained_variance_compute(
27+
preds: torch.Tensor,
28+
target: torch.Tensor,
29+
multioutput: str = 'uniform_average',
30+
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
3031
diff_avg = torch.mean(target - preds, dim=0)
3132
numerator = torch.mean((target - preds - diff_avg) ** 2, dim=0)
3233

@@ -52,10 +53,11 @@ def _explained_variance_compute(preds: torch.Tensor,
5253
return torch.sum(denominator / denom_sum * output_scores)
5354

5455

55-
def explained_variance(preds: torch.Tensor,
56-
target: torch.Tensor,
57-
multioutput: str = 'uniform_average',
58-
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
56+
def explained_variance(
57+
preds: torch.Tensor,
58+
target: torch.Tensor,
59+
multioutput: str = 'uniform_average',
60+
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
5961
"""
6062
Computes explained variance.
6163

tests/metrics/functional/test_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
accuracy,
1818
precision,
1919
recall,
20-
_binary_clf_curve,
2120
dice_score,
2221
auroc,
2322
multiclass_auroc,
2423
auc,
2524
iou,
2625
)
26+
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
2727
from pytorch_lightning.metrics.utils import to_onehot, get_num_classes, to_categorical
2828

2929

@@ -222,7 +222,7 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
222222
if sample_weight is not None:
223223
sample_weight = torch.ones_like(pred) * sample_weight
224224

225-
fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label)
225+
fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)
226226

227227
assert isinstance(tps, torch.Tensor)
228228
assert isinstance(fps, torch.Tensor)

0 commit comments

Comments
 (0)