99
1010def to_onehot (tensor : torch .Tensor ,
1111 n_classes : Optional [int ] = None ) -> torch .Tensor :
12- """ Converts a dense label tensor to one-hot format
12+ """
13+ Converts a dense label tensor to one-hot format
1314
1415 Args:
1516 tensor: dense label tensor, with shape [N, d1, d2, ...]
@@ -29,14 +30,31 @@ def to_onehot(tensor: torch.Tensor,
2930
3031
3132def to_categorical (tensor : torch .Tensor , argmax_dim : int = 1 ) -> torch .Tensor :
32- """ Converts a tensor of probabilities to a dense label tensor """
33+ """
34+ Converts a tensor of probabilities to a dense label tensor
35+
36+ Args:
37+ tensor: probabilities to get the categorical label [N, d1, d2, ...]
38+ argmax_dim: dimension to apply (default: 1)
39+
40+ Return:
41+ A tensor with categorical labels [N, d2, ...]
42+ """
3343 return torch .argmax (tensor , dim = argmax_dim )
3444
3545
3646def get_num_classes (pred : torch .Tensor , target : torch .Tensor ,
3747 num_classes : Optional [int ]) -> int :
38- """ Returns the number of classes for a given prediction and
39- target tensor
48+ """
49+ Returns the number of classes for a given prediction and target tensor.
50+
51+ Args:
52+ pred: predicted values
53+ target: true labels
54+ num_classes: number of classes if known (default: None)
55+
56+ Return:
57+ An integer that represents the number of classes.
4058 """
4159 if num_classes is None :
4260 if pred .ndim > target .ndim :
@@ -50,8 +68,9 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor,
5068 class_index : int , argmax_dim : int = 1
5169 ) -> Tuple [torch .Tensor , torch .Tensor ,
5270 torch .Tensor , torch .Tensor ]:
53- """ Calculates the number of true postive, false postive, true negative
54- and false negative for a specfic class
71+ """
72+ Calculates the number of true positive, falsepositivee, true negative
73+ and false negative for a specific class
5574
5675 Args:
5776 pred: prediction tensor
@@ -63,6 +82,9 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor,
6382 argmax_dim: if pred is a tensor of probabilities, this indicates the
6483 axis the argmax transformation will be applied over
6584
85+ Return:
86+ Tensors in the following order: True Positive, False Positive, True Negative, False Negative
87+
6688 """
6789 if pred .ndim == target .ndim + 1 :
6890 pred = to_categorical (pred , argmax_dim = argmax_dim )
@@ -80,20 +102,21 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
80102 argmax_dim : int = 1
81103 ) -> Tuple [torch .Tensor , torch .Tensor ,
82104 torch .Tensor , torch .Tensor ]:
83- """ Calls the stat_scores function iteratively for all classes, thus
84- calculating the number of true postive, false postive, true negative
85- and false negative for each class
105+ """
106+ Calls the stat_scores function iteratively for all classes, thus
107+ calculating the number of true postive, false postive, true negative
108+ and false negative for each class
86109
87110 Args:
88111 pred: prediction tensor
89-
90112 target: target tensor
91-
92113 class_index: class to calculate over
93-
94114 argmax_dim: if pred is a tensor of probabilities, this indicates the
95115 axis the argmax transformation will be applied over
96116
117+ Return:
118+ Returns tensors for: tp, fp, tn, fn
119+
97120 """
98121 num_classes = get_num_classes (pred = pred , target = target ,
99122 num_classes = num_classes )
@@ -116,6 +139,23 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
116139def accuracy (pred : torch .Tensor , target : torch .Tensor ,
117140 num_classes : Optional [int ] = None ,
118141 reduction = 'elementwise_mean' ) -> torch .Tensor :
142+ """
143+ Computes the accuracy classification score
144+
145+ Args:
146+ pred: predicted labels
147+ target: ground truth labels
148+ num_classes: number of classes
149+ reduction: a method for reducing accuracies over labels (default: takes the mean)
150+ Available reduction methods:
151+
152+ - elementwise_mean: takes the mean
153+ - none: pass array
154+ - sum: add elements
155+
156+ Return:
157+ A Tensor with the classification score.
158+ """
119159 tps , fps , tns , fns = stat_scores_multiple_classes (pred = pred , target = target ,
120160 num_classes = num_classes )
121161
@@ -129,6 +169,18 @@ def accuracy(pred: torch.Tensor, target: torch.Tensor,
129169
130170def confusion_matrix (pred : torch .Tensor , target : torch .Tensor ,
131171 normalize : bool = False ) -> torch .Tensor :
172+ """
173+ Computes the confusion matrix C where each entry C_{i,j} is the number of observations
174+ in group i that were predicted in group j.
175+
176+ Args:
177+ pred: estimated targets
178+ target: ground truth labels
179+ normalize: normalizes confusion matrix
180+
181+ Return:
182+ Tensor, confusion matrix C [num_classes, num_classes ]
183+ """
132184 num_classes = get_num_classes (pred , target , None )
133185
134186 d = target .size (- 1 )
@@ -149,6 +201,23 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
149201 num_classes : Optional [int ] = None ,
150202 reduction : str = 'elementwise_mean'
151203 ) -> Tuple [torch .Tensor , torch .Tensor ]:
204+ """
205+ Computes precision and recall for different thresholds
206+
207+ Args:
208+ pred: estimated probabilities
209+ target: ground-truth labels
210+ num_classes: number of classes
211+ reduction: method for reducing precision-recall values (default: takes the mean)
212+ Available reduction methods:
213+
214+ - elementwise_mean: takes the mean
215+ - none: pass array
216+ - sum: add elements
217+
218+ Return:
219+ Tensor with precision and recall
220+ """
152221 tps , fps , tns , fns = stat_scores_multiple_classes (pred = pred ,
153222 target = target ,
154223 num_classes = num_classes )
@@ -168,20 +237,77 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
168237def precision (pred : torch .Tensor , target : torch .Tensor ,
169238 num_classes : Optional [int ] = None ,
170239 reduction : str = 'elementwise_mean' ) -> torch .Tensor :
240+ """
241+ Computes precision score.
242+
243+ Args:
244+ pred: estimated probabilities
245+ target: ground-truth labels
246+ num_classes: number of classes
247+ reduction: method for reducing precision values (default: takes the mean)
248+ Available reduction methods:
249+
250+ - elementwise_mean: takes the mean
251+ - none: pass array
252+ - sum: add elements
253+
254+ Return:
255+ Tensor with precision.
256+ """
171257 return precision_recall (pred = pred , target = target ,
172258 num_classes = num_classes , reduction = reduction )[0 ]
173259
174260
175261def recall (pred : torch .Tensor , target : torch .Tensor ,
176262 num_classes : Optional [int ] = None ,
177263 reduction : str = 'elementwise_mean' ) -> torch .Tensor :
264+ """
265+ Computes recall score.
266+
267+ Args:
268+ pred: estimated probabilities
269+ target: ground-truth labels
270+ num_classes: number of classes
271+ reduction: method for reducing recall values (default: takes the mean)
272+ Available reduction methods:
273+
274+ - elementwise_mean: takes the mean
275+ - none: pass array
276+ - sum: add elements
277+
278+ Return:
279+ Tensor with recall.
280+ """
178281 return precision_recall (pred = pred , target = target ,
179282 num_classes = num_classes , reduction = reduction )[1 ]
180283
181284
182285def fbeta_score (pred : torch .Tensor , target : torch .Tensor , beta : float ,
183286 num_classes : Optional [int ] = None ,
184287 reduction : str = 'elementwise_mean' ) -> torch .Tensor :
288+ """
289+ Computes the F-beta score which is a weighted harmonic mean of precision and recall.
290+ It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
291+
292+ Args:
293+ pred: estimated probabilities
294+ target: ground-truth labels
295+ beta: weights recall when combining the score.
296+ beta < 1: more weight to precision.
297+ beta > 1 more weight to recall
298+ beta = 0: only precision
299+ beta -> inf: only recall
300+ num_classes: number of classes
301+ reduction: method for reducing F-score (default: takes the mean)
302+ Available reduction methods:
303+
304+ - elementwise_mean: takes the mean
305+ - none: pass array
306+ - sum: add elements.
307+
308+ Return:
309+ Tensor with the value of F-score. It is a value between 0-1.
310+ """
185311 prec , rec = precision_recall (pred = pred , target = target ,
186312 num_classes = num_classes ,
187313 reduction = 'none' )
@@ -196,6 +322,23 @@ def fbeta_score(pred: torch.Tensor, target: torch.Tensor, beta: float,
196322def f1_score (pred : torch .Tensor , target : torch .Tensor ,
197323 num_classes : Optional [int ] = None ,
198324 reduction = 'elementwise_mean' ) -> torch .Tensor :
325+ """
326+ Computes F1-score a.k.a F-measure.
327+
328+ Args:
329+ pred: estimated probabilities
330+ target: ground-truth labels
331+ num_classes: number of classes
332+ reduction: method for reducing F1-score (default: takes the mean)
333+ Available reduction methods:
334+
335+ - elementwise_mean: takes the mean
336+ - none: pass array
337+ - sum: add elements.
338+
339+ Return:
340+ Tensor containing F1-score
341+ """
199342 return fbeta_score (pred = pred , target = target , beta = 1. ,
200343 num_classes = num_classes , reduction = reduction )
201344
@@ -251,6 +394,18 @@ def roc(pred: torch.Tensor, target: torch.Tensor,
251394 pos_label : int = 1. ) -> Tuple [torch .Tensor ,
252395 torch .Tensor ,
253396 torch .Tensor ]:
397+ """
398+ Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
399+
400+ Args:
401+ pred: estimated probabilities
402+ target: ground-truth labels
403+ sample_weight: sample weights
404+ pos_label: the label for the positive class (default: 1)
405+
406+ Return:
407+ [Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds
408+ """
254409 fps , tps , thresholds = _binary_clf_curve (pred = pred , target = target ,
255410 sample_weight = sample_weight ,
256411 pos_label = pos_label )
@@ -282,6 +437,19 @@ def multiclass_roc(pred: torch.Tensor, target: torch.Tensor,
282437 ) -> Tuple [Tuple [torch .Tensor ,
283438 torch .Tensor ,
284439 torch .Tensor ]]:
440+ """
441+ Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
442+
443+ Args:
444+ pred: estimated probabilities
445+ target: ground-truth labels
446+ sample_weight: sample weights
447+ num_classes: number of classes (default: None, computes automatically from data)
448+
449+ Return:
450+ [num_classes, Tensor, Tensor, Tensor]: returns roc for each class.
451+ number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
452+ """
285453 num_classes = get_num_classes (pred , target , num_classes )
286454
287455 class_roc_vals = []
@@ -301,6 +469,18 @@ def precision_recall_curve(pred: torch.Tensor,
301469 pos_label : int = 1. ) -> Tuple [torch .Tensor ,
302470 torch .Tensor ,
303471 torch .Tensor ]:
472+ """
473+ Computes precision-recall pairs for different thresholds.
474+
475+ Args:
476+ pred: estimated probabilities
477+ target: ground-truth labels
478+ sample_weight: sample weights
479+ pos_label: the label for the positive class (default: 1.)
480+
481+ Return:
482+ [Tensor, Tensor, Tensor]: precision, recall, thresholds
483+ """
304484 fps , tps , thresholds = _binary_clf_curve (pred = pred , target = target ,
305485 sample_weight = sample_weight ,
306486 pos_label = pos_label )
@@ -334,6 +514,18 @@ def multiclass_precision_recall_curve(pred: torch.Tensor, target: torch.Tensor,
334514 ) -> Tuple [Tuple [torch .Tensor ,
335515 torch .Tensor ,
336516 torch .Tensor ]]:
517+ """
518+ Computes precision-recall pairs for different thresholds given a multiclass scores.
519+
520+ Args:
521+ pred: estimated probabilities
522+ target: ground-truth labels
523+ sample_weight: sample weight
524+ num_classes: number of classes
525+
526+ Return:
527+ [num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds
528+ """
337529 num_classes = get_num_classes (pred , target , num_classes )
338530
339531 class_pr_vals = []
@@ -350,6 +542,17 @@ def multiclass_precision_recall_curve(pred: torch.Tensor, target: torch.Tensor,
350542
351543
352544def auc (x : torch .Tensor , y : torch .Tensor , reorder : bool = True ):
545+ """
546+ Computes Area Under the Curve (AUC) using the trapezoidal rule
547+
548+ Args:
549+ x: x-coordinates
550+ y: y-coordinates
551+ reorder: reorder coordinates, so they are increasing.
552+
553+ Return:
554+ AUC score (float)
555+ """
353556 direction = 1.
354557
355558 if reorder :
@@ -400,6 +603,15 @@ def new_func(*args, **kwargs) -> torch.Tensor:
400603def auroc (pred : torch .Tensor , target : torch .Tensor ,
401604 sample_weight : Optional [Sequence ] = None ,
402605 pos_label : int = 1. ) -> torch .Tensor :
606+ """
607+ Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores
608+
609+ Args:
610+ pred: estimated probabilities
611+ target: ground-truth labels
612+ sample_weight: sample weights
613+ pos_label: the label for the positive class (default: 1.)
614+ """
403615 return roc (pred = pred , target = target , sample_weight = sample_weight ,
404616 pos_label = pos_label )
405617
@@ -410,7 +622,6 @@ def average_precision(pred: torch.Tensor, target: torch.Tensor,
410622 precision , recall , _ = precision_recall_curve (pred = pred , target = target ,
411623 sample_weight = sample_weight ,
412624 pos_label = pos_label )
413-
414625 # Return the step function integral
415626 # The following works because the last entry of precision is
416627 # guaranteed to be 1, as returned by precision_recall_curve
0 commit comments