1515
1616import torch
1717
18- from pytorch_lightning . metrics . classification . stat_scores import StatScores
19- from pytorch_lightning . metrics . functional . precision_recall import _precision_compute , _recall_compute
18+ from torchmetrics import Precision as _Precision
19+ from torchmetrics import Recall as _Recall
2020
21+ from pytorch_lightning .utilities .deprecation import deprecated
2122
22- class Precision (StatScores ):
23- r"""
24- Computes `Precision <https://en.wikipedia.org/wiki/Precision_and_recall>`_:
2523
26- .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}
27-
28- Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and
29- false positives respecitively. With the use of ``top_k`` parameter, this metric can
30- generalize to Precision@K.
31-
32- The reduction method (how the precision scores are aggregated) is controlled by the
33- ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
34- multi-dimensional multi-class case.
35-
36- Args:
37- num_classes:
38- Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
39- threshold:
40- Threshold probability value for transforming probability predictions to binary
41- (0,1) predictions, in the case of binary or multi-label inputs.
42- average:
43- Defines the reduction that is applied. Should be one of the following:
44-
45- - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes.
46- - ``'macro'``: Calculate the metric for each class separately, and average the
47- metrics accross classes (with equal weights for each class).
48- - ``'weighted'``: Calculate the metric for each class separately, and average the
49- metrics accross classes, weighting each class by its support (``tp + fn``).
50- - ``'none'`` or ``None``: Calculate the metric for each class separately, and return
51- the metric for every class.
52- - ``'samples'``: Calculate the metric for each sample, and average the metrics
53- across samples (with equal weights for each sample).
54-
55- Note that what is considered a sample in the multi-dimensional multi-class case
56- depends on the value of ``mdmc_average``.
57- multilabel:
58- .. warning :: This parameter is deprecated and has no effect. Will be removed in v1.4.0.
59-
60- mdmc_average:
61- Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
62- ``average`` parameter). Should be one of the following:
63-
64- - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
65- multi-class.
66-
67- - ``'samplewise'``: In this case, the statistics are computed separately for each
68- sample on the ``N`` axis, and then averaged over samples.
69- The computation for each sample is done by treating the flattened extra axes ``...``
70- as the ``N`` dimension within the sample, and computing the metric for the sample based on that.
71-
72- - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
73- are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
74- were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
75-
76- ignore_index:
77- Integer specifying a target class to ignore. If given, this class index does not contribute
78- to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
79- or ``'none'``, the score for the ignored class will be returned as ``nan``.
80-
81- top_k:
82- Number of highest probability entries for each sample to convert to 1s - relevant
83- only for inputs with probability predictions. If this parameter is set for multi-label
84- inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs,
85- this parameter defaults to 1.
86-
87- Should be left unset (``None``) for inputs with label predictions.
88- is_multiclass:
89- Used only in certain special cases, where you want to treat inputs as a different type
90- than what they appear to be.
91-
92- compute_on_step:
93- Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
94- dist_sync_on_step:
95- Synchronize metric state across processes at each ``forward()``
96- before returning the value at the step
97- process_group:
98- Specify the process group on which synchronization is called.
99- default: ``None`` (which selects the entire world)
100- dist_sync_fn:
101- Callback that performs the allgather operation on the metric state. When ``None``, DDP
102- will be used to perform the allgather.
103-
104- Raises:
105- ValueError:
106- If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``.
107-
108- Example:
109-
110- >>> from pytorch_lightning.metrics import Precision
111- >>> preds = torch.tensor([2, 0, 2, 1])
112- >>> target = torch.tensor([1, 1, 2, 0])
113- >>> precision = Precision(average='macro', num_classes=3)
114- >>> precision(preds, target)
115- tensor(0.1667)
116- >>> precision = Precision(average='micro')
117- >>> precision(preds, target)
118- tensor(0.2500)
119-
120- """
24+ class Precision (_Precision ):
12125
26+ @deprecated (target = _Precision , ver_deprecate = "1.3.0" , ver_remove = "1.5.0" )
12227 def __init__ (
12328 self ,
12429 num_classes : Optional [int ] = None ,
@@ -134,142 +39,17 @@ def __init__(
13439 process_group : Optional [Any ] = None ,
13540 dist_sync_fn : Callable = None ,
13641 ):
137- allowed_average = ["micro" , "macro" , "weighted" , "samples" , "none" , None ]
138- if average not in allowed_average :
139- raise ValueError (f"The `average` has to be one of { allowed_average } , got { average } ." )
140-
141- super ().__init__ (
142- reduce = "macro" if average in ["weighted" , "none" , None ] else average ,
143- mdmc_reduce = mdmc_average ,
144- threshold = threshold ,
145- top_k = top_k ,
146- num_classes = num_classes ,
147- is_multiclass = is_multiclass ,
148- ignore_index = ignore_index ,
149- compute_on_step = compute_on_step ,
150- dist_sync_on_step = dist_sync_on_step ,
151- process_group = process_group ,
152- dist_sync_fn = dist_sync_fn ,
153- )
154-
155- self .average = average
156-
157- def compute (self ) -> torch .Tensor :
15842 """
159- Computes the precision score based on inputs passed in to ``update`` previously .
43+ This implementation refers to :class:`~torchmetrics.Precision` .
16044
161- Return:
162- The shape of the returned tensor depends on the ``average`` parameter
163-
164- - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned
165- - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
166- of classes
45+ .. deprecated::
46+ Use :class:`~torchmetrics.Precision`. Will be removed in v1.5.0.
16747 """
168- tp , fp , tn , fn = self ._get_final_stats ()
169- return _precision_compute (tp , fp , tn , fn , self .average , self .mdmc_reduce )
170-
17148
172- class Recall (StatScores ):
173- r"""
174- Computes `Recall <https://en.wikipedia.org/wiki/Precision_and_recall>`_:
17549
176- .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}
177-
178- Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and
179- false negatives respecitively. With the use of ``top_k`` parameter, this metric can
180- generalize to Recall@K.
181-
182- The reduction method (how the recall scores are aggregated) is controlled by the
183- ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
184- multi-dimensional multi-class case.
185-
186- Args:
187- num_classes:
188- Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
189- threshold:
190- Threshold probability value for transforming probability predictions to binary
191- (0,1) predictions, in the case of binary or multi-label inputs.
192- average:
193- Defines the reduction that is applied. Should be one of the following:
194-
195- - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes.
196- - ``'macro'``: Calculate the metric for each class separately, and average the
197- metrics accross classes (with equal weights for each class).
198- - ``'weighted'``: Calculate the metric for each class separately, and average the
199- metrics accross classes, weighting each class by its support (``tp + fn``).
200- - ``'none'`` or ``None``: Calculate the metric for each class separately, and return
201- the metric for every class.
202- - ``'samples'``: Calculate the metric for each sample, and average the metrics
203- across samples (with equal weights for each sample).
204-
205- Note that what is considered a sample in the multi-dimensional multi-class case
206- depends on the value of ``mdmc_average``.
207- multilabel:
208- .. warning :: This parameter is deprecated and has no effect. Will be removed in v1.4.0.
209-
210- mdmc_average:
211- Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
212- ``average`` parameter). Should be one of the following:
213-
214- - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
215- multi-class.
216-
217- - ``'samplewise'``: In this case, the statistics are computed separately for each
218- sample on the ``N`` axis, and then averaged over samples.
219- The computation for each sample is done by treating the flattened extra axes ``...``
220- as the ``N`` dimension within the sample, and computing the metric for the sample based on that.
221-
222- - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
223- are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
224- were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
225-
226- ignore_index:
227- Integer specifying a target class to ignore. If given, this class index does not contribute
228- to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
229- or ``'none'``, the score for the ignored class will be returned as ``nan``.
230-
231- top_k:
232- Number of highest probability entries for each sample to convert to 1s - relevant
233- only for inputs with probability predictions. If this parameter is set for multi-label
234- inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs,
235- this parameter defaults to 1.
236-
237- Should be left unset (``None``) for inputs with label predictions.
238-
239- is_multiclass:
240- Used only in certain special cases, where you want to treat inputs as a different type
241- than what they appear to be.
242-
243- compute_on_step:
244- Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
245- dist_sync_on_step:
246- Synchronize metric state across processes at each ``forward()``
247- before returning the value at the step
248- process_group:
249- Specify the process group on which synchronization is called.
250- default: ``None`` (which selects the entire world)
251- dist_sync_fn:
252- Callback that performs the allgather operation on the metric state. When ``None``, DDP
253- will be used to perform the allgather.
254-
255- Raises:
256- ValueError:
257- If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``.
258-
259- Example:
260-
261- >>> from pytorch_lightning.metrics import Recall
262- >>> preds = torch.tensor([2, 0, 2, 1])
263- >>> target = torch.tensor([1, 1, 2, 0])
264- >>> recall = Recall(average='macro', num_classes=3)
265- >>> recall(preds, target)
266- tensor(0.3333)
267- >>> recall = Recall(average='micro')
268- >>> recall(preds, target)
269- tensor(0.2500)
270-
271- """
50+ class Recall (_Recall ):
27251
52+ @deprecated (target = _Recall , ver_deprecate = "1.3.0" , ver_remove = "1.5.0" )
27353 def __init__ (
27454 self ,
27555 num_classes : Optional [int ] = None ,
@@ -285,36 +65,9 @@ def __init__(
28565 process_group : Optional [Any ] = None ,
28666 dist_sync_fn : Callable = None ,
28767 ):
288- allowed_average = ["micro" , "macro" , "weighted" , "samples" , "none" , None ]
289- if average not in allowed_average :
290- raise ValueError (f"The `average` has to be one of { allowed_average } , got { average } ." )
291-
292- super ().__init__ (
293- reduce = "macro" if average in ["weighted" , "none" , None ] else average ,
294- mdmc_reduce = mdmc_average ,
295- threshold = threshold ,
296- top_k = top_k ,
297- num_classes = num_classes ,
298- is_multiclass = is_multiclass ,
299- ignore_index = ignore_index ,
300- compute_on_step = compute_on_step ,
301- dist_sync_on_step = dist_sync_on_step ,
302- process_group = process_group ,
303- dist_sync_fn = dist_sync_fn ,
304- )
305-
306- self .average = average
307-
308- def compute (self ) -> torch .Tensor :
30968 """
310- Computes the recall score based on inputs passed in to ``update`` previously.
311-
312- Return:
313- The shape of the returned tensor depends on the ``average`` parameter
69+ This implementation refers to :class:`~torchmetrics.Recall`.
31470
315- - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned
316- - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
317- of classes
71+ .. deprecated::
72+ Use :class:`~torchmetrics.Recall`. Will be removed in v1.5.0.
31873 """
319- tp , fp , tn , fn = self ._get_final_stats ()
320- return _recall_compute (tp , fp , tn , fn , self .average , self .mdmc_reduce )
0 commit comments