diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d979b41893f5..6790579f4d0bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `filepath` in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213)) +- Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237)) + ### Removed diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index abf661e46fb1b..a831611fb9593 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -806,7 +806,9 @@ def auc( Args: x: x-coordinates y: y-coordinates - reorder: reorder coordinates, so they are increasing + reorder: reorder coordinates, so they are increasing. The unstable algorithm of torch.argsort is + used internally to sort `x` which may in some cases cause inaccuracies in the result. + WARNING: Deprecated and will be removed in v1.1. Return: Tensor containing AUC score (float) @@ -821,6 +823,11 @@ def auc( direction = 1. if reorder: + rank_zero_warn("The `reorder` parameter to `auc` has been deprecated and will be removed in v1.1" + " Note that when `reorder` is True, the unstable algorithm of torch.argsort is" + " used internally to sort 'x' which may in some cases cause inaccuracies" + " in the result.", + DeprecationWarning) # can't use lexsort here since it is not implemented for torch order = torch.argsort(x) x, y = x[order], y[order] @@ -830,8 +837,9 @@ def auc( if (dx, 0).all(): direction = -1. else: - raise ValueError("Reordering is not turned on, and " - "the x array is not increasing: %s" % x) + # TODO: Update message on removing reorder + raise ValueError("Reorder is not turned on, and the 'x' array is" + f" neither increasing or decreasing: {x}") return direction * torch.trapz(y, x) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 70eb2a709b195..c8a7b1d270e35 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -5,6 +5,7 @@ import torch from tests.base import EvalModelTemplate +from pytorch_lightning.metrics.functional.classification import auc from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -63,3 +64,8 @@ def test_dataloader(self): def test_end(self, outputs): return {'test_loss': torch.tensor(0.7)} + + +def test_auc_reorder_remove_in_v1_1_0(): + with pytest.deprecated_call(match='The `reorder` parameter to `auc` has been deprecated'): + _ = auc(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 2]), reorder=True)