Skip to content

Commit 90d1d9f

Browse files
Bordas-rog
andauthored
drop deprecated reorder from AUC (#5004)
* drop deprecated reorder from AUC * chlog * fix * fix * simple * fix * fix * fix Co-authored-by: Roger Shieh <[email protected]>
1 parent 20b806a commit 90d1d9f

File tree

3 files changed

+15
-35
lines changed

3 files changed

+15
-35
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
126126

127127
### Removed
128128

129+
- Removed `reorder` parameter of the `auc` metric ([#5004](https://github.com/PyTorchLightning/pytorch-lightning/pull/5004))
130+
129131

130132

131133
### Fixed

pytorch_lightning/metrics/functional/classification.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -482,17 +482,13 @@ def __multiclass_roc(
482482
def auc(
483483
x: torch.Tensor,
484484
y: torch.Tensor,
485-
reorder: bool = True
486485
) -> torch.Tensor:
487486
"""
488487
Computes Area Under the Curve (AUC) using the trapezoidal rule
489488
490489
Args:
491490
x: x-coordinates
492491
y: y-coordinates
493-
reorder: reorder coordinates, so they are increasing. The unstable algorithm of torch.argsort is
494-
used internally to sort `x` which may in some cases cause inaccuracies in the result.
495-
WARNING: Deprecated and will be removed in v1.1.
496492
497493
Return:
498494
Tensor containing AUC score (float)
@@ -504,51 +500,38 @@ def auc(
504500
>>> auc(x, y)
505501
tensor(4.)
506502
"""
507-
direction = 1.
508-
509-
if reorder:
510-
rank_zero_warn("The `reorder` parameter to `auc` has been deprecated and will be removed in v1.1"
511-
" Note that when `reorder` is True, the unstable algorithm of torch.argsort is"
512-
" used internally to sort 'x' which may in some cases cause inaccuracies"
513-
" in the result.",
514-
DeprecationWarning)
515-
# can't use lexsort here since it is not implemented for torch
516-
order = torch.argsort(x)
517-
x, y = x[order], y[order]
503+
dx = x[1:] - x[:-1]
504+
if (dx < 0).any():
505+
if (dx <= 0).all():
506+
direction = -1.
507+
else:
508+
raise ValueError(f"The 'x' array is neither increasing or decreasing: {x}. Reorder is not supported.")
518509
else:
519-
dx = x[1:] - x[:-1]
520-
if (dx < 0).any():
521-
if (dx, 0).all():
522-
direction = -1.
523-
else:
524-
# TODO: Update message on removing reorder
525-
raise ValueError("Reorder is not turned on, and the 'x' array is"
526-
f" neither increasing or decreasing: {x}")
527-
510+
direction = 1.
528511
return direction * torch.trapz(y, x)
529512

530513

531-
def auc_decorator(reorder: bool = True) -> Callable:
514+
def auc_decorator() -> Callable:
532515
def wrapper(func_to_decorate: Callable) -> Callable:
533516
@wraps(func_to_decorate)
534517
def new_func(*args, **kwargs) -> torch.Tensor:
535518
x, y = func_to_decorate(*args, **kwargs)[:2]
536519

537-
return auc(x, y, reorder=reorder)
520+
return auc(x, y)
538521

539522
return new_func
540523

541524
return wrapper
542525

543526

544-
def multiclass_auc_decorator(reorder: bool = True) -> Callable:
527+
def multiclass_auc_decorator() -> Callable:
545528
def wrapper(func_to_decorate: Callable) -> Callable:
546529
@wraps(func_to_decorate)
547530
def new_func(*args, **kwargs) -> torch.Tensor:
548531
results = []
549532
for class_result in func_to_decorate(*args, **kwargs):
550533
x, y = class_result[:2]
551-
results.append(auc(x, y, reorder=reorder))
534+
results.append(auc(x, y))
552535

553536
return torch.stack(results)
554537

@@ -587,7 +570,7 @@ def auroc(
587570
' target tensor contains value different from 0 and 1.'
588571
' Use `multiclass_auroc` for multi class classification.')
589572

590-
@auc_decorator(reorder=True)
573+
@auc_decorator()
591574
def _auroc(pred, target, sample_weight, pos_label):
592575
return __roc(pred, target, sample_weight, pos_label)
593576

@@ -640,7 +623,7 @@ def multiclass_auroc(
640623
f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal"
641624
f" the number of classes passed in 'num_classes' ({num_classes}).")
642625

643-
@multiclass_auc_decorator(reorder=False)
626+
@multiclass_auc_decorator()
644627
def _multiclass_auroc(pred, target, sample_weight, num_classes):
645628
return __multiclass_roc(pred, target, sample_weight, num_classes)
646629

tests/test_deprecated.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,3 @@ def test_dataloader(self):
135135

136136
def test_end(self, outputs):
137137
return {'test_loss': torch.tensor(0.7)}
138-
139-
140-
def test_reorder_remove_in_v1_1():
141-
with pytest.deprecated_call(match='The `reorder` parameter to `auc` has been deprecated'):
142-
_ = auc(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 2]), reorder=True)

0 commit comments

Comments
 (0)