1515# License: MIT
1616
1717import functools
18+ import numbers
1819import warnings
1920from inspect import signature
2021
2627from sklearn .utils .multiclass import unique_labels
2728from sklearn .utils .validation import check_consistent_length , column_or_1d
2829
30+ from ..utils ._param_validation import Interval , StrOptions , validate_params
2931
32+
33+ @validate_params (
34+ {
35+ "y_true" : ["array-like" ],
36+ "y_pred" : ["array-like" ],
37+ "labels" : ["array-like" , None ],
38+ "pos_label" : [str , numbers .Integral , None ],
39+ "average" : [
40+ None ,
41+ StrOptions ({"binary" , "micro" , "macro" , "weighted" , "samples" }),
42+ ],
43+ "warn_for" : ["array-like" ],
44+ "sample_weight" : ["array-like" , None ],
45+ }
46+ )
3047def sensitivity_specificity_support (
3148 y_true ,
3249 y_pred ,
@@ -57,13 +74,13 @@ def sensitivity_specificity_support(
5774
5875 Parameters
5976 ----------
60- y_true : ndarray of shape (n_samples,)
77+ y_true : array-like of shape (n_samples,)
6178 Ground truth (correct) target values.
6279
63- y_pred : ndarray of shape (n_samples,)
80+ y_pred : array-like of shape (n_samples,)
6481 Estimated targets as returned by a classifier.
6582
66- labels : list , default=None
83+ labels : array-like , default=None
6784 The set of labels to include when ``average != 'binary'``, and their
6885 order if ``average is None``. Labels present in the data can be
6986 excluded, for example to calculate a multiclass average ignoring a
@@ -72,8 +89,11 @@ def sensitivity_specificity_support(
7289 labels are column indices. By default, all labels in ``y_true`` and
7390 ``y_pred`` are used in sorted order.
7491
75- pos_label : str or int , default=1
92+ pos_label : str, int or None , default=1
7693 The class to report if ``average='binary'`` and the data is binary.
94+ If ``pos_label is None`` and in binary classification, this function
95+ returns the average sensitivity and specificity if ``average``
96+ is one of ``'weighted'``.
7797 If the data are multiclass, this will be ignored;
7898 setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
7999 scores for that label only.
@@ -105,7 +125,7 @@ def sensitivity_specificity_support(
105125 This determines which warnings will be made in the case that this
106126 function is being used to return only one of its metrics.
107127
108- sample_weight : ndarray of shape (n_samples,), default=None
128+ sample_weight : array-like of shape (n_samples,), default=None
109129 Sample weights.
110130
111131 Returns
@@ -274,6 +294,19 @@ def sensitivity_specificity_support(
274294 return sensitivity , specificity , true_sum
275295
276296
297+ @validate_params (
298+ {
299+ "y_true" : ["array-like" ],
300+ "y_pred" : ["array-like" ],
301+ "labels" : ["array-like" , None ],
302+ "pos_label" : [str , numbers .Integral , None ],
303+ "average" : [
304+ None ,
305+ StrOptions ({"binary" , "micro" , "macro" , "weighted" , "samples" }),
306+ ],
307+ "sample_weight" : ["array-like" , None ],
308+ }
309+ )
277310def sensitivity_score (
278311 y_true ,
279312 y_pred ,
@@ -295,21 +328,23 @@ def sensitivity_score(
295328
296329 Parameters
297330 ----------
298- y_true : ndarray of shape (n_samples,)
331+ y_true : array-like of shape (n_samples,)
299332 Ground truth (correct) target values.
300333
301- y_pred : ndarray of shape (n_samples,)
334+ y_pred : array-like of shape (n_samples,)
302335 Estimated targets as returned by a classifier.
303336
304- labels : list , default=None
337+ labels : array-like , default=None
305338 The set of labels to include when ``average != 'binary'``, and their
306339 order if ``average is None``. Labels present in the data can be
307340 excluded, for example to calculate a multiclass average ignoring a
308341 majority negative class, while labels not present in the data will
309342 result in 0 components in a macro average.
310343
311- pos_label : str or int , default=1
344+ pos_label : str, int or None , default=1
312345 The class to report if ``average='binary'`` and the data is binary.
346+ If ``pos_label is None`` and in binary classification, this function
347+ returns the average sensitivity if ``average`` is one of ``'weighted'``.
313348 If the data are multiclass, this will be ignored;
314349 setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
315350 scores for that label only.
@@ -337,7 +372,7 @@ def sensitivity_score(
337372 meaningful for multilabel classification where this differs from
338373 :func:`accuracy_score`).
339374
340- sample_weight : ndarray of shape (n_samples,), default=None
375+ sample_weight : array-like of shape (n_samples,), default=None
341376 Sample weights.
342377
343378 Returns
@@ -374,6 +409,19 @@ def sensitivity_score(
374409 return s
375410
376411
412+ @validate_params (
413+ {
414+ "y_true" : ["array-like" ],
415+ "y_pred" : ["array-like" ],
416+ "labels" : ["array-like" , None ],
417+ "pos_label" : [str , numbers .Integral , None ],
418+ "average" : [
419+ None ,
420+ StrOptions ({"binary" , "micro" , "macro" , "weighted" , "samples" }),
421+ ],
422+ "sample_weight" : ["array-like" , None ],
423+ }
424+ )
377425def specificity_score (
378426 y_true ,
379427 y_pred ,
@@ -395,21 +443,23 @@ def specificity_score(
395443
396444 Parameters
397445 ----------
398- y_true : ndarray of shape (n_samples,)
446+ y_true : array-like of shape (n_samples,)
399447 Ground truth (correct) target values.
400448
401- y_pred : ndarray of shape (n_samples,)
449+ y_pred : array-like of shape (n_samples,)
402450 Estimated targets as returned by a classifier.
403451
404- labels : list , default=None
452+ labels : array-like , default=None
405453 The set of labels to include when ``average != 'binary'``, and their
406454 order if ``average is None``. Labels present in the data can be
407455 excluded, for example to calculate a multiclass average ignoring a
408456 majority negative class, while labels not present in the data will
409457 result in 0 components in a macro average.
410458
411- pos_label : str or int , default=1
459+ pos_label : str, int or None , default=1
412460 The class to report if ``average='binary'`` and the data is binary.
461+ If ``pos_label is None`` and in binary classification, this function
462+ returns the average specificity if ``average`` is one of ``'weighted'``.
413463 If the data are multiclass, this will be ignored;
414464 setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
415465 scores for that label only.
@@ -437,7 +487,7 @@ def specificity_score(
437487 meaningful for multilabel classification where this differs from
438488 :func:`accuracy_score`).
439489
440- sample_weight : ndarray of shape (n_samples,), default=None
490+ sample_weight : array-like of shape (n_samples,), default=None
441491 Sample weights.
442492
443493 Returns
@@ -474,6 +524,22 @@ def specificity_score(
474524 return s
475525
476526
527+ @validate_params (
528+ {
529+ "y_true" : ["array-like" ],
530+ "y_pred" : ["array-like" ],
531+ "labels" : ["array-like" , None ],
532+ "pos_label" : [str , numbers .Integral , None ],
533+ "average" : [
534+ None ,
535+ StrOptions (
536+ {"binary" , "micro" , "macro" , "weighted" , "samples" , "multiclass" }
537+ ),
538+ ],
539+ "sample_weight" : ["array-like" , None ],
540+ "correction" : [Interval (numbers .Real , 0 , None , closed = "left" )],
541+ }
542+ )
477543def geometric_mean_score (
478544 y_true ,
479545 y_pred ,
@@ -507,21 +573,24 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
507573
508574 Parameters
509575 ----------
510- y_true : ndarray of shape (n_samples,)
576+ y_true : array-like of shape (n_samples,)
511577 Ground truth (correct) target values.
512578
513- y_pred : ndarray of shape (n_samples,)
579+ y_pred : array-like of shape (n_samples,)
514580 Estimated targets as returned by a classifier.
515581
516- labels : list , default=None
582+ labels : array-like , default=None
517583 The set of labels to include when ``average != 'binary'``, and their
518584 order if ``average is None``. Labels present in the data can be
519585 excluded, for example to calculate a multiclass average ignoring a
520586 majority negative class, while labels not present in the data will
521587 result in 0 components in a macro average.
522588
523- pos_label : str or int , default=1
589+ pos_label : str, int or None , default=1
524590 The class to report if ``average='binary'`` and the data is binary.
591+ If ``pos_label is None`` and in binary classification, this function
592+ returns the average geometric mean if ``average`` is one of
593+ ``'weighted'``.
525594 If the data are multiclass, this will be ignored;
526595 setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
527596 scores for that label only.
@@ -539,6 +608,8 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
539608 ``'macro'``:
540609 Calculate metrics for each label, and find their unweighted
541610 mean. This does not take label imbalance into account.
611+ ``'multiclass'``:
612+ No average is taken.
542613 ``'weighted'``:
543614 Calculate metrics for each label, and find their average, weighted
544615 by support (the number of true instances for each label). This
@@ -549,7 +620,7 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
549620 meaningful for multilabel classification where this differs from
550621 :func:`accuracy_score`).
551622
552- sample_weight : ndarray of shape (n_samples,), default=None
623+ sample_weight : array-like of shape (n_samples,), default=None
553624 Sample weights.
554625
555626 correction : float, default=0.0
@@ -658,6 +729,7 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
658729 return gmean
659730
660731
732+ @validate_params ({"alpha" : [numbers .Real ], "squared" : ["boolean" ]})
661733def make_index_balanced_accuracy (* , alpha = 0.1 , squared = True ):
662734 """Balance any scoring function using the index balanced accuracy.
663735
@@ -763,6 +835,22 @@ def compute_score(*args, **kwargs):
763835 return decorate
764836
765837
838+ @validate_params (
839+ {
840+ "y_true" : ["array-like" ],
841+ "y_pred" : ["array-like" ],
842+ "labels" : ["array-like" , None ],
843+ "target_names" : ["array-like" , None ],
844+ "sample_weight" : ["array-like" , None ],
845+ "digits" : [Interval (numbers .Integral , 0 , None , closed = "left" )],
846+ "alpha" : [numbers .Real ],
847+ "output_dict" : ["boolean" ],
848+ "zero_division" : [
849+ StrOptions ({"warn" }),
850+ Interval (numbers .Integral , 0 , 1 , closed = "both" ),
851+ ],
852+ }
853+ )
766854def classification_report_imbalanced (
767855 y_true ,
768856 y_pred ,
@@ -970,6 +1058,13 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
9701058 return report
9711059
9721060
1061+ @validate_params (
1062+ {
1063+ "y_true" : ["array-like" ],
1064+ "y_pred" : ["array-like" ],
1065+ "sample_weight" : ["array-like" , None ],
1066+ }
1067+ )
9731068def macro_averaged_mean_absolute_error (y_true , y_pred , * , sample_weight = None ):
9741069 """Compute Macro-Averaged MAE for imbalanced ordinal classification.
9751070
0 commit comments