Skip to content

Commit ccffc34

Browse files
tadejsvteddykokerBordatchatonrohitgr7
authored
Classification metrics overhaul: accuracy metrics (2/n) (#4838)
* Add stuff * Change metrics documentation layout * Add stuff * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * Division with float * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <[email protected]> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <[email protected]> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <[email protected]> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <[email protected]> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <[email protected]> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <[email protected]> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <[email protected]> * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <[email protected]> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <[email protected]> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <[email protected]> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> * Suggestions from code review * Fix number in docs * Update pytorch_lightning/metrics/classification/accuracy.py * Replace topk by argsort in select_topk * Fix changelog * Add test for wrong params * Add Google Colab badges (#5111) * Add colab badges to notebook Add colab badges to notebook to notebooks 4 & 5 * Add colab badges Co-authored-by: chaton <[email protected]> * Fix hanging metrics tests (#5134) * Use torch.topk again as ddp hanging tests fixed in #5134 * Fix unwanted notebooks change * Fix too long line in hamming_distance * Apply suggestions from code review * Apply suggestions from code review * protect * Update CHANGELOG.md Co-authored-by: Teddy Koker <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: chaton <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Roger Shieh <[email protected]> Co-authored-by: Shachar Mirkin <[email protected]>
1 parent 0f36525 commit ccffc34

File tree

18 files changed

+609
-182
lines changed

18 files changed

+609
-182
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838))
13+
14+
- `Accuracy` metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with the `subset_accuracy` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838))
15+
16+
- `HammingDistance` metric to compute the hamming distance (loss) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838))
1217

1318
### Changed
1419

docs/source/metrics.rst

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,12 @@ FBeta
292292
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
293293
:noindex:
294294

295+
Hamming Distance
296+
~~~~~~~~~~~~~~~~
297+
298+
.. autoclass:: pytorch_lightning.metrics.classification.HammingDistance
299+
:noindex:
300+
295301
Precision
296302
~~~~~~~~~
297303

@@ -323,10 +329,9 @@ Functional Metrics (Classification)
323329
accuracy [func]
324330
~~~~~~~~~~~~~~~
325331

326-
.. autofunction:: pytorch_lightning.metrics.functional.classification.accuracy
332+
.. autofunction:: pytorch_lightning.metrics.functional.accuracy
327333
:noindex:
328334

329-
330335
auc [func]
331336
~~~~~~~~~~
332337

@@ -382,6 +387,11 @@ fbeta [func]
382387
.. autofunction:: pytorch_lightning.metrics.functional.fbeta
383388
:noindex:
384389

390+
hamming_distance [func]
391+
~~~~~~~~~~~~~~~~~~~~~~~
392+
393+
.. autofunction:: pytorch_lightning.metrics.functional.hamming_distance
394+
:noindex:
385395

386396
iou [func]
387397
~~~~~~~~~~

pytorch_lightning/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from pytorch_lightning.metrics.classification import ( # noqa: F401
1717
Accuracy,
18+
HammingDistance,
1819
Precision,
1920
Recall,
2021
ConfusionMatrix,

pytorch_lightning/metrics/classification/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytorch_lightning.metrics.classification.average_precision import AveragePrecision # noqa: F401
1616
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
1717
from pytorch_lightning.metrics.classification.f_beta import FBeta, Fbeta, F1 # noqa: F401
18+
from pytorch_lightning.metrics.classification.hamming_distance import HammingDistance # noqa: F401
1819
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall # noqa: F401
1920
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
2021
from pytorch_lightning.metrics.classification.roc import ROC # noqa: F401

pytorch_lightning/metrics/classification/accuracy.py

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,57 @@
1616
import torch
1717

1818
from pytorch_lightning.metrics.metric import Metric
19-
from pytorch_lightning.metrics.utils import _input_format_classification
19+
from pytorch_lightning.metrics.functional.accuracy import _accuracy_update, _accuracy_compute
2020

2121

2222
class Accuracy(Metric):
2323
r"""
2424
Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_:
2525
26-
.. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y_i})
26+
.. math::
27+
\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)
2728
2829
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a
29-
tensor of predictions. Works with binary, multiclass, and multilabel
30-
data. Accepts logits from a model output or integer class values in
31-
prediction. Works with multi-dimensional preds and target.
30+
tensor of predictions.
3231
33-
Forward accepts
32+
For multi-class and multi-dimensional multi-class data with probability predictions, the
33+
parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the
34+
top-K highest probability items are considered to find the correct label.
3435
35-
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
36-
- ``target`` (long tensor): ``(N, ...)``
36+
For multi-label and multi-dimensional multi-class inputs, this metric computes the "global"
37+
accuracy by default, which counts all labels or sub-samples separately. This can be
38+
changed to subset accuracy (which requires all labels or sub-samples in the sample to
39+
be correctly predicted) by setting ``subset_accuracy=True``.
3740
38-
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
39-
This is the case for binary and multi-label logits.
40-
41-
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
41+
Accepts all input types listed in :ref:`metrics:Input types`.
4242
4343
Args:
4444
threshold:
45-
Threshold value for binary or multi-label logits. default: 0.5
45+
Threshold probability value for transforming probability predictions to binary
46+
`(0,1)` predictions, in the case of binary or multi-label inputs.
47+
top_k:
48+
Number of highest probability predictions considered to find the correct label, relevant
49+
only for (multi-dimensional) multi-class inputs with probability predictions. The
50+
default value (``None``) will be interpreted as 1 for these inputs.
51+
52+
Should be left at default (``None``) for all other types of inputs.
53+
subset_accuracy:
54+
Whether to compute subset accuracy for multi-label and multi-dimensional
55+
multi-class inputs (has no effect for other input types).
56+
57+
For multi-label inputs, if the parameter is set to `True`, then all labels for
58+
each sample must be correctly predicted for the sample to count as correct. If it
59+
is set to `False`, then all labels are counted separately - this is equivalent to
60+
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).
61+
62+
For multi-dimensional multi-class inputs, if the parameter is set to `True`, then all
63+
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
64+
If it is set to `False`, then all sub-samples are counter separately - this is equivalent,
65+
in the case of label predictions, to flattening the inputs beforehand (i.e.
66+
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
67+
still applies in both cases, if set.
4668
compute_on_step:
47-
Forward only calls ``update()`` and return None if this is set to False. default: True
69+
Forward only calls ``update()`` and return None if this is set to False.
4870
dist_sync_on_step:
4971
Synchronize metric state across processes at each ``forward()``
5072
before returning the value at the step. default: False
@@ -63,10 +85,19 @@ class Accuracy(Metric):
6385
>>> accuracy(preds, target)
6486
tensor(0.5000)
6587
88+
>>> target = torch.tensor([0, 1, 2])
89+
>>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
90+
>>> accuracy = Accuracy(top_k=2)
91+
>>> accuracy(preds, target)
92+
tensor(0.6667)
93+
6694
"""
95+
6796
def __init__(
6897
self,
6998
threshold: float = 0.5,
99+
top_k: Optional[int] = None,
100+
subset_accuracy: bool = False,
70101
compute_on_step: bool = True,
71102
dist_sync_on_step: bool = False,
72103
process_group: Optional[Any] = None,
@@ -82,24 +113,35 @@ def __init__(
82113
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
83114
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
84115

116+
if not 0 <= threshold <= 1:
117+
raise ValueError("The `threshold` should lie in the [0,1] interval.")
118+
119+
if top_k is not None and top_k <= 0:
120+
raise ValueError("The `top_k` should be an integer larger than 1.")
121+
85122
self.threshold = threshold
123+
self.top_k = top_k
124+
self.subset_accuracy = subset_accuracy
86125

87126
def update(self, preds: torch.Tensor, target: torch.Tensor):
88127
"""
89-
Update state with predictions and targets.
128+
Update state with predictions and targets. See :ref:`metrics:Input types` for more information
129+
on input types.
90130
91131
Args:
92-
preds: Predictions from model
93-
target: Ground truth values
132+
preds: Predictions from model (probabilities, or labels)
133+
target: Ground truth labels
94134
"""
95-
preds, target = _input_format_classification(preds, target, self.threshold)
96-
assert preds.shape == target.shape
97135

98-
self.correct += torch.sum(preds == target)
99-
self.total += target.numel()
136+
correct, total = _accuracy_update(
137+
preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy
138+
)
139+
140+
self.correct += correct
141+
self.total += total
100142

101-
def compute(self):
143+
def compute(self) -> torch.Tensor:
102144
"""
103-
Computes accuracy over state.
145+
Computes accuracy based on inputs passed in to ``update`` previously.
104146
"""
105-
return self.correct.float() / self.total
147+
return _accuracy_compute(self.correct, self.total)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Callable, Optional
15+
16+
import torch
17+
from pytorch_lightning.metrics.metric import Metric
18+
from pytorch_lightning.metrics.functional.hamming_distance import _hamming_distance_update, _hamming_distance_compute
19+
20+
21+
class HammingDistance(Metric):
22+
r"""
23+
Computes the average `Hamming distance <https://en.wikipedia.org/wiki/Hamming_distance>`_ (also
24+
known as Hamming loss) between targets and predictions:
25+
26+
.. math::
27+
\text{Hamming distance} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}})
28+
29+
Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions,
30+
and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that
31+
tensor.
32+
33+
This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it
34+
treats each possible label separately - meaning that, for example, multi-class data is
35+
treated as if it were multi-label.
36+
37+
Accepts all input types listed in :ref:`metrics:Input types`.
38+
39+
Args:
40+
threshold:
41+
Threshold probability value for transforming probability predictions to binary
42+
`(0,1)` predictions, in the case of binary or multi-label inputs.
43+
compute_on_step:
44+
Forward only calls ``update()`` and return None if this is set to False.
45+
dist_sync_on_step:
46+
Synchronize metric state across processes at each ``forward()``
47+
before returning the value at the step.
48+
process_group:
49+
Specify the process group on which synchronization is called. default: None (which selects the entire world)
50+
dist_sync_fn:
51+
Callback that performs the allgather operation on the metric state. When ``None``, DDP
52+
will be used to perform the all gather.
53+
54+
Example:
55+
56+
>>> from pytorch_lightning.metrics import HammingDistance
57+
>>> target = torch.tensor([[0, 1], [1, 1]])
58+
>>> preds = torch.tensor([[0, 1], [0, 1]])
59+
>>> hamming_distance = HammingDistance()
60+
>>> hamming_distance(preds, target)
61+
tensor(0.2500)
62+
63+
"""
64+
65+
def __init__(
66+
self,
67+
threshold: float = 0.5,
68+
compute_on_step: bool = True,
69+
dist_sync_on_step: bool = False,
70+
process_group: Optional[Any] = None,
71+
dist_sync_fn: Callable = None,
72+
):
73+
super().__init__(
74+
compute_on_step=compute_on_step,
75+
dist_sync_on_step=dist_sync_on_step,
76+
process_group=process_group,
77+
dist_sync_fn=dist_sync_fn,
78+
)
79+
80+
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
81+
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
82+
83+
if not 0 <= threshold <= 1:
84+
raise ValueError("The `threshold` should lie in the [0,1] interval.")
85+
self.threshold = threshold
86+
87+
def update(self, preds: torch.Tensor, target: torch.Tensor):
88+
"""
89+
Update state with predictions and targets. See :ref:`metrics:Input types` for more information
90+
on input types.
91+
92+
Args:
93+
preds: Predictions from model (probabilities, or labels)
94+
target: Ground truth labels
95+
"""
96+
correct, total = _hamming_distance_update(preds, target, self.threshold)
97+
98+
self.correct += correct
99+
self.total += total
100+
101+
def compute(self) -> torch.Tensor:
102+
"""
103+
Computes hamming distance based on inputs passed in to ``update`` previously.
104+
"""
105+
return _hamming_distance_compute(self.correct, self.total)

pytorch_lightning/metrics/classification/helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,11 @@ def _input_format_classification(
405405
else:
406406
preds, target = preds.squeeze(), target.squeeze()
407407

408+
# Convert half precision tensors to full precision, as not all ops are supported
409+
# for example, min() is not supported
410+
if preds.dtype == torch.float16:
411+
preds = preds.float()
412+
408413
case = _check_classification_inputs(
409414
preds,
410415
target,

pytorch_lightning/metrics/functional/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
from pytorch_lightning.metrics.functional.average_precision import average_precision # noqa: F401
1515
from pytorch_lightning.metrics.functional.classification import ( # noqa: F401
16-
accuracy,
1716
auc,
1817
auroc,
1918
dice_score,
@@ -32,8 +31,10 @@
3231
)
3332
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
3433
# TODO: unify metrics between class and functional, add below
34+
from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401
3535
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
3636
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401
37+
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
3738
from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401
3839
from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401
3940
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401

0 commit comments

Comments
 (0)