Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Deprecated `filepath` in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213))

- Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237))

### Removed


Expand Down
14 changes: 11 additions & 3 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,9 @@ def auc(
Args:
x: x-coordinates
y: y-coordinates
reorder: reorder coordinates, so they are increasing
reorder: reorder coordinates, so they are increasing. The unstable algorithm of torch.argsort is
used internally to sort `x` which may in some cases cause inaccuracies in the result.
WARNING: Deprecated and will be removed in v1.1.

Return:
Tensor containing AUC score (float)
Expand All @@ -821,6 +823,11 @@ def auc(
direction = 1.

if reorder:
rank_zero_warn("The `reorder` parameter to `auc` has been deprecated and will be removed in v1.1"
" Note that when `reorder` is True, the unstable algorithm of torch.argsort is"
" used internally to sort 'x' which may in some cases cause inaccuracies"
" in the result.",
DeprecationWarning)
# can't use lexsort here since it is not implemented for torch
order = torch.argsort(x)
x, y = x[order], y[order]
Expand All @@ -830,8 +837,9 @@ def auc(
if (dx, 0).all():
direction = -1.
else:
raise ValueError("Reordering is not turned on, and "
"the x array is not increasing: %s" % x)
# TODO: Update message on removing reorder
raise ValueError("Reorder is not turned on, and the 'x' array is"
f" neither increasing or decreasing: {x}")

return direction * torch.trapz(y, x)

Expand Down
6 changes: 6 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from tests.base import EvalModelTemplate
from pytorch_lightning.metrics.functional.classification import auc

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -63,3 +64,8 @@ def test_dataloader(self):

def test_end(self, outputs):
return {'test_loss': torch.tensor(0.7)}


def test_auc_reorder_remove_in_v1_1_0():
with pytest.deprecated_call(match='The `reorder` parameter to `auc` has been deprecated'):
_ = auc(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 2]), reorder=True)