1313# limitations under the License.
1414from typing import Any , Optional
1515
16- import torch
17- from torchmetrics import Metric
16+ from torchmetrics import F1 as _F1
17+ from torchmetrics import FBeta as _FBeta
1818
19- from pytorch_lightning .metrics .functional .f_beta import _fbeta_compute , _fbeta_update
20- from pytorch_lightning .utilities import rank_zero_warn
19+ from pytorch_lightning .utilities .deprecation import deprecated
2120
2221
23- class FBeta (Metric ):
24- r"""
25- Computes `F-score <https://en.wikipedia.org/wiki/F-score>`_, specifically:
26-
27- .. math::
28- F_\beta = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
29- {(\beta^2 * \text{precision}) + \text{recall}}
30-
31- Where :math:`\beta` is some positive real factor. Works with binary, multiclass, and multilabel data.
32- Accepts probabilities from a model output or integer class values in prediction.
33- Works with multi-dimensional preds and target.
34-
35- Forward accepts
36-
37- - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
38- - ``target`` (long tensor): ``(N, ...)``
39-
40- If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
41- to convert into integer labels. This is the case for binary and multi-label probabilities.
42-
43- If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
44-
45- Args:
46- num_classes: Number of classes in the dataset.
47- beta: Beta coefficient in the F measure.
48- threshold:
49- Threshold value for binary or multi-label probabilities. default: 0.5
50-
51- average:
52- - ``'micro'`` computes metric globally
53- - ``'macro'`` computes metric for each class and uniformly averages them
54- - ``'weighted'`` computes metric for each class and does a weighted-average,
55- where each class is weighted by their support (accounts for class imbalance)
56- - ``'none'`` or ``None`` computes and returns the metric per class
57-
58- multilabel: If predictions are from multilabel classification.
59- compute_on_step:
60- Forward only calls ``update()`` and return None if this is set to False. default: True
61- dist_sync_on_step:
62- Synchronize metric state across processes at each ``forward()``
63- before returning the value at the step. default: False
64- process_group:
65- Specify the process group on which synchronization is called. default: None (which selects the entire world)
66-
67- Raises:
68- ValueError:
69- If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``.
70-
71- Example:
72-
73- >>> from pytorch_lightning.metrics import FBeta
74- >>> target = torch.tensor([0, 1, 2, 0, 1, 2])
75- >>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
76- >>> f_beta = FBeta(num_classes=3, beta=0.5)
77- >>> f_beta(preds, target)
78- tensor(0.3333)
79-
80- """
22+ class FBeta (_FBeta ):
8123
24+ @deprecated (target = _FBeta , ver_deprecate = "1.3.0" , ver_remove = "1.5.0" )
8225 def __init__ (
8326 self ,
8427 num_classes : int ,
@@ -90,103 +33,17 @@ def __init__(
9033 dist_sync_on_step : bool = False ,
9134 process_group : Optional [Any ] = None ,
9235 ):
93- super ().__init__ (
94- compute_on_step = compute_on_step ,
95- dist_sync_on_step = dist_sync_on_step ,
96- process_group = process_group ,
97- )
98-
99- self .num_classes = num_classes
100- self .beta = beta
101- self .threshold = threshold
102- self .average = average
103- self .multilabel = multilabel
104-
105- allowed_average = ("micro" , "macro" , "weighted" , "none" , None )
106- if self .average not in allowed_average :
107- raise ValueError (
108- 'Argument `average` expected to be one of the following:'
109- f' { allowed_average } but got { self .average } '
110- )
111-
112- self .add_state ("true_positives" , default = torch .zeros (num_classes ), dist_reduce_fx = "sum" )
113- self .add_state ("predicted_positives" , default = torch .zeros (num_classes ), dist_reduce_fx = "sum" )
114- self .add_state ("actual_positives" , default = torch .zeros (num_classes ), dist_reduce_fx = "sum" )
115-
116- def update (self , preds : torch .Tensor , target : torch .Tensor ):
117- """
118- Update state with predictions and targets.
119-
120- Args:
121- preds: Predictions from model
122- target: Ground truth values
12336 """
124- true_positives , predicted_positives , actual_positives = _fbeta_update (
125- preds , target , self .num_classes , self .threshold , self .multilabel
126- )
127-
128- self .true_positives += true_positives
129- self .predicted_positives += predicted_positives
130- self .actual_positives += actual_positives
37+ This implementation refers to :class:`~torchmetrics.FBeta`.
13138
132- def compute (self ) -> torch .Tensor :
39+ .. deprecated::
40+ Use :class:`~torchmetrics.FBeta`. Will be removed in v1.5.0.
13341 """
134- Computes fbeta over state.
135- """
136- return _fbeta_compute (
137- self .true_positives , self .predicted_positives , self .actual_positives , self .beta , self .average
138- )
139-
140-
141- class F1 (FBeta ):
142- """
143- Computes F1 metric. F1 metrics correspond to a harmonic mean of the
144- precision and recall scores.
145-
146- Works with binary, multiclass, and multilabel data.
147- Accepts logits from a model output or integer class values in prediction.
148- Works with multi-dimensional preds and target.
14942
150- Forward accepts
15143
152- - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
153- - ``target`` (long tensor): ``(N, ...)``
154-
155- If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
156- This is the case for binary and multi-label logits.
157-
158- If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
159-
160- Args:
161- num_classes: Number of classes in the dataset.
162- threshold:
163- Threshold value for binary or multi-label logits. default: 0.5
164-
165- average:
166- - ``'micro'`` computes metric globally
167- - ``'macro'`` computes metric for each class and uniformly averages them
168- - ``'weighted'`` computes metric for each class and does a weighted-average,
169- where each class is weighted by their support (accounts for class imbalance)
170- - ``'none'`` or ``None`` computes and returns the metric per class
171-
172- multilabel: If predictions are from multilabel classification.
173- compute_on_step:
174- Forward only calls ``update()`` and returns None if this is set to False. default: True
175- dist_sync_on_step:
176- Synchronize metric state across processes at each ``forward()``
177- before returning the value at the step. default: False
178- process_group:
179- Specify the process group on which synchronization is called. default: None (which selects the entire world)
180-
181- Example:
182- >>> from pytorch_lightning.metrics import F1
183- >>> target = torch.tensor([0, 1, 2, 0, 1, 2])
184- >>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
185- >>> f1 = F1(num_classes=3)
186- >>> f1(preds, target)
187- tensor(0.3333)
188- """
44+ class F1 (_F1 ):
18945
46+ @deprecated (target = _F1 , ver_deprecate = "1.3.0" , ver_remove = "1.5.0" )
19047 def __init__ (
19148 self ,
19249 num_classes : int ,
@@ -197,16 +54,9 @@ def __init__(
19754 dist_sync_on_step : bool = False ,
19855 process_group : Optional [Any ] = None ,
19956 ):
200- if multilabel is not False :
201- rank_zero_warn ( f'The `multilabel= { multilabel } ` parameter is unused and will not have any effect.' )
57+ """
58+ This implementation refers to :class:`~torchmetrics.F1`.
20259
203- super ().__init__ (
204- num_classes = num_classes ,
205- beta = 1.0 ,
206- threshold = threshold ,
207- average = average ,
208- multilabel = multilabel ,
209- compute_on_step = compute_on_step ,
210- dist_sync_on_step = dist_sync_on_step ,
211- process_group = process_group ,
212- )
60+ .. deprecated::
61+ Use :class:`~torchmetrics.F1`. Will be removed in v1.5.0.
62+ """
0 commit comments