1616import torch
1717
1818from pytorch_lightning .metrics .metric import Metric
19- from pytorch_lightning .metrics .utils import _input_format_classification
19+ from pytorch_lightning .metrics .functional . accuracy import _accuracy_update , _accuracy_compute
2020
2121
2222class Accuracy (Metric ):
2323 r"""
2424 Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_:
2525
26- .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y_i})
26+ .. math::
27+ \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)
2728
2829 Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a
29- tensor of predictions. Works with binary, multiclass, and multilabel
30- data. Accepts logits from a model output or integer class values in
31- prediction. Works with multi-dimensional preds and target.
30+ tensor of predictions.
3231
33- Forward accepts
32+ For multi-class and multi-dimensional multi-class data with probability predictions, the
33+ parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the
34+ top-K highest probability items are considered to find the correct label.
3435
35- - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
36- - ``target`` (long tensor): ``(N, ...)``
36+ For multi-label and multi-dimensional multi-class inputs, this metric computes the "global"
37+ accuracy by default, which counts all labels or sub-samples separately. This can be
38+ changed to subset accuracy (which requires all labels or sub-samples in the sample to
39+ be correctly predicted) by setting ``subset_accuracy=True``.
3740
38- If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
39- This is the case for binary and multi-label logits.
40-
41- If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
41+ Accepts all input types listed in :ref:`metrics:Input types`.
4242
4343 Args:
4444 threshold:
45- Threshold value for binary or multi-label logits. default: 0.5
45+ Threshold probability value for transforming probability predictions to binary
46+ `(0,1)` predictions, in the case of binary or multi-label inputs.
47+ top_k:
48+ Number of highest probability predictions considered to find the correct label, relevant
49+ only for (multi-dimensional) multi-class inputs with probability predictions. The
50+ default value (``None``) will be interpreted as 1 for these inputs.
51+
52+ Should be left at default (``None``) for all other types of inputs.
53+ subset_accuracy:
54+ Whether to compute subset accuracy for multi-label and multi-dimensional
55+ multi-class inputs (has no effect for other input types).
56+
57+ For multi-label inputs, if the parameter is set to `True`, then all labels for
58+ each sample must be correctly predicted for the sample to count as correct. If it
59+ is set to `False`, then all labels are counted separately - this is equivalent to
60+ flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).
61+
62+ For multi-dimensional multi-class inputs, if the parameter is set to `True`, then all
63+ sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
64+ If it is set to `False`, then all sub-samples are counter separately - this is equivalent,
65+ in the case of label predictions, to flattening the inputs beforehand (i.e.
66+ ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
67+ still applies in both cases, if set.
4668 compute_on_step:
47- Forward only calls ``update()`` and return None if this is set to False. default: True
69+ Forward only calls ``update()`` and return None if this is set to False.
4870 dist_sync_on_step:
4971 Synchronize metric state across processes at each ``forward()``
5072 before returning the value at the step. default: False
@@ -63,10 +85,19 @@ class Accuracy(Metric):
6385 >>> accuracy(preds, target)
6486 tensor(0.5000)
6587
88+ >>> target = torch.tensor([0, 1, 2])
89+ >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
90+ >>> accuracy = Accuracy(top_k=2)
91+ >>> accuracy(preds, target)
92+ tensor(0.6667)
93+
6694 """
95+
6796 def __init__ (
6897 self ,
6998 threshold : float = 0.5 ,
99+ top_k : Optional [int ] = None ,
100+ subset_accuracy : bool = False ,
70101 compute_on_step : bool = True ,
71102 dist_sync_on_step : bool = False ,
72103 process_group : Optional [Any ] = None ,
@@ -82,24 +113,35 @@ def __init__(
82113 self .add_state ("correct" , default = torch .tensor (0 ), dist_reduce_fx = "sum" )
83114 self .add_state ("total" , default = torch .tensor (0 ), dist_reduce_fx = "sum" )
84115
116+ if not 0 <= threshold <= 1 :
117+ raise ValueError ("The `threshold` should lie in the [0,1] interval." )
118+
119+ if top_k is not None and top_k <= 0 :
120+ raise ValueError ("The `top_k` should be an integer larger than 1." )
121+
85122 self .threshold = threshold
123+ self .top_k = top_k
124+ self .subset_accuracy = subset_accuracy
86125
87126 def update (self , preds : torch .Tensor , target : torch .Tensor ):
88127 """
89- Update state with predictions and targets.
128+ Update state with predictions and targets. See :ref:`metrics:Input types` for more information
129+ on input types.
90130
91131 Args:
92- preds: Predictions from model
93- target: Ground truth values
132+ preds: Predictions from model (probabilities, or labels)
133+ target: Ground truth labels
94134 """
95- preds , target = _input_format_classification (preds , target , self .threshold )
96- assert preds .shape == target .shape
97135
98- self .correct += torch .sum (preds == target )
99- self .total += target .numel ()
136+ correct , total = _accuracy_update (
137+ preds , target , threshold = self .threshold , top_k = self .top_k , subset_accuracy = self .subset_accuracy
138+ )
139+
140+ self .correct += correct
141+ self .total += total
100142
101- def compute (self ):
143+ def compute (self ) -> torch . Tensor :
102144 """
103- Computes accuracy over state .
145+ Computes accuracy based on inputs passed in to ``update`` previously .
104146 """
105- return self .correct . float () / self .total
147+ return _accuracy_compute ( self .correct , self .total )
0 commit comments