@@ -482,17 +482,13 @@ def __multiclass_roc(
482482def auc (
483483 x : torch .Tensor ,
484484 y : torch .Tensor ,
485- reorder : bool = True
486485) -> torch .Tensor :
487486 """
488487 Computes Area Under the Curve (AUC) using the trapezoidal rule
489488
490489 Args:
491490 x: x-coordinates
492491 y: y-coordinates
493- reorder: reorder coordinates, so they are increasing. The unstable algorithm of torch.argsort is
494- used internally to sort `x` which may in some cases cause inaccuracies in the result.
495- WARNING: Deprecated and will be removed in v1.1.
496492
497493 Return:
498494 Tensor containing AUC score (float)
@@ -504,51 +500,38 @@ def auc(
504500 >>> auc(x, y)
505501 tensor(4.)
506502 """
507- direction = 1.
508-
509- if reorder :
510- rank_zero_warn ("The `reorder` parameter to `auc` has been deprecated and will be removed in v1.1"
511- " Note that when `reorder` is True, the unstable algorithm of torch.argsort is"
512- " used internally to sort 'x' which may in some cases cause inaccuracies"
513- " in the result." ,
514- DeprecationWarning )
515- # can't use lexsort here since it is not implemented for torch
516- order = torch .argsort (x )
517- x , y = x [order ], y [order ]
503+ dx = x [1 :] - x [:- 1 ]
504+ if (dx < 0 ).any ():
505+ if (dx <= 0 ).all ():
506+ direction = - 1.
507+ else :
508+ raise ValueError (f"The 'x' array is neither increasing or decreasing: { x } . Reorder is not supported." )
518509 else :
519- dx = x [1 :] - x [:- 1 ]
520- if (dx < 0 ).any ():
521- if (dx , 0 ).all ():
522- direction = - 1.
523- else :
524- # TODO: Update message on removing reorder
525- raise ValueError ("Reorder is not turned on, and the 'x' array is"
526- f" neither increasing or decreasing: { x } " )
527-
510+ direction = 1.
528511 return direction * torch .trapz (y , x )
529512
530513
531- def auc_decorator (reorder : bool = True ) -> Callable :
514+ def auc_decorator () -> Callable :
532515 def wrapper (func_to_decorate : Callable ) -> Callable :
533516 @wraps (func_to_decorate )
534517 def new_func (* args , ** kwargs ) -> torch .Tensor :
535518 x , y = func_to_decorate (* args , ** kwargs )[:2 ]
536519
537- return auc (x , y , reorder = reorder )
520+ return auc (x , y )
538521
539522 return new_func
540523
541524 return wrapper
542525
543526
544- def multiclass_auc_decorator (reorder : bool = True ) -> Callable :
527+ def multiclass_auc_decorator () -> Callable :
545528 def wrapper (func_to_decorate : Callable ) -> Callable :
546529 @wraps (func_to_decorate )
547530 def new_func (* args , ** kwargs ) -> torch .Tensor :
548531 results = []
549532 for class_result in func_to_decorate (* args , ** kwargs ):
550533 x , y = class_result [:2 ]
551- results .append (auc (x , y , reorder = reorder ))
534+ results .append (auc (x , y ))
552535
553536 return torch .stack (results )
554537
@@ -587,7 +570,7 @@ def auroc(
587570 ' target tensor contains value different from 0 and 1.'
588571 ' Use `multiclass_auroc` for multi class classification.' )
589572
590- @auc_decorator (reorder = True )
573+ @auc_decorator ()
591574 def _auroc (pred , target , sample_weight , pos_label ):
592575 return __roc (pred , target , sample_weight , pos_label )
593576
@@ -640,7 +623,7 @@ def multiclass_auroc(
640623 f"Number of classes deduced from 'pred' ({ pred .size (1 )} ) does not equal"
641624 f" the number of classes passed in 'num_classes' ({ num_classes } )." )
642625
643- @multiclass_auc_decorator (reorder = False )
626+ @multiclass_auc_decorator ()
644627 def _multiclass_auroc (pred , target , sample_weight , num_classes ):
645628 return __multiclass_roc (pred , target , sample_weight , num_classes )
646629
0 commit comments