|
17 | 17 | import torch |
18 | 18 | from torch.nn import functional as F |
19 | 19 |
|
| 20 | +from pytorch_lightning.metrics.functional import roc |
| 21 | +from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve |
20 | 22 | from pytorch_lightning.metrics.utils import to_categorical, get_num_classes, reduce, class_reduce |
21 | 23 | from pytorch_lightning.utilities import rank_zero_warn |
22 | 24 |
|
@@ -332,107 +334,6 @@ def recall( |
332 | 334 | num_classes=num_classes, class_reduction=class_reduction)[1] |
333 | 335 |
|
334 | 336 |
|
335 | | -def _binary_clf_curve( |
336 | | - pred: torch.Tensor, |
337 | | - target: torch.Tensor, |
338 | | - sample_weight: Optional[Sequence] = None, |
339 | | - pos_label: int = 1., |
340 | | -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
341 | | - """ |
342 | | - adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py |
343 | | - """ |
344 | | - if sample_weight is not None and not isinstance(sample_weight, torch.Tensor): |
345 | | - sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float) |
346 | | - |
347 | | - # remove class dimension if necessary |
348 | | - if pred.ndim > target.ndim: |
349 | | - pred = pred[:, 0] |
350 | | - desc_score_indices = torch.argsort(pred, descending=True) |
351 | | - |
352 | | - pred = pred[desc_score_indices] |
353 | | - target = target[desc_score_indices] |
354 | | - |
355 | | - if sample_weight is not None: |
356 | | - weight = sample_weight[desc_score_indices] |
357 | | - else: |
358 | | - weight = 1. |
359 | | - |
360 | | - # pred typically has many tied values. Here we extract |
361 | | - # the indices associated with the distinct values. We also |
362 | | - # concatenate a value for the end of the curve. |
363 | | - distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0] |
364 | | - threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) |
365 | | - |
366 | | - target = (target == pos_label).to(torch.long) |
367 | | - tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] |
368 | | - |
369 | | - if sample_weight is not None: |
370 | | - # express fps as a cumsum to ensure fps is increasing even in |
371 | | - # the presence of floating point errors |
372 | | - fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] |
373 | | - else: |
374 | | - fps = 1 + threshold_idxs - tps |
375 | | - |
376 | | - return fps, tps, pred[threshold_idxs] |
377 | | - |
378 | | - |
379 | | -# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py |
380 | | -def __roc( |
381 | | - pred: torch.Tensor, |
382 | | - target: torch.Tensor, |
383 | | - sample_weight: Optional[Sequence] = None, |
384 | | - pos_label: int = 1., |
385 | | -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
386 | | - """ |
387 | | - Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. |
388 | | -
|
389 | | - .. warning:: Deprecated |
390 | | -
|
391 | | - Args: |
392 | | - pred: estimated probabilities |
393 | | - target: ground-truth labels |
394 | | - sample_weight: sample weights |
395 | | - pos_label: the label for the positive class |
396 | | -
|
397 | | - Return: |
398 | | - false-positive rate (fpr), true-positive rate (tpr), thresholds |
399 | | -
|
400 | | - Example: |
401 | | -
|
402 | | - >>> x = torch.tensor([0, 1, 2, 3]) |
403 | | - >>> y = torch.tensor([0, 1, 1, 1]) |
404 | | - >>> fpr, tpr, thresholds = __roc(x, y) |
405 | | - >>> fpr |
406 | | - tensor([0., 0., 0., 0., 1.]) |
407 | | - >>> tpr |
408 | | - tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) |
409 | | - >>> thresholds |
410 | | - tensor([4, 3, 2, 1, 0]) |
411 | | -
|
412 | | - """ |
413 | | - fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, |
414 | | - sample_weight=sample_weight, |
415 | | - pos_label=pos_label) |
416 | | - |
417 | | - # Add an extra threshold position |
418 | | - # to make sure that the curve starts at (0, 0) |
419 | | - tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) |
420 | | - fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) |
421 | | - thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) |
422 | | - |
423 | | - if fps[-1] <= 0: |
424 | | - raise ValueError("No negative samples in targets, false positive value should be meaningless") |
425 | | - |
426 | | - fpr = fps / fps[-1] |
427 | | - |
428 | | - if tps[-1] <= 0: |
429 | | - raise ValueError("No positive samples in targets, true positive value should be meaningless") |
430 | | - |
431 | | - tpr = tps / tps[-1] |
432 | | - |
433 | | - return fpr, tpr, thresholds |
434 | | - |
435 | | - |
436 | 337 | # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py |
437 | 338 | def __multiclass_roc( |
438 | 339 | pred: torch.Tensor, |
@@ -474,7 +375,7 @@ def __multiclass_roc( |
474 | 375 | for c in range(num_classes): |
475 | 376 | pred_c = pred[:, c] |
476 | 377 |
|
477 | | - class_roc_vals.append(__roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c)) |
| 378 | + class_roc_vals.append(roc(preds=pred_c, target=target, sample_weights=sample_weight, pos_label=c, num_classes=1)) |
478 | 379 |
|
479 | 380 | return tuple(class_roc_vals) |
480 | 381 |
|
@@ -589,7 +490,7 @@ def auroc( |
589 | 490 |
|
590 | 491 | @auc_decorator(reorder=True) |
591 | 492 | def _auroc(pred, target, sample_weight, pos_label): |
592 | | - return __roc(pred, target, sample_weight, pos_label) |
| 493 | + return roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, num_classes=1) |
593 | 494 |
|
594 | 495 | return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) |
595 | 496 |
|
@@ -642,7 +543,7 @@ def multiclass_auroc( |
642 | 543 |
|
643 | 544 | @multiclass_auc_decorator(reorder=False) |
644 | 545 | def _multiclass_auroc(pred, target, sample_weight, num_classes): |
645 | | - return __multiclass_roc(pred, target, sample_weight, num_classes) |
| 546 | + return roc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes) |
646 | 547 |
|
647 | 548 | class_aurocs = _multiclass_auroc(pred=pred, target=target, |
648 | 549 | sample_weight=sample_weight, |
|
0 commit comments