@@ -806,6 +806,8 @@ def classification_report_imbalanced(
806806 sample_weight = None ,
807807 digits = 2 ,
808808 alpha = 0.1 ,
809+ output_dict = False ,
810+ zero_division = "warn" ,
809811):
810812 """Build a classification report based on metrics used with imbalanced
811813 dataset
@@ -816,38 +818,59 @@ def classification_report_imbalanced(
816818 mean, and index balanced accuracy of the
817819 geometric mean.
818820
821+ Read more in the :ref:`User Guide <classification_report>`.
822+
819823 Parameters
820824 ----------
821- y_true : ndarray, shape (n_samples, )
825+ y_true : 1d array-like, or label indicator array / sparse matrix
822826 Ground truth (correct) target values.
823827
824- y_pred : ndarray, shape (n_samples, )
828+ y_pred : 1d array-like, or label indicator array / sparse matrix
825829 Estimated targets as returned by a classifier.
826830
827- labels : list, optional
828- The set of labels to include when ``average != 'binary'``, and their
829- order if ``average is None``. Labels present in the data can be
830- excluded, for example to calculate a multiclass average ignoring a
831- majority negative class, while labels not present in the data will
832- result in 0 components in a macro average.
831+ labels : array-like of shape (n_labels,), default=None
832+ Optional list of label indices to include in the report.
833833
834- target_names : list of strings, optional
834+ target_names : list of str of shape (n_labels,), default=None
835835 Optional display names matching the labels (same order).
836836
837- sample_weight : ndarray, shape (n_samples, )
837+ sample_weight : array-like of shape (n_samples,), default=None
838838 Sample weights.
839839
840- digits : int, optional (default=2)
841- Number of digits for formatting output floating point values
840+ digits : int, default=2
841+ Number of digits for formatting output floating point values.
842+ When ``output_dict`` is ``True``, this will be ignored and the
843+ returned values will not be rounded.
842844
843- alpha : float, optional ( default=0.1)
845+ alpha : float, default=0.1
844846 Weighting factor.
845847
848+ output_dict : bool, default=False
849+ If True, return output as dict.
850+
851+ .. versionadded:: 0.7
852+
853+ zero_division : "warn" or {0, 1}, default="warn"
854+ Sets the value to return when there is a zero division. If set to
855+ "warn", this acts as 0, but warnings are also raised.
856+
857+ .. versionadded:: 0.7
858+
846859 Returns
847860 -------
848- report : string
861+ report : string / dict
849862 Text summary of the precision, recall, specificity, geometric mean,
850863 and index balanced accuracy.
864+ Dictionary returned if output_dict is True. Dictionary has the
865+ following structure::
866+
867+ {'label 1': {'pre':0.5,
868+ 'rec':1.0,
869+ ...
870+ },
871+ 'label 2': { ... },
872+ ...
873+ }
851874
852875 Examples
853876 --------
@@ -883,7 +906,7 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
883906 last_line_heading = "avg / total"
884907
885908 if target_names is None :
886- target_names = ["%s" % l for l in labels ]
909+ target_names = [f" { label } " for label in labels ]
887910 name_width = max (len (cn ) for cn in target_names )
888911 width = max (name_width , len (last_line_heading ), digits )
889912
@@ -905,6 +928,7 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
905928 labels = labels ,
906929 average = None ,
907930 sample_weight = sample_weight ,
931+ zero_division = zero_division
908932 )
909933 # Specificity
910934 specificity = specificity_score (
@@ -934,33 +958,50 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
934958 sample_weight = sample_weight ,
935959 )
936960
961+ report_dict = {}
937962 for i , label in enumerate (labels ):
963+ report_dict_label = {}
938964 values = [target_names [i ]]
939- for v in (
940- precision [i ],
941- recall [i ],
942- specificity [i ],
943- f1 [i ],
944- geo_mean [i ],
945- iba [i ],
965+ for score_name , score_value in zip (
966+ headers [1 :- 1 ],
967+ [
968+ precision [i ],
969+ recall [i ],
970+ specificity [i ],
971+ f1 [i ],
972+ geo_mean [i ],
973+ iba [i ],
974+ ]
946975 ):
947- values += ["{0:0.{1}f}" .format (v , digits )]
948- values += ["{}" .format (support [i ])]
976+ values += ["{0:0.{1}f}" .format (score_value , digits )]
977+ report_dict_label [score_name ] = score_value
978+ values += [f"{ support [i ]} " ]
979+ report_dict_label [headers [- 1 ]] = support [i ]
949980 report += fmt % tuple (values )
950981
982+ report_dict [label ] = report_dict_label
983+
951984 report += "\n "
952985
953986 # compute averages
954987 values = [last_line_heading ]
955- for v in (
956- np .average (precision , weights = support ),
957- np .average (recall , weights = support ),
958- np .average (specificity , weights = support ),
959- np .average (f1 , weights = support ),
960- np .average (geo_mean , weights = support ),
961- np .average (iba , weights = support ),
988+ for score_name , score_value in zip (
989+ headers [1 :- 1 ],
990+ [
991+ np .average (precision , weights = support ),
992+ np .average (recall , weights = support ),
993+ np .average (specificity , weights = support ),
994+ np .average (f1 , weights = support ),
995+ np .average (geo_mean , weights = support ),
996+ np .average (iba , weights = support ),
997+ ]
962998 ):
963- values += ["{0:0.{1}f}" .format (v , digits )]
964- values += ["{}" .format (np .sum (support ))]
999+ values += ["{0:0.{1}f}" .format (score_value , digits )]
1000+ report_dict [f"avg_{ score_name } " ] = score_value
1001+ values += [f"{ np .sum (support )} " ]
9651002 report += fmt % tuple (values )
1003+ report_dict ["total_support" ] = np .sum (support )
1004+
1005+ if output_dict :
1006+ return report_dict
9661007 return report
0 commit comments