Skip to content

Commit 3e979d0

Browse files
cuentBorda
andcommitted
Add documentation to native metrics (#2144)
* add docs * add docs * Apply suggestions from code review * formatting * add docs Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka <[email protected]>
1 parent c3ad1ca commit 3e979d0

File tree

1 file changed

+224
-13
lines changed

1 file changed

+224
-13
lines changed

pytorch_lightning/metrics/functional/classification.py

Lines changed: 224 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
def to_onehot(tensor: torch.Tensor,
1111
n_classes: Optional[int] = None) -> torch.Tensor:
12-
""" Converts a dense label tensor to one-hot format
12+
"""
13+
Converts a dense label tensor to one-hot format
1314
1415
Args:
1516
tensor: dense label tensor, with shape [N, d1, d2, ...]
@@ -29,14 +30,31 @@ def to_onehot(tensor: torch.Tensor,
2930

3031

3132
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
32-
""" Converts a tensor of probabilities to a dense label tensor """
33+
"""
34+
Converts a tensor of probabilities to a dense label tensor
35+
36+
Args:
37+
tensor: probabilities to get the categorical label [N, d1, d2, ...]
38+
argmax_dim: dimension to apply (default: 1)
39+
40+
Return:
41+
A tensor with categorical labels [N, d2, ...]
42+
"""
3343
return torch.argmax(tensor, dim=argmax_dim)
3444

3545

3646
def get_num_classes(pred: torch.Tensor, target: torch.Tensor,
3747
num_classes: Optional[int]) -> int:
38-
""" Returns the number of classes for a given prediction and
39-
target tensor
48+
"""
49+
Returns the number of classes for a given prediction and target tensor.
50+
51+
Args:
52+
pred: predicted values
53+
target: true labels
54+
num_classes: number of classes if known (default: None)
55+
56+
Return:
57+
An integer that represents the number of classes.
4058
"""
4159
if num_classes is None:
4260
if pred.ndim > target.ndim:
@@ -50,8 +68,9 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor,
5068
class_index: int, argmax_dim: int = 1
5169
) -> Tuple[torch.Tensor, torch.Tensor,
5270
torch.Tensor, torch.Tensor]:
53-
""" Calculates the number of true postive, false postive, true negative
54-
and false negative for a specfic class
71+
"""
72+
Calculates the number of true positive, falsepositivee, true negative
73+
and false negative for a specific class
5574
5675
Args:
5776
pred: prediction tensor
@@ -63,6 +82,9 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor,
6382
argmax_dim: if pred is a tensor of probabilities, this indicates the
6483
axis the argmax transformation will be applied over
6584
85+
Return:
86+
Tensors in the following order: True Positive, False Positive, True Negative, False Negative
87+
6688
"""
6789
if pred.ndim == target.ndim + 1:
6890
pred = to_categorical(pred, argmax_dim=argmax_dim)
@@ -80,20 +102,21 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
80102
argmax_dim: int = 1
81103
) -> Tuple[torch.Tensor, torch.Tensor,
82104
torch.Tensor, torch.Tensor]:
83-
""" Calls the stat_scores function iteratively for all classes, thus
84-
calculating the number of true postive, false postive, true negative
85-
and false negative for each class
105+
"""
106+
Calls the stat_scores function iteratively for all classes, thus
107+
calculating the number of true postive, false postive, true negative
108+
and false negative for each class
86109
87110
Args:
88111
pred: prediction tensor
89-
90112
target: target tensor
91-
92113
class_index: class to calculate over
93-
94114
argmax_dim: if pred is a tensor of probabilities, this indicates the
95115
axis the argmax transformation will be applied over
96116
117+
Return:
118+
Returns tensors for: tp, fp, tn, fn
119+
97120
"""
98121
num_classes = get_num_classes(pred=pred, target=target,
99122
num_classes=num_classes)
@@ -116,6 +139,23 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
116139
def accuracy(pred: torch.Tensor, target: torch.Tensor,
117140
num_classes: Optional[int] = None,
118141
reduction='elementwise_mean') -> torch.Tensor:
142+
"""
143+
Computes the accuracy classification score
144+
145+
Args:
146+
pred: predicted labels
147+
target: ground truth labels
148+
num_classes: number of classes
149+
reduction: a method for reducing accuracies over labels (default: takes the mean)
150+
Available reduction methods:
151+
152+
- elementwise_mean: takes the mean
153+
- none: pass array
154+
- sum: add elements
155+
156+
Return:
157+
A Tensor with the classification score.
158+
"""
119159
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target,
120160
num_classes=num_classes)
121161

@@ -129,6 +169,18 @@ def accuracy(pred: torch.Tensor, target: torch.Tensor,
129169

130170
def confusion_matrix(pred: torch.Tensor, target: torch.Tensor,
131171
normalize: bool = False) -> torch.Tensor:
172+
"""
173+
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
174+
in group i that were predicted in group j.
175+
176+
Args:
177+
pred: estimated targets
178+
target: ground truth labels
179+
normalize: normalizes confusion matrix
180+
181+
Return:
182+
Tensor, confusion matrix C [num_classes, num_classes ]
183+
"""
132184
num_classes = get_num_classes(pred, target, None)
133185

134186
d = target.size(-1)
@@ -149,6 +201,23 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
149201
num_classes: Optional[int] = None,
150202
reduction: str = 'elementwise_mean'
151203
) -> Tuple[torch.Tensor, torch.Tensor]:
204+
"""
205+
Computes precision and recall for different thresholds
206+
207+
Args:
208+
pred: estimated probabilities
209+
target: ground-truth labels
210+
num_classes: number of classes
211+
reduction: method for reducing precision-recall values (default: takes the mean)
212+
Available reduction methods:
213+
214+
- elementwise_mean: takes the mean
215+
- none: pass array
216+
- sum: add elements
217+
218+
Return:
219+
Tensor with precision and recall
220+
"""
152221
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred,
153222
target=target,
154223
num_classes=num_classes)
@@ -168,20 +237,77 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
168237
def precision(pred: torch.Tensor, target: torch.Tensor,
169238
num_classes: Optional[int] = None,
170239
reduction: str = 'elementwise_mean') -> torch.Tensor:
240+
"""
241+
Computes precision score.
242+
243+
Args:
244+
pred: estimated probabilities
245+
target: ground-truth labels
246+
num_classes: number of classes
247+
reduction: method for reducing precision values (default: takes the mean)
248+
Available reduction methods:
249+
250+
- elementwise_mean: takes the mean
251+
- none: pass array
252+
- sum: add elements
253+
254+
Return:
255+
Tensor with precision.
256+
"""
171257
return precision_recall(pred=pred, target=target,
172258
num_classes=num_classes, reduction=reduction)[0]
173259

174260

175261
def recall(pred: torch.Tensor, target: torch.Tensor,
176262
num_classes: Optional[int] = None,
177263
reduction: str = 'elementwise_mean') -> torch.Tensor:
264+
"""
265+
Computes recall score.
266+
267+
Args:
268+
pred: estimated probabilities
269+
target: ground-truth labels
270+
num_classes: number of classes
271+
reduction: method for reducing recall values (default: takes the mean)
272+
Available reduction methods:
273+
274+
- elementwise_mean: takes the mean
275+
- none: pass array
276+
- sum: add elements
277+
278+
Return:
279+
Tensor with recall.
280+
"""
178281
return precision_recall(pred=pred, target=target,
179282
num_classes=num_classes, reduction=reduction)[1]
180283

181284

182285
def fbeta_score(pred: torch.Tensor, target: torch.Tensor, beta: float,
183286
num_classes: Optional[int] = None,
184287
reduction: str = 'elementwise_mean') -> torch.Tensor:
288+
"""
289+
Computes the F-beta score which is a weighted harmonic mean of precision and recall.
290+
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
291+
292+
Args:
293+
pred: estimated probabilities
294+
target: ground-truth labels
295+
beta: weights recall when combining the score.
296+
beta < 1: more weight to precision.
297+
beta > 1 more weight to recall
298+
beta = 0: only precision
299+
beta -> inf: only recall
300+
num_classes: number of classes
301+
reduction: method for reducing F-score (default: takes the mean)
302+
Available reduction methods:
303+
304+
- elementwise_mean: takes the mean
305+
- none: pass array
306+
- sum: add elements.
307+
308+
Return:
309+
Tensor with the value of F-score. It is a value between 0-1.
310+
"""
185311
prec, rec = precision_recall(pred=pred, target=target,
186312
num_classes=num_classes,
187313
reduction='none')
@@ -196,6 +322,23 @@ def fbeta_score(pred: torch.Tensor, target: torch.Tensor, beta: float,
196322
def f1_score(pred: torch.Tensor, target: torch.Tensor,
197323
num_classes: Optional[int] = None,
198324
reduction='elementwise_mean') -> torch.Tensor:
325+
"""
326+
Computes F1-score a.k.a F-measure.
327+
328+
Args:
329+
pred: estimated probabilities
330+
target: ground-truth labels
331+
num_classes: number of classes
332+
reduction: method for reducing F1-score (default: takes the mean)
333+
Available reduction methods:
334+
335+
- elementwise_mean: takes the mean
336+
- none: pass array
337+
- sum: add elements.
338+
339+
Return:
340+
Tensor containing F1-score
341+
"""
199342
return fbeta_score(pred=pred, target=target, beta=1.,
200343
num_classes=num_classes, reduction=reduction)
201344

@@ -251,6 +394,18 @@ def roc(pred: torch.Tensor, target: torch.Tensor,
251394
pos_label: int = 1.) -> Tuple[torch.Tensor,
252395
torch.Tensor,
253396
torch.Tensor]:
397+
"""
398+
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
399+
400+
Args:
401+
pred: estimated probabilities
402+
target: ground-truth labels
403+
sample_weight: sample weights
404+
pos_label: the label for the positive class (default: 1)
405+
406+
Return:
407+
[Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds
408+
"""
254409
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
255410
sample_weight=sample_weight,
256411
pos_label=pos_label)
@@ -282,6 +437,19 @@ def multiclass_roc(pred: torch.Tensor, target: torch.Tensor,
282437
) -> Tuple[Tuple[torch.Tensor,
283438
torch.Tensor,
284439
torch.Tensor]]:
440+
"""
441+
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
442+
443+
Args:
444+
pred: estimated probabilities
445+
target: ground-truth labels
446+
sample_weight: sample weights
447+
num_classes: number of classes (default: None, computes automatically from data)
448+
449+
Return:
450+
[num_classes, Tensor, Tensor, Tensor]: returns roc for each class.
451+
number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
452+
"""
285453
num_classes = get_num_classes(pred, target, num_classes)
286454

287455
class_roc_vals = []
@@ -301,6 +469,18 @@ def precision_recall_curve(pred: torch.Tensor,
301469
pos_label: int = 1.) -> Tuple[torch.Tensor,
302470
torch.Tensor,
303471
torch.Tensor]:
472+
"""
473+
Computes precision-recall pairs for different thresholds.
474+
475+
Args:
476+
pred: estimated probabilities
477+
target: ground-truth labels
478+
sample_weight: sample weights
479+
pos_label: the label for the positive class (default: 1.)
480+
481+
Return:
482+
[Tensor, Tensor, Tensor]: precision, recall, thresholds
483+
"""
304484
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
305485
sample_weight=sample_weight,
306486
pos_label=pos_label)
@@ -334,6 +514,18 @@ def multiclass_precision_recall_curve(pred: torch.Tensor, target: torch.Tensor,
334514
) -> Tuple[Tuple[torch.Tensor,
335515
torch.Tensor,
336516
torch.Tensor]]:
517+
"""
518+
Computes precision-recall pairs for different thresholds given a multiclass scores.
519+
520+
Args:
521+
pred: estimated probabilities
522+
target: ground-truth labels
523+
sample_weight: sample weight
524+
num_classes: number of classes
525+
526+
Return:
527+
[num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds
528+
"""
337529
num_classes = get_num_classes(pred, target, num_classes)
338530

339531
class_pr_vals = []
@@ -350,6 +542,17 @@ def multiclass_precision_recall_curve(pred: torch.Tensor, target: torch.Tensor,
350542

351543

352544
def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):
545+
"""
546+
Computes Area Under the Curve (AUC) using the trapezoidal rule
547+
548+
Args:
549+
x: x-coordinates
550+
y: y-coordinates
551+
reorder: reorder coordinates, so they are increasing.
552+
553+
Return:
554+
AUC score (float)
555+
"""
353556
direction = 1.
354557

355558
if reorder:
@@ -400,6 +603,15 @@ def new_func(*args, **kwargs) -> torch.Tensor:
400603
def auroc(pred: torch.Tensor, target: torch.Tensor,
401604
sample_weight: Optional[Sequence] = None,
402605
pos_label: int = 1.) -> torch.Tensor:
606+
"""
607+
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores
608+
609+
Args:
610+
pred: estimated probabilities
611+
target: ground-truth labels
612+
sample_weight: sample weights
613+
pos_label: the label for the positive class (default: 1.)
614+
"""
403615
return roc(pred=pred, target=target, sample_weight=sample_weight,
404616
pos_label=pos_label)
405617

@@ -410,7 +622,6 @@ def average_precision(pred: torch.Tensor, target: torch.Tensor,
410622
precision, recall, _ = precision_recall_curve(pred=pred, target=target,
411623
sample_weight=sample_weight,
412624
pos_label=pos_label)
413-
414625
# Return the step function integral
415626
# The following works because the last entry of precision is
416627
# guaranteed to be 1, as returned by precision_recall_curve

0 commit comments

Comments
 (0)