Skip to content

Commit 1867e61

Browse files
committed
prune accuracy
1 parent b341b53 commit 1867e61

File tree

4 files changed

+40
-382
lines changed

4 files changed

+40
-382
lines changed

pytorch_lightning/metrics/classification/accuracy.py

Lines changed: 12 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -14,89 +14,16 @@
1414
from typing import Any, Callable, Optional
1515

1616
import torch
17+
from torchmetrics import Accuracy as __Accuracy
1718

18-
from pytorch_lightning.metrics.functional.accuracy import _accuracy_compute, _accuracy_update
19-
from pytorch_lightning.metrics.metric import Metric
19+
from pytorch_lightning.utilities import rank_zero_warn
2020

2121

22-
class Accuracy(Metric):
22+
class Accuracy(__Accuracy):
2323
r"""
24-
Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`__:
25-
26-
.. math::
27-
\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)
28-
29-
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a
30-
tensor of predictions.
31-
32-
For multi-class and multi-dimensional multi-class data with probability predictions, the
33-
parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the
34-
top-K highest probability items are considered to find the correct label.
35-
36-
For multi-label and multi-dimensional multi-class inputs, this metric computes the "global"
37-
accuracy by default, which counts all labels or sub-samples separately. This can be
38-
changed to subset accuracy (which requires all labels or sub-samples in the sample to
39-
be correctly predicted) by setting ``subset_accuracy=True``.
40-
41-
Args:
42-
threshold:
43-
Threshold probability value for transforming probability predictions to binary
44-
(0,1) predictions, in the case of binary or multi-label inputs.
45-
top_k:
46-
Number of highest probability predictions considered to find the correct label, relevant
47-
only for (multi-dimensional) multi-class inputs with probability predictions. The
48-
default value (``None``) will be interpreted as 1 for these inputs.
49-
50-
Should be left at default (``None``) for all other types of inputs.
51-
subset_accuracy:
52-
Whether to compute subset accuracy for multi-label and multi-dimensional
53-
multi-class inputs (has no effect for other input types).
54-
55-
- For multi-label inputs, if the parameter is set to ``True``, then all labels for
56-
each sample must be correctly predicted for the sample to count as correct. If it
57-
is set to ``False``, then all labels are counted separately - this is equivalent to
58-
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).
59-
60-
- For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all
61-
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
62-
If it is set to ``False``, then all sub-samples are counter separately - this is equivalent,
63-
in the case of label predictions, to flattening the inputs beforehand (i.e.
64-
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
65-
still applies in both cases, if set.
66-
67-
compute_on_step:
68-
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
69-
dist_sync_on_step:
70-
Synchronize metric state across processes at each ``forward()``
71-
before returning the value at the step
72-
process_group:
73-
Specify the process group on which synchronization is called.
74-
default: ``None`` (which selects the entire world)
75-
dist_sync_fn:
76-
Callback that performs the allgather operation on the metric state. When ``None``, DDP
77-
will be used to perform the allgather
78-
79-
Raises:
80-
ValueError:
81-
If ``threshold`` is not between ``0`` and ``1``.
82-
ValueError:
83-
If ``top_k`` is not an ``integer`` larger than ``0``.
84-
85-
Example:
86-
87-
>>> from pytorch_lightning.metrics import Accuracy
88-
>>> target = torch.tensor([0, 1, 2, 3])
89-
>>> preds = torch.tensor([0, 2, 1, 3])
90-
>>> accuracy = Accuracy()
91-
>>> accuracy(preds, target)
92-
tensor(0.5000)
93-
94-
>>> target = torch.tensor([0, 1, 2])
95-
>>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
96-
>>> accuracy = Accuracy(top_k=2)
97-
>>> accuracy(preds, target)
98-
tensor(0.6667)
24+
This implementation refers to :class:`~torchmetrics.Accuracy`.
9925
26+
.. warning:: This metric is deprecated, use ``torchmetrics.Accuracy``. Will be removed in v1.5.0.
10027
"""
10128

10229
def __init__(
@@ -109,44 +36,16 @@ def __init__(
10936
process_group: Optional[Any] = None,
11037
dist_sync_fn: Callable = None,
11138
):
39+
rank_zero_warn(
40+
"This `Accuracy` was deprecated in v1.3.0 in favor of `torchmetrics.Accuracy`."
41+
" It will be removed in v1.5.0", DeprecationWarning
42+
)
11243
super().__init__(
44+
threshold=threshold,
45+
top_k=top_k,
46+
subset_accuracy=subset_accuracy,
11347
compute_on_step=compute_on_step,
11448
dist_sync_on_step=dist_sync_on_step,
11549
process_group=process_group,
11650
dist_sync_fn=dist_sync_fn,
11751
)
118-
119-
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
120-
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
121-
122-
if not 0 < threshold < 1:
123-
raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}")
124-
125-
if top_k is not None and (not isinstance(top_k, int) or top_k <= 0):
126-
raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}")
127-
128-
self.threshold = threshold
129-
self.top_k = top_k
130-
self.subset_accuracy = subset_accuracy
131-
132-
def update(self, preds: torch.Tensor, target: torch.Tensor):
133-
"""
134-
Update state with predictions and targets.
135-
136-
Args:
137-
preds: Predictions from model (probabilities, or labels)
138-
target: Ground truth labels
139-
"""
140-
141-
correct, total = _accuracy_update(
142-
preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy
143-
)
144-
145-
self.correct += correct
146-
self.total += total
147-
148-
def compute(self) -> torch.Tensor:
149-
"""
150-
Computes accuracy based on inputs passed in to ``update`` previously.
151-
"""
152-
return _accuracy_compute(self.correct, self.total)

pytorch_lightning/metrics/functional/accuracy.py

Lines changed: 17 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -11,41 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional, Tuple
14+
from typing import Optional
1515

1616
import torch
17+
from torchmetrics.functional import accuracy as __accuracy
1718

18-
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
19-
20-
21-
def _accuracy_update(
22-
preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], subset_accuracy: bool
23-
) -> Tuple[torch.Tensor, torch.Tensor]:
24-
25-
preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)
26-
27-
if mode == DataType.MULTILABEL and top_k:
28-
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")
29-
30-
if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy):
31-
correct = (preds == target).all(dim=1).sum()
32-
total = torch.tensor(target.shape[0], device=target.device)
33-
elif mode == DataType.MULTILABEL and not subset_accuracy:
34-
correct = (preds == target).sum()
35-
total = torch.tensor(target.numel(), device=target.device)
36-
elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy):
37-
correct = (preds * target).sum()
38-
total = target.sum()
39-
elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
40-
sample_correct = (preds * target).sum(dim=(1, 2))
41-
correct = (sample_correct == target.shape[2]).sum()
42-
total = torch.tensor(target.shape[0], device=target.device)
43-
44-
return correct, total
45-
46-
47-
def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor:
48-
return correct.float() / total
19+
from pytorch_lightning.utilities import rank_zero_warn
4920

5021

5122
def accuracy(
@@ -55,68 +26,20 @@ def accuracy(
5526
top_k: Optional[int] = None,
5627
subset_accuracy: bool = False,
5728
) -> torch.Tensor:
58-
r"""Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_:
59-
60-
.. math::
61-
\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)
62-
63-
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a
64-
tensor of predictions.
65-
66-
For multi-class and multi-dimensional multi-class data with probability predictions, the
67-
parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the
68-
top-K highest probability items are considered to find the correct label.
69-
70-
For multi-label and multi-dimensional multi-class inputs, this metric computes the "global"
71-
accuracy by default, which counts all labels or sub-samples separately. This can be
72-
changed to subset accuracy (which requires all labels or sub-samples in the sample to
73-
be correctly predicted) by setting ``subset_accuracy=True``.
74-
75-
Args:
76-
preds: Predictions from model (probabilities, or labels)
77-
target: Ground truth labels
78-
threshold:
79-
Threshold probability value for transforming probability predictions to binary
80-
(0,1) predictions, in the case of binary or multi-label inputs.
81-
top_k:
82-
Number of highest probability predictions considered to find the correct label, relevant
83-
only for (multi-dimensional) multi-class inputs with probability predictions. The
84-
default value (``None``) will be interpreted as 1 for these inputs.
85-
86-
Should be left at default (``None``) for all other types of inputs.
87-
subset_accuracy:
88-
Whether to compute subset accuracy for multi-label and multi-dimensional
89-
multi-class inputs (has no effect for other input types).
90-
91-
- For multi-label inputs, if the parameter is set to ``True``, then all labels for
92-
each sample must be correctly predicted for the sample to count as correct. If it
93-
is set to ``False``, then all labels are counted separately - this is equivalent to
94-
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).
95-
96-
- For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all
97-
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
98-
If it is set to ``False``, then all sub-samples are counter separately - this is equivalent,
99-
in the case of label predictions, to flattening the inputs beforehand (i.e.
100-
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
101-
still applies in both cases, if set.
102-
103-
Raises:
104-
ValueError:
105-
If ``top_k`` parameter is set for ``multi-label`` inputs.
106-
107-
Example:
108-
109-
>>> from pytorch_lightning.metrics.functional import accuracy
110-
>>> target = torch.tensor([0, 1, 2, 3])
111-
>>> preds = torch.tensor([0, 2, 1, 3])
112-
>>> accuracy(preds, target)
113-
tensor(0.5000)
29+
r"""
30+
This implementation refers to :class:`~torchmetrics.functional.accuracy`.
11431
115-
>>> target = torch.tensor([0, 1, 2])
116-
>>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
117-
>>> accuracy(preds, target, top_k=2)
118-
tensor(0.6667)
32+
.. warning:: This metric is deprecated, use ``torchmetrics.functional.accuracy``. Will be removed in v1.5.0.
11933
"""
12034

121-
correct, total = _accuracy_update(preds, target, threshold, top_k, subset_accuracy)
122-
return _accuracy_compute(correct, total)
35+
rank_zero_warn(
36+
"This `accuracy` was deprecated in v1.3.0 in favor of `torchmetrics.functional.accuracy`."
37+
" It will be removed in v1.5.0", DeprecationWarning
38+
)
39+
return __accuracy(
40+
preds=preds,
41+
target=target,
42+
threshold=threshold,
43+
top_k=top_k,
44+
subset_accuracy=subset_accuracy,
45+
)

tests/deprecated_api/test_remove_1-5.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from unittest import mock
1616

1717
import pytest
18+
import torch
1819
from torch import optim
1920

2021
from pytorch_lightning import Callback, Trainer
@@ -111,3 +112,13 @@ def test_v1_5_0_model_checkpoint_period(tmpdir):
111112
ModelCheckpoint(dirpath=tmpdir)
112113
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
113114
ModelCheckpoint(dirpath=tmpdir, period=1)
115+
116+
117+
def test_v1_5_0_deprecated_metric_accuracy():
118+
from pytorch_lightning.metrics import Accuracy
119+
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
120+
Accuracy()
121+
122+
from pytorch_lightning.metrics.functional.accuracy import accuracy
123+
with pytest.deprecated_call(match='It will be removed in v1.5.0'):
124+
accuracy(preds=torch.tensor([0, 1]), target=torch.tensor([0, 1]))

0 commit comments

Comments
 (0)