1818
1919from pytorch_lightning .metrics .functional .auc import auc as __auc
2020from pytorch_lightning .metrics .functional .auroc import auroc as __auroc
21- from pytorch_lightning .metrics .functional .average_precision import average_precision as __ap
2221from pytorch_lightning .metrics .functional .iou import iou as __iou
23- from pytorch_lightning .metrics .functional .precision_recall_curve import _binary_clf_curve
24- from pytorch_lightning .metrics .functional .precision_recall_curve import precision_recall_curve as __prc
25- from pytorch_lightning .metrics .functional .roc import roc as __roc
26- from pytorch_lightning .metrics .utils import class_reduce
27- from pytorch_lightning .metrics .utils import get_num_classes as __gnc
28- from pytorch_lightning .metrics .utils import reduce
29- from pytorch_lightning .metrics .utils import to_categorical as __tc
30- from pytorch_lightning .metrics .utils import to_onehot as __to
22+ from pytorch_lightning .metrics .utils import class_reduce , get_num_classes , reduce , to_categorical
3123from pytorch_lightning .utilities import rank_zero_warn
3224
3325
34- def to_onehot (
35- tensor : torch .Tensor ,
36- num_classes : Optional [int ] = None ,
37- ) -> torch .Tensor :
38- """
39- Converts a dense label tensor to one-hot format
40-
41- .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_onehot`
42- """
43- rank_zero_warn (
44- "This `to_onehot` was deprecated in v1.1.0 in favor of"
45- " `from pytorch_lightning.metrics.utils import to_onehot`."
46- " It will be removed in v1.3.0" , DeprecationWarning
47- )
48- return __to (tensor , num_classes )
49-
50-
51- def to_categorical (tensor : torch .Tensor , argmax_dim : int = 1 ) -> torch .Tensor :
52- """
53- Converts a tensor of probabilities to a dense label tensor
54-
55- .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_categorical`
56-
57- """
58- rank_zero_warn (
59- "This `to_categorical` was deprecated in v1.1.0 in favor of"
60- " `from pytorch_lightning.metrics.utils import to_categorical`."
61- " It will be removed in v1.3.0" , DeprecationWarning
62- )
63- return __tc (tensor )
64-
65-
66- def get_num_classes (
67- pred : torch .Tensor ,
68- target : torch .Tensor ,
69- num_classes : Optional [int ] = None ,
70- ) -> int :
71- """
72- Calculates the number of classes for a given prediction and target tensor.
73-
74- .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.get_num_classes`
75-
76- """
77- rank_zero_warn (
78- "This `get_num_classes` was deprecated in v1.1.0 in favor of"
79- " `from pytorch_lightning.metrics.utils import get_num_classes`."
80- " It will be removed in v1.3.0" , DeprecationWarning
81- )
82- return __gnc (pred , target , num_classes )
83-
84-
8526def stat_scores (
8627 pred : torch .Tensor ,
8728 target : torch .Tensor ,
@@ -122,6 +63,7 @@ def stat_scores(
12263 return tp , fp , tn , fn , sup
12364
12465
66+ # todo: remove in 1.4
12567def stat_scores_multiple_classes (
12668 pred : torch .Tensor ,
12769 target : torch .Tensor ,
@@ -210,6 +152,7 @@ def _confmat_normalize(cm):
210152 return cm
211153
212154
155+ # todo: remove in 1.4
213156def precision_recall (
214157 pred : torch .Tensor ,
215158 target : torch .Tensor ,
@@ -268,6 +211,7 @@ def precision_recall(
268211 return precision , recall
269212
270213
214+ # todo: remove in 1.4
271215def precision (
272216 pred : torch .Tensor ,
273217 target : torch .Tensor ,
@@ -311,6 +255,7 @@ def precision(
311255 return precision_recall (pred = pred , target = target , num_classes = num_classes , class_reduction = class_reduction )[0 ]
312256
313257
258+ # todo: remove in 1.4
314259def recall (
315260 pred : torch .Tensor ,
316261 target : torch .Tensor ,
@@ -353,128 +298,7 @@ def recall(
353298 return precision_recall (pred = pred , target = target , num_classes = num_classes , class_reduction = class_reduction )[1 ]
354299
355300
356- # todo: remove in 1.3
357- def roc (
358- pred : torch .Tensor ,
359- target : torch .Tensor ,
360- sample_weight : Optional [Sequence ] = None ,
361- pos_label : int = 1. ,
362- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
363- """
364- Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
365-
366- .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
367- """
368- rank_zero_warn (
369- "This `multiclass_roc` was deprecated in v1.1.0 in favor of"
370- " `from pytorch_lightning.metrics.functional.roc import roc`."
371- " It will be removed in v1.3.0" , DeprecationWarning
372- )
373- return __roc (preds = pred , target = target , sample_weights = sample_weight , pos_label = pos_label )
374-
375-
376- # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
377- def _roc (
378- pred : torch .Tensor ,
379- target : torch .Tensor ,
380- sample_weight : Optional [Sequence ] = None ,
381- pos_label : int = 1. ,
382- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
383- """
384- Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
385-
386- .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
387-
388- Example:
389-
390- >>> x = torch.tensor([0, 1, 2, 3])
391- >>> y = torch.tensor([0, 1, 1, 1])
392- >>> fpr, tpr, thresholds = _roc(x, y)
393- >>> fpr
394- tensor([0., 0., 0., 0., 1.])
395- >>> tpr
396- tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
397- >>> thresholds
398- tensor([4, 3, 2, 1, 0])
399-
400- """
401- rank_zero_warn (
402- "This `multiclass_roc` was deprecated in v1.1.0 in favor of"
403- " `from pytorch_lightning.metrics.functional.roc import roc`."
404- " It will be removed in v1.3.0" , DeprecationWarning
405- )
406- fps , tps , thresholds = _binary_clf_curve (pred , target , sample_weights = sample_weight , pos_label = pos_label )
407-
408- # Add an extra threshold position
409- # to make sure that the curve starts at (0, 0)
410- tps = torch .cat ([torch .zeros (1 , dtype = tps .dtype , device = tps .device ), tps ])
411- fps = torch .cat ([torch .zeros (1 , dtype = fps .dtype , device = fps .device ), fps ])
412- thresholds = torch .cat ([thresholds [0 ][None ] + 1 , thresholds ])
413-
414- if fps [- 1 ] <= 0 :
415- raise ValueError ("No negative samples in targets, false positive value should be meaningless" )
416-
417- fpr = fps / fps [- 1 ]
418-
419- if tps [- 1 ] <= 0 :
420- raise ValueError ("No positive samples in targets, true positive value should be meaningless" )
421-
422- tpr = tps / tps [- 1 ]
423-
424- return fpr , tpr , thresholds
425-
426-
427- # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
428- def multiclass_roc (
429- pred : torch .Tensor ,
430- target : torch .Tensor ,
431- sample_weight : Optional [Sequence ] = None ,
432- num_classes : Optional [int ] = None ,
433- ) -> Tuple [Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]]:
434- """
435- Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
436-
437- .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
438-
439- Args:
440- pred: estimated probabilities
441- target: ground-truth labels
442- sample_weight: sample weights
443- num_classes: number of classes (default: None, computes automatically from data)
444-
445- Return:
446- returns roc for each class.
447- Number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
448-
449- Example:
450-
451- >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
452- ... [0.05, 0.85, 0.05, 0.05],
453- ... [0.05, 0.05, 0.85, 0.05],
454- ... [0.05, 0.05, 0.05, 0.85]])
455- >>> target = torch.tensor([0, 1, 3, 2])
456- >>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE
457- ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
458- (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
459- (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
460- (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
461- """
462- rank_zero_warn (
463- "This `multiclass_roc` was deprecated in v1.1.0 in favor of"
464- " `from pytorch_lightning.metrics.functional.roc import roc`."
465- " It will be removed in v1.3.0" , DeprecationWarning
466- )
467- num_classes = get_num_classes (pred , target , num_classes )
468-
469- class_roc_vals = []
470- for c in range (num_classes ):
471- pred_c = pred [:, c ]
472-
473- class_roc_vals .append (_roc (pred = pred_c , target = target , sample_weight = sample_weight , pos_label = c ))
474-
475- return tuple (class_roc_vals )
476-
477-
301+ # todo: remove in 1.4
478302def auc (
479303 x : torch .Tensor ,
480304 y : torch .Tensor ,
@@ -508,6 +332,7 @@ def auc(
508332 return __auc (x , y )
509333
510334
335+ # todo: remove in 1.4
511336def auc_decorator () -> Callable :
512337 rank_zero_warn ("This `auc_decorator` was deprecated in v1.2.0." " It will be removed in v1.4.0" , DeprecationWarning )
513338
@@ -524,6 +349,7 @@ def new_func(*args, **kwargs) -> torch.Tensor:
524349 return wrapper
525350
526351
352+ # todo: remove in 1.4
527353def multiclass_auc_decorator () -> Callable :
528354 rank_zero_warn (
529355 "This `multiclass_auc_decorator` was deprecated in v1.2.0."
@@ -546,6 +372,7 @@ def new_func(*args, **kwargs) -> torch.Tensor:
546372 return wrapper
547373
548374
375+ # todo: remove in 1.4
549376def auroc (
550377 pred : torch .Tensor ,
551378 target : torch .Tensor ,
@@ -588,6 +415,7 @@ def auroc(
588415 )
589416
590417
418+ # todo: remove in 1.4
591419def multiclass_auroc (
592420 pred : torch .Tensor ,
593421 target : torch .Tensor ,
@@ -767,68 +595,3 @@ def iou(
767595 num_classes = num_classes ,
768596 reduction = reduction
769597 )
770-
771-
772- # todo: remove in 1.3
773- def precision_recall_curve (
774- pred : torch .Tensor ,
775- target : torch .Tensor ,
776- sample_weight : Optional [Sequence ] = None ,
777- pos_label : int = 1. ,
778- ):
779- """
780- Computes precision-recall pairs for different thresholds.
781-
782- .. warning :: Deprecated in favor of
783- :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
784- """
785- rank_zero_warn (
786- "This `precision_recall_curve` was deprecated in v1.1.0 in favor of"
787- " `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
788- " It will be removed in v1.3.0" , DeprecationWarning
789- )
790- return __prc (preds = pred , target = target , sample_weights = sample_weight , pos_label = pos_label )
791-
792-
793- # todo: remove in 1.3
794- def multiclass_precision_recall_curve (
795- pred : torch .Tensor ,
796- target : torch .Tensor ,
797- sample_weight : Optional [Sequence ] = None ,
798- num_classes : Optional [int ] = None ,
799- ):
800- """
801- Computes precision-recall pairs for different thresholds given a multiclass scores.
802-
803- .. warning :: Deprecated in favor of
804- :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
805- """
806- rank_zero_warn (
807- "This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of"
808- " `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
809- " It will be removed in v1.3.0" , DeprecationWarning
810- )
811- if num_classes is None :
812- num_classes = get_num_classes (pred , target , num_classes )
813- return __prc (preds = pred , target = target , sample_weights = sample_weight , num_classes = num_classes )
814-
815-
816- # todo: remove in 1.3
817- def average_precision (
818- pred : torch .Tensor ,
819- target : torch .Tensor ,
820- sample_weight : Optional [Sequence ] = None ,
821- pos_label : int = 1. ,
822- ):
823- """
824- Compute average precision from prediction scores.
825-
826- .. warning :: Deprecated in favor of
827- :func:`~pytorch_lightning.metrics.functional.average_precision.average_precision`
828- """
829- rank_zero_warn (
830- "This `average_precision` was deprecated in v1.1.0 in favor of"
831- " `pytorch_lightning.metrics.functional.average_precision import average_precision`."
832- " It will be removed in v1.3.0" , DeprecationWarning
833- )
834- return __ap (preds = pred , target = target , sample_weights = sample_weight , pos_label = pos_label )
0 commit comments