@@ -828,6 +828,93 @@ public CalibratedBinaryClassificationMetrics Evaluate(IDataView data, string lab
828828 return result ;
829829 }
830830
831+ /// <summary>
832+ /// Evaluates scored binary classification data and generates precision recall curve data.
833+ /// </summary>
834+ /// <param name="data">The scored data.</param>
835+ /// <param name="label">The name of the label column in <paramref name="data"/>.</param>
836+ /// <param name="score">The name of the score column in <paramref name="data"/>.</param>
837+ /// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param>
838+ /// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
839+ /// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param>
840+ /// <returns>The evaluation results for these calibrated outputs.</returns>
841+ public CalibratedBinaryClassificationMetrics EvaluateWithPRCurve (
842+ IDataView data ,
843+ string label ,
844+ string score ,
845+ string probability ,
846+ string predictedLabel ,
847+ out List < BinaryPrecisionRecallDataPoint > prCurve )
848+ {
849+ Host . CheckValue ( data , nameof ( data ) ) ;
850+ Host . CheckNonEmpty ( label , nameof ( label ) ) ;
851+ Host . CheckNonEmpty ( score , nameof ( score ) ) ;
852+ Host . CheckNonEmpty ( probability , nameof ( probability ) ) ;
853+ Host . CheckNonEmpty ( predictedLabel , nameof ( predictedLabel ) ) ;
854+
855+ var roles = new RoleMappedData ( data , opt : false ,
856+ RoleMappedSchema . ColumnRole . Label . Bind ( label ) ,
857+ RoleMappedSchema . CreatePair ( AnnotationUtils . Const . ScoreValueKind . Score , score ) ,
858+ RoleMappedSchema . CreatePair ( AnnotationUtils . Const . ScoreValueKind . Probability , probability ) ,
859+ RoleMappedSchema . CreatePair ( AnnotationUtils . Const . ScoreValueKind . PredictedLabel , predictedLabel ) ) ;
860+
861+ var resultDict = ( ( IEvaluator ) this ) . Evaluate ( roles ) ;
862+ Host . Assert ( resultDict . ContainsKey ( MetricKinds . PrCurve ) ) ;
863+ var prCurveView = resultDict [ MetricKinds . PrCurve ] ;
864+ Host . Assert ( resultDict . ContainsKey ( MetricKinds . OverallMetrics ) ) ;
865+ var overall = resultDict [ MetricKinds . OverallMetrics ] ;
866+
867+ var prCurveResult = new List < BinaryPrecisionRecallDataPoint > ( ) ;
868+ using ( var cursor = prCurveView . GetRowCursorForAllColumns ( ) )
869+ {
870+ GetPrecisionRecallDataPointGetters ( prCurveView , cursor ,
871+ out ValueGetter < float > thresholdGetter ,
872+ out ValueGetter < double > precisionGetter ,
873+ out ValueGetter < double > recallGetter ,
874+ out ValueGetter < double > fprGetter ) ;
875+
876+ while ( cursor . MoveNext ( ) )
877+ {
878+ prCurveResult . Add ( new BinaryPrecisionRecallDataPoint ( thresholdGetter , precisionGetter , recallGetter , fprGetter ) ) ;
879+ }
880+ }
881+ prCurve = prCurveResult ;
882+
883+ CalibratedBinaryClassificationMetrics result ;
884+ using ( var cursor = overall . GetRowCursorForAllColumns ( ) )
885+ {
886+ var moved = cursor . MoveNext ( ) ;
887+ Host . Assert ( moved ) ;
888+ result = new CalibratedBinaryClassificationMetrics ( Host , cursor ) ;
889+ moved = cursor . MoveNext ( ) ;
890+ Host . Assert ( ! moved ) ;
891+ }
892+
893+ return result ;
894+ }
895+
896+ private void GetPrecisionRecallDataPointGetters ( IDataView prCurveView ,
897+ DataViewRowCursor cursor ,
898+ out ValueGetter < float > thresholdGetter ,
899+ out ValueGetter < double > precisionGetter ,
900+ out ValueGetter < double > recallGetter ,
901+ out ValueGetter < double > fprGetter )
902+ {
903+ var thresholdColumn = prCurveView . Schema . GetColumnOrNull ( BinaryClassifierEvaluator . Threshold ) ;
904+ var precisionColumn = prCurveView . Schema . GetColumnOrNull ( BinaryClassifierEvaluator . Precision ) ;
905+ var recallColumn = prCurveView . Schema . GetColumnOrNull ( BinaryClassifierEvaluator . Recall ) ;
906+ var fprColumn = prCurveView . Schema . GetColumnOrNull ( BinaryClassifierEvaluator . FalsePositiveRate ) ;
907+ Host . Assert ( thresholdColumn != null ) ;
908+ Host . Assert ( precisionColumn != null ) ;
909+ Host . Assert ( recallColumn != null ) ;
910+ Host . Assert ( fprColumn != null ) ;
911+
912+ thresholdGetter = cursor . GetGetter < float > ( ( DataViewSchema . Column ) thresholdColumn ) ;
913+ precisionGetter = cursor . GetGetter < double > ( ( DataViewSchema . Column ) precisionColumn ) ;
914+ recallGetter = cursor . GetGetter < double > ( ( DataViewSchema . Column ) recallColumn ) ;
915+ fprGetter = cursor . GetGetter < double > ( ( DataViewSchema . Column ) fprColumn ) ;
916+ }
917+
831918 /// <summary>
832919 /// Evaluates scored binary classification data, without probability-based metrics.
833920 /// </summary>
@@ -864,6 +951,69 @@ public BinaryClassificationMetrics Evaluate(IDataView data, string label, string
864951 }
865952 return result ;
866953 }
954+
955+ /// <summary>
956+ /// Evaluates scored binary classification data, without probability-based metrics
957+ /// and generates precision recall curve data.
958+ /// </summary>
959+ /// <param name="data">The scored data.</param>
960+ /// <param name="label">The name of the label column in <paramref name="data"/>.</param>
961+ /// <param name="score">The name of the score column in <paramref name="data"/>.</param>
962+ /// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
963+ /// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param>
964+ /// <returns>The evaluation results for these uncalibrated outputs.</returns>
965+ /// <seealso cref="Evaluate(IDataView, string, string, string)"/>
966+ public BinaryClassificationMetrics EvaluateWithPRCurve (
967+ IDataView data ,
968+ string label ,
969+ string score ,
970+ string predictedLabel ,
971+ out List < BinaryPrecisionRecallDataPoint > prCurve )
972+ {
973+ Host . CheckValue ( data , nameof ( data ) ) ;
974+ Host . CheckNonEmpty ( label , nameof ( label ) ) ;
975+ Host . CheckNonEmpty ( score , nameof ( score ) ) ;
976+ Host . CheckNonEmpty ( predictedLabel , nameof ( predictedLabel ) ) ;
977+
978+ var roles = new RoleMappedData ( data , opt : false ,
979+ RoleMappedSchema . ColumnRole . Label . Bind ( label ) ,
980+ RoleMappedSchema . CreatePair ( AnnotationUtils . Const . ScoreValueKind . Score , score ) ,
981+ RoleMappedSchema . CreatePair ( AnnotationUtils . Const . ScoreValueKind . PredictedLabel , predictedLabel ) ) ;
982+
983+ var resultDict = ( ( IEvaluator ) this ) . Evaluate ( roles ) ;
984+ Host . Assert ( resultDict . ContainsKey ( MetricKinds . PrCurve ) ) ;
985+ var prCurveView = resultDict [ MetricKinds . PrCurve ] ;
986+ Host . Assert ( resultDict . ContainsKey ( MetricKinds . OverallMetrics ) ) ;
987+ var overall = resultDict [ MetricKinds . OverallMetrics ] ;
988+
989+ var prCurveResult = new List < BinaryPrecisionRecallDataPoint > ( ) ;
990+ using ( var cursor = prCurveView . GetRowCursorForAllColumns ( ) )
991+ {
992+ GetPrecisionRecallDataPointGetters ( prCurveView , cursor ,
993+ out ValueGetter < float > thresholdGetter ,
994+ out ValueGetter < double > precisionGetter ,
995+ out ValueGetter < double > recallGetter ,
996+ out ValueGetter < double > fprGetter ) ;
997+
998+ while ( cursor . MoveNext ( ) )
999+ {
1000+ prCurveResult . Add ( new BinaryPrecisionRecallDataPoint ( thresholdGetter , precisionGetter , recallGetter , fprGetter ) ) ;
1001+ }
1002+ }
1003+ prCurve = prCurveResult ;
1004+
1005+ BinaryClassificationMetrics result ;
1006+ using ( var cursor = overall . GetRowCursorForAllColumns ( ) )
1007+ {
1008+ var moved = cursor . MoveNext ( ) ;
1009+ Host . Assert ( moved ) ;
1010+ result = new BinaryClassificationMetrics ( Host , cursor ) ;
1011+ moved = cursor . MoveNext ( ) ;
1012+ Host . Assert ( ! moved ) ;
1013+ }
1014+
1015+ return result ;
1016+ }
8671017 }
8681018
8691019 internal sealed class BinaryPerInstanceEvaluator : PerInstanceEvaluatorBase
0 commit comments