Skip to content

Commit abadc86

Browse files
committed
curve
1 parent b89decd commit abadc86

File tree

3 files changed

+26
-322
lines changed

3 files changed

+26
-322
lines changed

pytorch_lightning/metrics/classification/precision_recall_curve.py

Lines changed: 8 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -11,80 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, List, Optional, Tuple, Union
14+
from typing import Any, Optional
1515

16-
import torch
17-
from torchmetrics import Metric
16+
from torchmetrics import PrecisionRecallCurve as _PrecisionRecallCurve
1817

19-
from pytorch_lightning.metrics.functional.precision_recall_curve import (
20-
_precision_recall_curve_compute,
21-
_precision_recall_curve_update,
22-
)
23-
from pytorch_lightning.utilities import rank_zero_warn
18+
from pytorch_lightning.utilities.deprecation import deprecated
2419

2520

26-
class PrecisionRecallCurve(Metric):
27-
"""
28-
Computes precision-recall pairs for different thresholds. Works for both
29-
binary and multiclass problems. In the case of multiclass, the values will
30-
be calculated based on a one-vs-the-rest approach.
31-
32-
Forward accepts
33-
34-
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
35-
with probabilities, where C is the number of classes.
36-
37-
- ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels
38-
39-
Args:
40-
num_classes: integer with number of classes. Not nessesary to provide
41-
for binary problems.
42-
pos_label: integer determining the positive class. Default is ``None``
43-
which for binary problem is translate to 1. For multiclass problems
44-
this argument should not be set as we iteratively change it in the
45-
range [0,num_classes-1]
46-
compute_on_step:
47-
Forward only calls ``update()`` and return None if this is set to False. default: True
48-
dist_sync_on_step:
49-
Synchronize metric state across processes at each ``forward()``
50-
before returning the value at the step. default: False
51-
process_group:
52-
Specify the process group on which synchronization is called. default: None (which selects the entire world)
53-
54-
Example (binary case):
55-
56-
>>> from pytorch_lightning.metrics import PrecisionRecallCurve
57-
>>> pred = torch.tensor([0, 1, 2, 3])
58-
>>> target = torch.tensor([0, 1, 1, 0])
59-
>>> pr_curve = PrecisionRecallCurve(pos_label=1)
60-
>>> precision, recall, thresholds = pr_curve(pred, target)
61-
>>> precision
62-
tensor([0.6667, 0.5000, 0.0000, 1.0000])
63-
>>> recall
64-
tensor([1.0000, 0.5000, 0.0000, 0.0000])
65-
>>> thresholds
66-
tensor([1, 2, 3])
67-
68-
Example (multiclass case):
69-
70-
>>> from pytorch_lightning.metrics import PrecisionRecallCurve
71-
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
72-
... [0.05, 0.75, 0.05, 0.05, 0.05],
73-
... [0.05, 0.05, 0.75, 0.05, 0.05],
74-
... [0.05, 0.05, 0.05, 0.75, 0.05]])
75-
>>> target = torch.tensor([0, 1, 3, 2])
76-
>>> pr_curve = PrecisionRecallCurve(num_classes=5)
77-
>>> precision, recall, thresholds = pr_curve(pred, target)
78-
>>> precision # doctest: +NORMALIZE_WHITESPACE
79-
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
80-
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
81-
>>> recall
82-
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
83-
>>> thresholds
84-
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
85-
86-
"""
21+
class PrecisionRecallCurve(_PrecisionRecallCurve):
8722

23+
@deprecated(target=_PrecisionRecallCurve, ver_deprecate="1.3.0", ver_remove="1.5.0")
8824
def __init__(
8925
self,
9026
num_classes: Optional[int] = None,
@@ -93,60 +29,9 @@ def __init__(
9329
dist_sync_on_step: bool = False,
9430
process_group: Optional[Any] = None,
9531
):
96-
super().__init__(
97-
compute_on_step=compute_on_step,
98-
dist_sync_on_step=dist_sync_on_step,
99-
process_group=process_group,
100-
)
101-
102-
self.num_classes = num_classes
103-
self.pos_label = pos_label
104-
105-
self.add_state("preds", default=[], dist_reduce_fx=None)
106-
self.add_state("target", default=[], dist_reduce_fx=None)
107-
108-
rank_zero_warn(
109-
'Metric `PrecisionRecallCurve` will save all targets and predictions in buffer.'
110-
' For large datasets this may lead to large memory footprint.'
111-
)
112-
113-
def update(self, preds: torch.Tensor, target: torch.Tensor):
11432
"""
115-
Update state with predictions and targets.
116-
117-
Args:
118-
preds: Predictions from model
119-
target: Ground truth values
120-
"""
121-
preds, target, num_classes, pos_label = _precision_recall_curve_update(
122-
preds, target, self.num_classes, self.pos_label
123-
)
124-
self.preds.append(preds)
125-
self.target.append(target)
126-
self.num_classes = num_classes
127-
self.pos_label = pos_label
128-
129-
def compute(
130-
self
131-
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
132-
List[torch.Tensor]]]:
133-
"""
134-
Compute the precision-recall curve
135-
136-
Returns:
137-
3-element tuple containing
33+
This implementation refers to :class:`~torchmetrics.PrecisionRecallCurve`.
13834
139-
precision:
140-
tensor where element i is the precision of predictions with
141-
score >= thresholds[i] and the last element is 1.
142-
If multiclass, this is a list of such tensors, one for each class.
143-
recall:
144-
tensor where element i is the recall of predictions with
145-
score >= thresholds[i] and the last element is 0.
146-
If multiclass, this is a list of such tensors, one for each class.
147-
thresholds:
148-
Thresholds used for computing precision/recall scores
35+
.. deprecated::
36+
Use :class:`~torchmetrics.PrecisionRecallCurve`. Will be removed in v1.5.0.
14937
"""
150-
preds = torch.cat(self.preds, dim=0)
151-
target = torch.cat(self.target, dim=0)
152-
return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label)

pytorch_lightning/metrics/functional/precision_recall_curve.py

Lines changed: 6 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -14,213 +14,22 @@
1414
from typing import List, Optional, Sequence, Tuple, Union
1515

1616
import torch
17-
import torch.nn.functional as F
17+
from torchmetrics.functional import precision_recall_curve as _precision_recall_curve
1818

19-
from pytorch_lightning.utilities import rank_zero_warn
19+
from pytorch_lightning.utilities.deprecation import deprecated
2020

2121

22-
def _binary_clf_curve(
23-
preds: torch.Tensor,
24-
target: torch.Tensor,
25-
sample_weights: Optional[Sequence] = None,
26-
pos_label: int = 1.,
27-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
28-
"""
29-
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
30-
"""
31-
if sample_weights is not None and not isinstance(sample_weights, torch.Tensor):
32-
sample_weights = torch.tensor(sample_weights, device=preds.device, dtype=torch.float)
33-
34-
# remove class dimension if necessary
35-
if preds.ndim > target.ndim:
36-
preds = preds[:, 0]
37-
desc_score_indices = torch.argsort(preds, descending=True)
38-
39-
preds = preds[desc_score_indices]
40-
target = target[desc_score_indices]
41-
42-
if sample_weights is not None:
43-
weight = sample_weights[desc_score_indices]
44-
else:
45-
weight = 1.
46-
47-
# pred typically has many tied values. Here we extract
48-
# the indices associated with the distinct values. We also
49-
# concatenate a value for the end of the curve.
50-
distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]
51-
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)
52-
target = (target == pos_label).to(torch.long)
53-
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
54-
55-
if sample_weights is not None:
56-
# express fps as a cumsum to ensure fps is increasing even in
57-
# the presence of floating point errors
58-
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
59-
else:
60-
fps = 1 + threshold_idxs - tps
61-
62-
return fps, tps, preds[threshold_idxs]
63-
64-
65-
def _precision_recall_curve_update(
66-
preds: torch.Tensor,
67-
target: torch.Tensor,
68-
num_classes: Optional[int] = None,
69-
pos_label: Optional[int] = None,
70-
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
71-
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
72-
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
73-
# single class evaluation
74-
if len(preds.shape) == len(target.shape):
75-
num_classes = 1
76-
if pos_label is None:
77-
rank_zero_warn('`pos_label` automatically set 1.')
78-
pos_label = 1
79-
preds = preds.flatten()
80-
target = target.flatten()
81-
82-
# multi class evaluation
83-
if len(preds.shape) == len(target.shape) + 1:
84-
if pos_label is not None:
85-
rank_zero_warn(
86-
'Argument `pos_label` should be `None` when running'
87-
f' multiclass precision recall curve. Got {pos_label}'
88-
)
89-
if num_classes != preds.shape[1]:
90-
raise ValueError(
91-
f'Argument `num_classes` was set to {num_classes} in'
92-
f' metric `precision_recall_curve` but detected {preds.shape[1]}'
93-
' number of classes from predictions'
94-
)
95-
preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
96-
target = target.flatten()
97-
98-
return preds, target, num_classes, pos_label
99-
100-
101-
def _precision_recall_curve_compute(
102-
preds: torch.Tensor,
103-
target: torch.Tensor,
104-
num_classes: int,
105-
pos_label: int,
106-
sample_weights: Optional[Sequence] = None,
107-
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
108-
List[torch.Tensor]]]:
109-
110-
if num_classes == 1:
111-
fps, tps, thresholds = _binary_clf_curve(
112-
preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label
113-
)
114-
115-
precision = tps / (tps + fps)
116-
recall = tps / tps[-1]
117-
118-
# stop when full recall attained
119-
# and reverse the outputs so recall is decreasing
120-
last_ind = torch.where(tps == tps[-1])[0][0]
121-
sl = slice(0, last_ind.item() + 1)
122-
123-
# need to call reversed explicitly, since including that to slice would
124-
# introduce negative strides that are not yet supported in pytorch
125-
precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)])
126-
127-
recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)])
128-
129-
thresholds = reversed(thresholds[sl]).clone()
130-
131-
return precision, recall, thresholds
132-
133-
# Recursively call per class
134-
precision, recall, thresholds = [], [], []
135-
for c in range(num_classes):
136-
preds_c = preds[:, c]
137-
res = precision_recall_curve(
138-
preds=preds_c,
139-
target=target,
140-
num_classes=1,
141-
pos_label=c,
142-
sample_weights=sample_weights,
143-
)
144-
precision.append(res[0])
145-
recall.append(res[1])
146-
thresholds.append(res[2])
147-
148-
return precision, recall, thresholds
149-
15022

23+
@deprecated(target=_precision_recall_curve, ver_deprecate="1.3.0", ver_remove="1.5.0")
15124
def precision_recall_curve(
15225
preds: torch.Tensor,
15326
target: torch.Tensor,
15427
num_classes: Optional[int] = None,
15528
pos_label: Optional[int] = None,
15629
sample_weights: Optional[Sequence] = None,
15730
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor],
158-
List[torch.Tensor]]]:
31+
List[torch.Tensor]],]:
15932
"""
160-
Computes precision-recall pairs for different thresholds.
161-
162-
Args:
163-
preds: predictions from model (probabilities)
164-
target: ground truth labels
165-
num_classes: integer with number of classes. Not nessesary to provide
166-
for binary problems.
167-
pos_label: integer determining the positive class. Default is ``None``
168-
which for binary problem is translate to 1. For multiclass problems
169-
this argument should not be set as we iteratively change it in the
170-
range [0,num_classes-1]
171-
sample_weights: sample weights for each data point
172-
173-
Returns:
174-
3-element tuple containing
175-
176-
precision:
177-
tensor where element i is the precision of predictions with
178-
score >= thresholds[i] and the last element is 1.
179-
If multiclass, this is a list of such tensors, one for each class.
180-
recall:
181-
tensor where element i is the recall of predictions with
182-
score >= thresholds[i] and the last element is 0.
183-
If multiclass, this is a list of such tensors, one for each class.
184-
thresholds:
185-
Thresholds used for computing precision/recall scores
186-
187-
Example (binary case):
188-
189-
>>> from pytorch_lightning.metrics.functional import precision_recall_curve
190-
>>> pred = torch.tensor([0, 1, 2, 3])
191-
>>> target = torch.tensor([0, 1, 1, 0])
192-
>>> precision, recall, thresholds = precision_recall_curve(pred, target, pos_label=1)
193-
>>> precision
194-
tensor([0.6667, 0.5000, 0.0000, 1.0000])
195-
>>> recall
196-
tensor([1.0000, 0.5000, 0.0000, 0.0000])
197-
>>> thresholds
198-
tensor([1, 2, 3])
199-
200-
Raises:
201-
ValueError:
202-
If ``preds`` and ``target`` don't have the same number of dimensions,
203-
or one additional dimension for ``preds``.
204-
ValueError:
205-
If the number of classes deduced from ``preds`` is not the same as the
206-
``num_classes`` provided.
207-
208-
Example (multiclass case):
209-
210-
>>> from pytorch_lightning.metrics.functional import precision_recall_curve
211-
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
212-
... [0.05, 0.75, 0.05, 0.05, 0.05],
213-
... [0.05, 0.05, 0.75, 0.05, 0.05],
214-
... [0.05, 0.05, 0.05, 0.75, 0.05]])
215-
>>> target = torch.tensor([0, 1, 3, 2])
216-
>>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5)
217-
>>> precision # doctest: +NORMALIZE_WHITESPACE
218-
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
219-
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
220-
>>> recall
221-
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
222-
>>> thresholds
223-
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
33+
.. deprecated::
34+
Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0.
22435
"""
225-
preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label)
226-
return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights)

0 commit comments

Comments
 (0)