Skip to content

Commit c8f605e

Browse files
tadejsvtchatonBorda
authored
Classification metrics overhaul: precision & recall (4/n) (#4842)
* Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <[email protected]> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 1ff6b18 commit c8f605e

File tree

13 files changed

+1185
-296
lines changed

13 files changed

+1185
-296
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5454
- Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467))
5555

5656

57+
- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842))
58+
59+
60+
5761
### Changed
5862

5963
- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

docs/source/metrics.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,8 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ
382382
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
383383

384384

385-
Using the ``is_multiclass`` parameter
386-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
385+
Using the is_multiclass parameter
386+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
387387

388388
In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
389389
but are actually binary/multi-label - for example, if both predictions and targets are
@@ -602,14 +602,14 @@ roc [func]
602602
precision [func]
603603
~~~~~~~~~~~~~~~~
604604

605-
.. autofunction:: pytorch_lightning.metrics.functional.classification.precision
605+
.. autofunction:: pytorch_lightning.metrics.functional.precision
606606
:noindex:
607607

608608

609609
precision_recall [func]
610610
~~~~~~~~~~~~~~~~~~~~~~~
611611

612-
.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall
612+
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall
613613
:noindex:
614614

615615

@@ -623,7 +623,7 @@ precision_recall_curve [func]
623623
recall [func]
624624
~~~~~~~~~~~~~
625625

626-
.. autofunction:: pytorch_lightning.metrics.functional.classification.recall
626+
.. autofunction:: pytorch_lightning.metrics.functional.recall
627627
:noindex:
628628

629629
select_topk [func]

pytorch_lightning/metrics/classification/helpers.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from typing import Tuple, Optional
1515

16+
import numpy as np
1617
import torch
1718

1819
from pytorch_lightning.metrics.utils import to_onehot, select_topk
@@ -249,7 +250,7 @@ def _check_classification_inputs(
249250
is_multiclass:
250251
Used only in certain special cases, where you want to treat inputs as a different type
251252
than what they appear to be. See the parameter's
252-
:ref:`documentation section <metrics:Using the \\`\\`is_multiclass\\`\\` parameter>`
253+
:ref:`documentation section <metrics:Using the is_multiclass parameter>`
253254
for a more detailed explanation and examples.
254255
255256
@@ -375,7 +376,7 @@ def _input_format_classification(
375376
is_multiclass:
376377
Used only in certain special cases, where you want to treat inputs as a different type
377378
than what they appear to be. See the parameter's
378-
:ref:`documentation section <metrics:Using the \\`\\`is_multiclass\\`\\` parameter>`
379+
:ref:`documentation section <metrics:Using the is_multiclass parameter>`
379380
for a more detailed explanation and examples.
380381
381382
@@ -437,3 +438,69 @@ def _input_format_classification(
437438
preds, target = preds.squeeze(-1), target.squeeze(-1)
438439

439440
return preds.int(), target.int(), case
441+
442+
443+
def _reduce_stat_scores(
444+
numerator: torch.Tensor,
445+
denominator: torch.Tensor,
446+
weights: Optional[torch.Tensor],
447+
average: str,
448+
mdmc_average: Optional[str],
449+
zero_division: int = 0,
450+
) -> torch.Tensor:
451+
"""
452+
Reduces scores of type ``numerator/denominator`` or
453+
``weights * (numerator/denominator)``, if ``average='weighted'``.
454+
455+
Args:
456+
numerator: A tensor with numerator numbers.
457+
denominator: A tensor with denominator numbers. If a denominator is
458+
negative, the class will be ignored (if averaging), or its score
459+
will be returned as ``nan`` (if ``average=None``).
460+
If the denominator is zero, then ``zero_division`` score will be
461+
used for those elements.
462+
weights:
463+
A tensor of weights to be used if ``average='weighted'``.
464+
average:
465+
The method to average the scores. Should be one of ``'micro'``, ``'macro'``,
466+
``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior
467+
corresponds to `sklearn averaging methods <https://scikit-learn.org/stable/modules/\
468+
model_evaluation.html#multiclass-and-multilabel-classification>`__.
469+
mdmc_average:
470+
The method to average the scores if inputs were multi-dimensional multi-class (MDMC).
471+
Should be either ``'global'`` or ``'samplewise'``. If inputs were not
472+
multi-dimensional multi-class, it should be ``None`` (default).
473+
zero_division:
474+
The value to use for the score if denominator equals zero.
475+
"""
476+
numerator, denominator = numerator.float(), denominator.float()
477+
zero_div_mask = denominator == 0
478+
ignore_mask = denominator < 0
479+
480+
if weights is None:
481+
weights = torch.ones_like(denominator)
482+
else:
483+
weights = weights.float()
484+
485+
numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator)
486+
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
487+
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)
488+
489+
if average not in ["micro", "none", None]:
490+
weights = weights / weights.sum(dim=-1, keepdim=True)
491+
492+
scores = weights * (numerator / denominator)
493+
494+
# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
495+
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)
496+
497+
if mdmc_average == "samplewise":
498+
scores = scores.mean(dim=0)
499+
ignore_mask = ignore_mask.sum(dim=0).bool()
500+
501+
if average in ["none", None]:
502+
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
503+
else:
504+
scores = scores.sum()
505+
506+
return scores

0 commit comments

Comments
 (0)