|
13 | 13 | # limitations under the License. |
14 | 14 | from typing import Tuple, Optional |
15 | 15 |
|
| 16 | +import numpy as np |
16 | 17 | import torch |
17 | 18 |
|
18 | 19 | from pytorch_lightning.metrics.utils import to_onehot, select_topk |
@@ -249,7 +250,7 @@ def _check_classification_inputs( |
249 | 250 | is_multiclass: |
250 | 251 | Used only in certain special cases, where you want to treat inputs as a different type |
251 | 252 | than what they appear to be. See the parameter's |
252 | | - :ref:`documentation section <metrics:Using the \\`\\`is_multiclass\\`\\` parameter>` |
| 253 | + :ref:`documentation section <metrics:Using the is_multiclass parameter>` |
253 | 254 | for a more detailed explanation and examples. |
254 | 255 |
|
255 | 256 |
|
@@ -375,7 +376,7 @@ def _input_format_classification( |
375 | 376 | is_multiclass: |
376 | 377 | Used only in certain special cases, where you want to treat inputs as a different type |
377 | 378 | than what they appear to be. See the parameter's |
378 | | - :ref:`documentation section <metrics:Using the \\`\\`is_multiclass\\`\\` parameter>` |
| 379 | + :ref:`documentation section <metrics:Using the is_multiclass parameter>` |
379 | 380 | for a more detailed explanation and examples. |
380 | 381 |
|
381 | 382 |
|
@@ -437,3 +438,69 @@ def _input_format_classification( |
437 | 438 | preds, target = preds.squeeze(-1), target.squeeze(-1) |
438 | 439 |
|
439 | 440 | return preds.int(), target.int(), case |
| 441 | + |
| 442 | + |
| 443 | +def _reduce_stat_scores( |
| 444 | + numerator: torch.Tensor, |
| 445 | + denominator: torch.Tensor, |
| 446 | + weights: Optional[torch.Tensor], |
| 447 | + average: str, |
| 448 | + mdmc_average: Optional[str], |
| 449 | + zero_division: int = 0, |
| 450 | +) -> torch.Tensor: |
| 451 | + """ |
| 452 | + Reduces scores of type ``numerator/denominator`` or |
| 453 | + ``weights * (numerator/denominator)``, if ``average='weighted'``. |
| 454 | +
|
| 455 | + Args: |
| 456 | + numerator: A tensor with numerator numbers. |
| 457 | + denominator: A tensor with denominator numbers. If a denominator is |
| 458 | + negative, the class will be ignored (if averaging), or its score |
| 459 | + will be returned as ``nan`` (if ``average=None``). |
| 460 | + If the denominator is zero, then ``zero_division`` score will be |
| 461 | + used for those elements. |
| 462 | + weights: |
| 463 | + A tensor of weights to be used if ``average='weighted'``. |
| 464 | + average: |
| 465 | + The method to average the scores. Should be one of ``'micro'``, ``'macro'``, |
| 466 | + ``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior |
| 467 | + corresponds to `sklearn averaging methods <https://scikit-learn.org/stable/modules/\ |
| 468 | +model_evaluation.html#multiclass-and-multilabel-classification>`__. |
| 469 | + mdmc_average: |
| 470 | + The method to average the scores if inputs were multi-dimensional multi-class (MDMC). |
| 471 | + Should be either ``'global'`` or ``'samplewise'``. If inputs were not |
| 472 | + multi-dimensional multi-class, it should be ``None`` (default). |
| 473 | + zero_division: |
| 474 | + The value to use for the score if denominator equals zero. |
| 475 | + """ |
| 476 | + numerator, denominator = numerator.float(), denominator.float() |
| 477 | + zero_div_mask = denominator == 0 |
| 478 | + ignore_mask = denominator < 0 |
| 479 | + |
| 480 | + if weights is None: |
| 481 | + weights = torch.ones_like(denominator) |
| 482 | + else: |
| 483 | + weights = weights.float() |
| 484 | + |
| 485 | + numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator) |
| 486 | + denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator) |
| 487 | + weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights) |
| 488 | + |
| 489 | + if average not in ["micro", "none", None]: |
| 490 | + weights = weights / weights.sum(dim=-1, keepdim=True) |
| 491 | + |
| 492 | + scores = weights * (numerator / denominator) |
| 493 | + |
| 494 | + # This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted' |
| 495 | + scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores) |
| 496 | + |
| 497 | + if mdmc_average == "samplewise": |
| 498 | + scores = scores.mean(dim=0) |
| 499 | + ignore_mask = ignore_mask.sum(dim=0).bool() |
| 500 | + |
| 501 | + if average in ["none", None]: |
| 502 | + scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores) |
| 503 | + else: |
| 504 | + scores = scores.sum() |
| 505 | + |
| 506 | + return scores |
0 commit comments