diff --git a/CHANGELOG.md b/CHANGELOG.md index c102be008d0f3..ccd296d3b0e6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `ConfusionMatrix` class interface ([#4348](https://github.com/PyTorchLightning/pytorch-lightning/pull/4348)) +- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236)) + ### Changed - W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405)) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index c232ea216e885..de3cd01c33e9b 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -271,6 +271,13 @@ auroc [func] :noindex: +multiclass_auroc [func] +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.classification.multiclass_auroc + :noindex: + + average_precision [func] ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 189e3945e1a56..620072b44a2da 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -21,6 +21,7 @@ fbeta_score, multiclass_precision_recall_curve, multiclass_roc, + multiclass_auroc, precision, precision_recall, precision_recall_curve, diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 51b80308a579c..aec1b47096e26 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -817,13 +817,14 @@ def new_func(*args, **kwargs) -> torch.Tensor: def multiclass_auc_decorator(reorder: bool = True) -> Callable: def wrapper(func_to_decorate: Callable) -> Callable: + @wraps(func_to_decorate) def new_func(*args, **kwargs) -> torch.Tensor: results = [] for class_result in func_to_decorate(*args, **kwargs): x, y = class_result[:2] results.append(auc(x, y, reorder=reorder)) - return torch.cat(results) + return torch.stack(results) return new_func @@ -858,7 +859,7 @@ def auroc( if any(target > 1): raise ValueError('AUROC metric is meant for binary classification, but' ' target tensor contains value different from 0 and 1.' - ' Multiclass is currently not supported.') + ' Use `multiclass_auroc` for multi class classification.') @auc_decorator(reorder=True) def _auroc(pred, target, sample_weight, pos_label): @@ -867,6 +868,62 @@ def _auroc(pred, target, sample_weight, pos_label): return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) +def multiclass_auroc( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, +) -> torch.Tensor: + """ + Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from multiclass + prediction scores + + Args: + pred: estimated probabilities, with shape [N, C] + target: ground-truth labels, with shape [N,] + sample_weight: sample weights + num_classes: number of classes (default: None, computes automatically from data) + + Return: + Tensor containing ROCAUC score + + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_auroc(pred, target) # doctest: +NORMALIZE_WHITESPACE + tensor(0.6667) + """ + if not torch.allclose(pred.sum(dim=1), torch.tensor(1.0)): + raise ValueError( + "Multiclass AUROC metric expects the target scores to be" + " probabilities, i.e. they should sum up to 1.0 over classes") + + if torch.unique(target).size(0) != pred.size(1): + raise ValueError( + f"Number of classes found in in 'target' ({torch.unique(target).size(0)})" + f" does not equal the number of columns in 'pred' ({pred.size(1)})." + " Multiclass AUROC is not defined when all of the classes do not" + " occur in the target labels.") + + if num_classes is not None and num_classes != pred.size(1): + raise ValueError( + f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal" + f" the number of classes passed in 'num_classes' ({num_classes}).") + + @multiclass_auc_decorator(reorder=False) + def _multiclass_auroc(pred, target, sample_weight, num_classes): + return multiclass_roc(pred, target, sample_weight, num_classes) + + class_aurocs = _multiclass_auroc(pred=pred, target=target, + sample_weight=sample_weight, + num_classes=num_classes) + return torch.mean(class_aurocs) + + def average_precision( pred: torch.Tensor, target: torch.Tensor, diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index cab543f8dd6e6..12eb8555b10aa 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -30,6 +30,7 @@ dice_score, average_precision, auroc, + multiclass_auroc, precision_recall_curve, roc, auc, @@ -316,6 +317,47 @@ def test_auroc(pred, target, expected): assert score == expected +def test_multiclass_auroc(): + with pytest.raises(ValueError, + match=r".*probabilities, i.e. they should sum up to 1.0 over classes"): + _ = multiclass_auroc(pred=torch.tensor([[0.9, 0.9], + [1.0, 0]]), + target=torch.tensor([0, 1])) + + with pytest.raises(ValueError, + match=r".*not defined when all of the classes do not occur in the target.*"): + _ = multiclass_auroc(pred=torch.rand((4, 3)).softmax(dim=1), + target=torch.tensor([1, 0, 1, 0])) + + with pytest.raises(ValueError, + match=r".*does not equal the number of classes passed in 'num_classes'.*"): + _ = multiclass_auroc(pred=torch.rand((5, 4)).softmax(dim=1), + target=torch.tensor([0, 1, 2, 2, 3]), + num_classes=6) + + +@pytest.mark.parametrize('n_cls', [2, 5, 10, 50]) +def test_multiclass_auroc_against_sklearn(n_cls): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + n_samples = 300 + pred = torch.rand(n_samples, n_cls, device=device).softmax(dim=1) + target = torch.randint(n_cls, (n_samples,), device=device) + # Make sure target includes all class labels so that multiclass AUROC is defined + target[10:10 + n_cls] = torch.arange(n_cls) + + pl_score = multiclass_auroc(pred, target) + # For the binary case, sklearn expects an (n_samples,) array of probabilities of + # the positive class + pred = pred[:, 1] if n_cls == 2 else pred + sk_score = sk_roc_auc_score(target.cpu().detach().numpy(), + pred.cpu().detach().numpy(), + multi_class="ovr") + + sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) + assert torch.allclose(sk_score, pl_score) + + @pytest.mark.parametrize(['x', 'y', 'expected'], [ pytest.param([0, 1], [0, 1], 0.5), pytest.param([1, 0], [0, 1], 0.5),