|
14 | 14 | from typing import List, Optional, Sequence, Tuple, Union |
15 | 15 |
|
16 | 16 | import torch |
17 | | -import torch.nn.functional as F |
| 17 | +from torchmetrics.functional import precision_recall_curve as _precision_recall_curve |
18 | 18 |
|
19 | | -from pytorch_lightning.utilities import rank_zero_warn |
| 19 | +from pytorch_lightning.utilities.deprecation import deprecated |
20 | 20 |
|
21 | 21 |
|
22 | | -def _binary_clf_curve( |
23 | | - preds: torch.Tensor, |
24 | | - target: torch.Tensor, |
25 | | - sample_weights: Optional[Sequence] = None, |
26 | | - pos_label: int = 1., |
27 | | -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
28 | | - """ |
29 | | - adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py |
30 | | - """ |
31 | | - if sample_weights is not None and not isinstance(sample_weights, torch.Tensor): |
32 | | - sample_weights = torch.tensor(sample_weights, device=preds.device, dtype=torch.float) |
33 | | - |
34 | | - # remove class dimension if necessary |
35 | | - if preds.ndim > target.ndim: |
36 | | - preds = preds[:, 0] |
37 | | - desc_score_indices = torch.argsort(preds, descending=True) |
38 | | - |
39 | | - preds = preds[desc_score_indices] |
40 | | - target = target[desc_score_indices] |
41 | | - |
42 | | - if sample_weights is not None: |
43 | | - weight = sample_weights[desc_score_indices] |
44 | | - else: |
45 | | - weight = 1. |
46 | | - |
47 | | - # pred typically has many tied values. Here we extract |
48 | | - # the indices associated with the distinct values. We also |
49 | | - # concatenate a value for the end of the curve. |
50 | | - distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0] |
51 | | - threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) |
52 | | - target = (target == pos_label).to(torch.long) |
53 | | - tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] |
54 | | - |
55 | | - if sample_weights is not None: |
56 | | - # express fps as a cumsum to ensure fps is increasing even in |
57 | | - # the presence of floating point errors |
58 | | - fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] |
59 | | - else: |
60 | | - fps = 1 + threshold_idxs - tps |
61 | | - |
62 | | - return fps, tps, preds[threshold_idxs] |
63 | | - |
64 | | - |
65 | | -def _precision_recall_curve_update( |
66 | | - preds: torch.Tensor, |
67 | | - target: torch.Tensor, |
68 | | - num_classes: Optional[int] = None, |
69 | | - pos_label: Optional[int] = None, |
70 | | -) -> Tuple[torch.Tensor, torch.Tensor, int, int]: |
71 | | - if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): |
72 | | - raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") |
73 | | - # single class evaluation |
74 | | - if len(preds.shape) == len(target.shape): |
75 | | - num_classes = 1 |
76 | | - if pos_label is None: |
77 | | - rank_zero_warn('`pos_label` automatically set 1.') |
78 | | - pos_label = 1 |
79 | | - preds = preds.flatten() |
80 | | - target = target.flatten() |
81 | | - |
82 | | - # multi class evaluation |
83 | | - if len(preds.shape) == len(target.shape) + 1: |
84 | | - if pos_label is not None: |
85 | | - rank_zero_warn( |
86 | | - 'Argument `pos_label` should be `None` when running' |
87 | | - f' multiclass precision recall curve. Got {pos_label}' |
88 | | - ) |
89 | | - if num_classes != preds.shape[1]: |
90 | | - raise ValueError( |
91 | | - f'Argument `num_classes` was set to {num_classes} in' |
92 | | - f' metric `precision_recall_curve` but detected {preds.shape[1]}' |
93 | | - ' number of classes from predictions' |
94 | | - ) |
95 | | - preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) |
96 | | - target = target.flatten() |
97 | | - |
98 | | - return preds, target, num_classes, pos_label |
99 | | - |
100 | | - |
101 | | -def _precision_recall_curve_compute( |
102 | | - preds: torch.Tensor, |
103 | | - target: torch.Tensor, |
104 | | - num_classes: int, |
105 | | - pos_label: int, |
106 | | - sample_weights: Optional[Sequence] = None, |
107 | | -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], |
108 | | - List[torch.Tensor]]]: |
109 | | - |
110 | | - if num_classes == 1: |
111 | | - fps, tps, thresholds = _binary_clf_curve( |
112 | | - preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label |
113 | | - ) |
114 | | - |
115 | | - precision = tps / (tps + fps) |
116 | | - recall = tps / tps[-1] |
117 | | - |
118 | | - # stop when full recall attained |
119 | | - # and reverse the outputs so recall is decreasing |
120 | | - last_ind = torch.where(tps == tps[-1])[0][0] |
121 | | - sl = slice(0, last_ind.item() + 1) |
122 | | - |
123 | | - # need to call reversed explicitly, since including that to slice would |
124 | | - # introduce negative strides that are not yet supported in pytorch |
125 | | - precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)]) |
126 | | - |
127 | | - recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)]) |
128 | | - |
129 | | - thresholds = reversed(thresholds[sl]).clone() |
130 | | - |
131 | | - return precision, recall, thresholds |
132 | | - |
133 | | - # Recursively call per class |
134 | | - precision, recall, thresholds = [], [], [] |
135 | | - for c in range(num_classes): |
136 | | - preds_c = preds[:, c] |
137 | | - res = precision_recall_curve( |
138 | | - preds=preds_c, |
139 | | - target=target, |
140 | | - num_classes=1, |
141 | | - pos_label=c, |
142 | | - sample_weights=sample_weights, |
143 | | - ) |
144 | | - precision.append(res[0]) |
145 | | - recall.append(res[1]) |
146 | | - thresholds.append(res[2]) |
147 | | - |
148 | | - return precision, recall, thresholds |
149 | | - |
150 | 22 |
|
| 23 | +@deprecated(target=_precision_recall_curve, ver_deprecate="1.3.0", ver_remove="1.5.0") |
151 | 24 | def precision_recall_curve( |
152 | 25 | preds: torch.Tensor, |
153 | 26 | target: torch.Tensor, |
154 | 27 | num_classes: Optional[int] = None, |
155 | 28 | pos_label: Optional[int] = None, |
156 | 29 | sample_weights: Optional[Sequence] = None, |
157 | 30 | ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], |
158 | | - List[torch.Tensor]]]: |
| 31 | + List[torch.Tensor]],]: |
159 | 32 | """ |
160 | | - Computes precision-recall pairs for different thresholds. |
161 | | -
|
162 | | - Args: |
163 | | - preds: predictions from model (probabilities) |
164 | | - target: ground truth labels |
165 | | - num_classes: integer with number of classes. Not nessesary to provide |
166 | | - for binary problems. |
167 | | - pos_label: integer determining the positive class. Default is ``None`` |
168 | | - which for binary problem is translate to 1. For multiclass problems |
169 | | - this argument should not be set as we iteratively change it in the |
170 | | - range [0,num_classes-1] |
171 | | - sample_weights: sample weights for each data point |
172 | | -
|
173 | | - Returns: |
174 | | - 3-element tuple containing |
175 | | -
|
176 | | - precision: |
177 | | - tensor where element i is the precision of predictions with |
178 | | - score >= thresholds[i] and the last element is 1. |
179 | | - If multiclass, this is a list of such tensors, one for each class. |
180 | | - recall: |
181 | | - tensor where element i is the recall of predictions with |
182 | | - score >= thresholds[i] and the last element is 0. |
183 | | - If multiclass, this is a list of such tensors, one for each class. |
184 | | - thresholds: |
185 | | - Thresholds used for computing precision/recall scores |
186 | | -
|
187 | | - Example (binary case): |
188 | | -
|
189 | | - >>> from pytorch_lightning.metrics.functional import precision_recall_curve |
190 | | - >>> pred = torch.tensor([0, 1, 2, 3]) |
191 | | - >>> target = torch.tensor([0, 1, 1, 0]) |
192 | | - >>> precision, recall, thresholds = precision_recall_curve(pred, target, pos_label=1) |
193 | | - >>> precision |
194 | | - tensor([0.6667, 0.5000, 0.0000, 1.0000]) |
195 | | - >>> recall |
196 | | - tensor([1.0000, 0.5000, 0.0000, 0.0000]) |
197 | | - >>> thresholds |
198 | | - tensor([1, 2, 3]) |
199 | | -
|
200 | | - Raises: |
201 | | - ValueError: |
202 | | - If ``preds`` and ``target`` don't have the same number of dimensions, |
203 | | - or one additional dimension for ``preds``. |
204 | | - ValueError: |
205 | | - If the number of classes deduced from ``preds`` is not the same as the |
206 | | - ``num_classes`` provided. |
207 | | -
|
208 | | - Example (multiclass case): |
209 | | -
|
210 | | - >>> from pytorch_lightning.metrics.functional import precision_recall_curve |
211 | | - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], |
212 | | - ... [0.05, 0.75, 0.05, 0.05, 0.05], |
213 | | - ... [0.05, 0.05, 0.75, 0.05, 0.05], |
214 | | - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) |
215 | | - >>> target = torch.tensor([0, 1, 3, 2]) |
216 | | - >>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5) |
217 | | - >>> precision # doctest: +NORMALIZE_WHITESPACE |
218 | | - [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), |
219 | | - tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] |
220 | | - >>> recall |
221 | | - [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] |
222 | | - >>> thresholds |
223 | | - [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] |
| 33 | + .. deprecated:: |
| 34 | + Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. |
224 | 35 | """ |
225 | | - preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) |
226 | | - return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights) |
0 commit comments