@@ -144,14 +144,19 @@ def accuracy(pred: torch.Tensor, target: torch.Tensor,
144144
145145 Args:
146146 pred: predicted labels
147+
147148 target: ground truth labels
149+
148150 num_classes: number of classes
151+
149152 reduction: method for reducing accuracies over labels (default: takes the mean)
150153
151154 Available reduction methods:
152155
153156 - elementwise_mean: takes the mean
157+
154158 - none: pass array
159+
155160 - sum: add elements
156161
157162 Output:
@@ -170,6 +175,20 @@ def accuracy(pred: torch.Tensor, target: torch.Tensor,
170175
171176def confusion_matrix (pred : torch .Tensor , target : torch .Tensor ,
172177 normalize : bool = False ) -> torch .Tensor :
178+ '''
179+ Computes the confusion matrix C where each entry C_{i,j} is the number of observations
180+ in group i that were predicted in group j.
181+
182+ Args:
183+ pred: estimated targets
184+
185+ target: groud truth labels
186+
187+ normalize: normalizes confusion matrix
188+
189+ Output:
190+ Tensor, confusion matrix C [num_classes, num_classes ]
191+ '''
173192 num_classes = get_num_classes (pred , target , None )
174193
175194 d = target .size (- 1 )
@@ -190,6 +209,30 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
190209 num_classes : Optional [int ] = None ,
191210 reduction : str = 'elementwise_mean'
192211 ) -> Tuple [torch .Tensor , torch .Tensor ]:
212+ '''
213+ Computes precision and recall for different thresholds
214+
215+ Args:
216+
217+ pred: estimated probabilities
218+
219+ target: ground-truth labels
220+
221+ num_classes: number of classes
222+
223+ reduction: method for reducing precision-recall values (default: takes the mean)
224+
225+ Available reduction methods:
226+
227+ - elementwise_mean: takes the mean
228+
229+ - none: pass array
230+
231+ - sum: add elements
232+
233+ Output:
234+ Tensor with precision and recall
235+ '''
193236 tps , fps , tns , fns = stat_scores_multiple_classes (pred = pred ,
194237 target = target ,
195238 num_classes = num_classes )
@@ -209,13 +252,61 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
209252def precision (pred : torch .Tensor , target : torch .Tensor ,
210253 num_classes : Optional [int ] = None ,
211254 reduction : str = 'elementwise_mean' ) -> torch .Tensor :
255+ '''
256+ Computes precision score.
257+
258+ Args:
259+
260+ pred: estimated probabilities
261+
262+ target: ground-truth labels
263+
264+ num_classes: number of classes
265+
266+ reduction: method for reducing precision values (default: takes the mean)
267+
268+ Available reduction methods:
269+
270+ - elementwise_mean: takes the mean
271+
272+ - none: pass array
273+
274+ - sum: add elements
275+
276+ Output:
277+ Tensor with precision.
278+ '''
212279 return precision_recall (pred = pred , target = target ,
213280 num_classes = num_classes , reduction = reduction )[0 ]
214281
215282
216283def recall (pred : torch .Tensor , target : torch .Tensor ,
217284 num_classes : Optional [int ] = None ,
218285 reduction : str = 'elementwise_mean' ) -> torch .Tensor :
286+ '''
287+ Computes recall score.
288+
289+ Args:
290+
291+ pred: estimated probabilities
292+
293+ target: ground-truth labels
294+
295+ num_classes: number of classes
296+
297+ reduction: method for reducing recall values (default: takes the mean)
298+
299+ Available reduction methods:
300+
301+ - elementwise_mean: takes the mean
302+
303+ - none: pass array
304+
305+ - sum: add elements
306+
307+ Output:
308+ Tensor with recall.
309+ '''
219310 return precision_recall (pred = pred , target = target ,
220311 num_classes = num_classes , reduction = reduction )[1 ]
221312
0 commit comments