Skip to content

Commit 4d6b032

Browse files
cuentBorda
authored andcommitted
add docs
1 parent b2c877a commit 4d6b032

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

pytorch_lightning/metrics/functional/classification.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,19 @@ def accuracy(pred: torch.Tensor, target: torch.Tensor,
144144
145145
Args:
146146
pred: predicted labels
147+
147148
target: ground truth labels
149+
148150
num_classes: number of classes
151+
149152
reduction: method for reducing accuracies over labels (default: takes the mean)
150153
151154
Available reduction methods:
152155
153156
- elementwise_mean: takes the mean
157+
154158
- none: pass array
159+
155160
- sum: add elements
156161
157162
Output:
@@ -170,6 +175,20 @@ def accuracy(pred: torch.Tensor, target: torch.Tensor,
170175

171176
def confusion_matrix(pred: torch.Tensor, target: torch.Tensor,
172177
normalize: bool = False) -> torch.Tensor:
178+
'''
179+
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
180+
in group i that were predicted in group j.
181+
182+
Args:
183+
pred: estimated targets
184+
185+
target: groud truth labels
186+
187+
normalize: normalizes confusion matrix
188+
189+
Output:
190+
Tensor, confusion matrix C [num_classes, num_classes ]
191+
'''
173192
num_classes = get_num_classes(pred, target, None)
174193

175194
d = target.size(-1)
@@ -190,6 +209,30 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
190209
num_classes: Optional[int] = None,
191210
reduction: str = 'elementwise_mean'
192211
) -> Tuple[torch.Tensor, torch.Tensor]:
212+
'''
213+
Computes precision and recall for different thresholds
214+
215+
Args:
216+
217+
pred: estimated probabilities
218+
219+
target: ground-truth labels
220+
221+
num_classes: number of classes
222+
223+
reduction: method for reducing precision-recall values (default: takes the mean)
224+
225+
Available reduction methods:
226+
227+
- elementwise_mean: takes the mean
228+
229+
- none: pass array
230+
231+
- sum: add elements
232+
233+
Output:
234+
Tensor with precision and recall
235+
'''
193236
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred,
194237
target=target,
195238
num_classes=num_classes)
@@ -209,13 +252,61 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
209252
def precision(pred: torch.Tensor, target: torch.Tensor,
210253
num_classes: Optional[int] = None,
211254
reduction: str = 'elementwise_mean') -> torch.Tensor:
255+
'''
256+
Computes precision score.
257+
258+
Args:
259+
260+
pred: estimated probabilities
261+
262+
target: ground-truth labels
263+
264+
num_classes: number of classes
265+
266+
reduction: method for reducing precision values (default: takes the mean)
267+
268+
Available reduction methods:
269+
270+
- elementwise_mean: takes the mean
271+
272+
- none: pass array
273+
274+
- sum: add elements
275+
276+
Output:
277+
Tensor with precision.
278+
'''
212279
return precision_recall(pred=pred, target=target,
213280
num_classes=num_classes, reduction=reduction)[0]
214281

215282

216283
def recall(pred: torch.Tensor, target: torch.Tensor,
217284
num_classes: Optional[int] = None,
218285
reduction: str = 'elementwise_mean') -> torch.Tensor:
286+
'''
287+
Computes recall score.
288+
289+
Args:
290+
291+
pred: estimated probabilities
292+
293+
target: ground-truth labels
294+
295+
num_classes: number of classes
296+
297+
reduction: method for reducing recall values (default: takes the mean)
298+
299+
Available reduction methods:
300+
301+
- elementwise_mean: takes the mean
302+
303+
- none: pass array
304+
305+
- sum: add elements
306+
307+
Output:
308+
Tensor with recall.
309+
'''
219310
return precision_recall(pred=pred, target=target,
220311
num_classes=num_classes, reduction=reduction)[1]
221312

0 commit comments

Comments
 (0)