1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from distutils .version import LooseVersion
1514from typing import Any , Callable , Optional
1615
17- import torch
18- from torchmetrics import Metric
16+ from torchmetrics import AUROC as _AUROC
1917
20- from pytorch_lightning .metrics .functional .auroc import _auroc_compute , _auroc_update
21- from pytorch_lightning .utilities import rank_zero_warn
18+ from pytorch_lightning .utilities .deprecation import deprecated
2219
2320
24- class AUROC (Metric ):
25- r"""Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC)
26- <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Further_interpretations>`_.
27- Works for both binary, multilabel and multiclass problems. In the case of
28- multiclass, the values will be calculated based on a one-vs-the-rest approach.
29-
30- Forward accepts
31-
32- - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
33- with probabilities, where C is the number of classes.
34-
35- - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels
36-
37- For non-binary input, if the ``preds`` and ``target`` tensor have the same
38- size the input will be interpretated as multilabel and if ``preds`` have one
39- dimension more than the ``target`` tensor the input will be interpretated as
40- multiclass.
41-
42- Args:
43- num_classes: integer with number of classes. Not nessesary to provide
44- for binary problems.
45- pos_label: integer determining the positive class. Default is ``None``
46- which for binary problem is translate to 1. For multiclass problems
47- this argument should not be set as we iteratively change it in the
48- range [0,num_classes-1]
49- average:
50- - ``'macro'`` computes metric for each class and uniformly averages them
51- - ``'weighted'`` computes metric for each class and does a weighted-average,
52- where each class is weighted by their support (accounts for class imbalance)
53- - ``None`` computes and returns the metric per class
54- max_fpr:
55- If not ``None``, calculates standardized partial AUC over the
56- range [0, max_fpr]. Should be a float between 0 and 1.
57- compute_on_step:
58- Forward only calls ``update()`` and return None if this is set to False. default: True
59- dist_sync_on_step:
60- Synchronize metric state across processes at each ``forward()``
61- before returning the value at the step.
62- process_group:
63- Specify the process group on which synchronization is called. default: None (which selects the entire world)
64- dist_sync_fn:
65- Callback that performs the allgather operation on the metric state. When ``None``, DDP
66- will be used to perform the allgather
67-
68- Raises:
69- ValueError:
70- If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``.
71- ValueError:
72- If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``.
73- RuntimeError:
74- If ``PyTorch version`` is ``below 1.6`` since max_fpr requires ``torch.bucketize``
75- which is not available below 1.6.
76- ValueError:
77- If the mode of data (binary, multi-label, multi-class) changes between batches.
78-
79- Example (binary case):
80-
81- >>> from pytorch_lightning.metrics import AUROC
82- >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
83- >>> target = torch.tensor([0, 0, 1, 1, 1])
84- >>> auroc = AUROC(pos_label=1)
85- >>> auroc(preds, target)
86- tensor(0.5000)
87-
88- Example (multiclass case):
89-
90- >>> from pytorch_lightning.metrics import AUROC
91- >>> preds = torch.tensor([[0.90, 0.05, 0.05],
92- ... [0.05, 0.90, 0.05],
93- ... [0.05, 0.05, 0.90],
94- ... [0.85, 0.05, 0.10],
95- ... [0.10, 0.10, 0.80]])
96- >>> target = torch.tensor([0, 1, 1, 2, 2])
97- >>> auroc = AUROC(num_classes=3)
98- >>> auroc(preds, target)
99- tensor(0.7778)
100-
101- """
21+ class AUROC (_AUROC ):
10222
23+ @deprecated (target = _AUROC , ver_deprecate = "1.3.0" , ver_remove = "1.5.0" )
10324 def __init__ (
10425 self ,
10526 num_classes : Optional [int ] = None ,
@@ -111,74 +32,9 @@ def __init__(
11132 process_group : Optional [Any ] = None ,
11233 dist_sync_fn : Callable = None ,
11334 ):
114- super ().__init__ (
115- compute_on_step = compute_on_step ,
116- dist_sync_on_step = dist_sync_on_step ,
117- process_group = process_group ,
118- dist_sync_fn = dist_sync_fn ,
119- )
120-
121- self .num_classes = num_classes
122- self .pos_label = pos_label
123- self .average = average
124- self .max_fpr = max_fpr
125-
126- allowed_average = (None , 'macro' , 'weighted' )
127- if self .average not in allowed_average :
128- raise ValueError (
129- f'Argument `average` expected to be one of the following: { allowed_average } but got { average } '
130- )
131-
132- if self .max_fpr is not None :
133- if (not isinstance (max_fpr , float ) and 0 < max_fpr <= 1 ):
134- raise ValueError (f"`max_fpr` should be a float in range (0, 1], got: { max_fpr } " )
135-
136- if LooseVersion (torch .__version__ ) < LooseVersion ('1.6.0' ):
137- raise RuntimeError (
138- '`max_fpr` argument requires `torch.bucketize` which is not available below PyTorch version 1.6'
139- )
140-
141- self .mode = None
142- self .add_state ("preds" , default = [], dist_reduce_fx = None )
143- self .add_state ("target" , default = [], dist_reduce_fx = None )
144-
145- rank_zero_warn (
146- 'Metric `AUROC` will save all targets and predictions in buffer.'
147- ' For large datasets this may lead to large memory footprint.'
148- )
149-
150- def update (self , preds : torch .Tensor , target : torch .Tensor ):
15135 """
152- Update state with predictions and targets .
36+ This implementation refers to :class:`~torchmetrics.AUROC` .
15337
154- Args:
155- preds: Predictions from model (probabilities, or labels)
156- target: Ground truth labels
157- """
158- preds , target , mode = _auroc_update (preds , target )
159-
160- self .preds .append (preds )
161- self .target .append (target )
162-
163- if self .mode is not None and self .mode != mode :
164- raise ValueError (
165- 'The mode of data (binary, multi-label, multi-class) should be constant, but changed'
166- f' between batches from { self .mode } to { mode } '
167- )
168- self .mode = mode
169-
170- def compute (self ) -> torch .Tensor :
171- """
172- Computes AUROC based on inputs passed in to ``update`` previously.
38+ .. deprecated::
39+ Use :class:`~torchmetrics.AUROC`. Will be removed in v1.5.0.
17340 """
174- preds = torch .cat (self .preds , dim = 0 )
175- target = torch .cat (self .target , dim = 0 )
176- return _auroc_compute (
177- preds ,
178- target ,
179- self .mode ,
180- self .num_classes ,
181- self .pos_label ,
182- self .average ,
183- self .max_fpr ,
184- )
0 commit comments