@@ -1348,20 +1348,49 @@ internal static class MetricWriter
13481348 /// is assigned the string representation of the weighted confusion table. Otherwise it is assigned null.</param>
13491349 /// <param name="binary">Indicates whether the confusion table is for binary classification.</param>
13501350 /// <param name="sample">Indicates how many classes to sample from the confusion table (-1 indicates no sampling)</param>
1351- public static string GetConfusionTable ( IHost host , IDataView confusionDataView , out string weightedConfusionTable , bool binary = true , int sample = - 1 )
1351+ public static string GetConfusionTableAsFormattedString ( IHost host , IDataView confusionDataView , out string weightedConfusionTable , bool binary = true , int sample = - 1 )
13521352 {
13531353 host . CheckValue ( confusionDataView , nameof ( confusionDataView ) ) ;
13541354 host . CheckParam ( sample == - 1 || sample >= 2 , nameof ( sample ) , "Should be -1 to indicate no sampling, or at least 2" ) ;
13551355
1356- // Get the class names.
1357- int countCol ;
1358- host . Check ( confusionDataView . Schema . TryGetColumnIndex ( MetricKinds . ColumnNames . Count , out countCol ) , "Did not find the count column" ) ;
1359- var type = confusionDataView . Schema [ countCol ] . Annotations . Schema . GetColumnOrNull ( AnnotationUtils . Kinds . SlotNames ) ? . Type as VectorDataViewType ;
1360- host . Check ( type != null && type . IsKnownSize && type . ItemType is TextDataViewType , "The Count column does not have a text vector metadata of kind SlotNames." ) ;
1356+ var weightColumn = confusionDataView . Schema . GetColumnOrNull ( MetricKinds . ColumnNames . Weight ) ;
1357+ bool isWeighted = weightColumn . HasValue ;
13611358
1359+ var confusionMatrix = GetConfusionMatrix ( host , confusionDataView , binary , sample , false ) ;
1360+ var confusionTableString = GetConfusionTableAsString ( confusionMatrix , false ) ;
1361+
1362+ // If there is a Weight column, return the weighted confusionMatrix as well, from this function.
1363+ if ( isWeighted )
1364+ {
1365+ confusionMatrix = GetConfusionMatrix ( host , confusionDataView , binary , sample , true ) ;
1366+ weightedConfusionTable = GetConfusionTableAsString ( confusionMatrix , true ) ;
1367+ }
1368+ else
1369+ weightedConfusionTable = null ;
1370+
1371+ return confusionTableString ;
1372+ }
1373+
1374+ public static ConfusionMatrix GetConfusionMatrix ( IHost host , IDataView confusionDataView , bool binary = true , int sample = - 1 , bool getWeighted = false )
1375+ {
1376+ host . CheckValue ( confusionDataView , nameof ( confusionDataView ) ) ;
1377+ host . CheckParam ( sample == - 1 || sample >= 2 , nameof ( sample ) , "Should be -1 to indicate no sampling, or at least 2" ) ;
1378+
1379+ // check that there is a Weight column, if isWeighted parameter is set to true.
1380+ var weightColumn = confusionDataView . Schema . GetColumnOrNull ( MetricKinds . ColumnNames . Weight ) ;
1381+ if ( getWeighted )
1382+ host . CheckParam ( weightColumn . HasValue , nameof ( getWeighted ) , "There is no Weight column in the confusionMatrix data view." ) ;
1383+
1384+ // Get the counts names.
1385+ var countColumn = confusionDataView . Schema [ MetricKinds . ColumnNames . Count ] ;
1386+ var type = countColumn . Annotations . Schema . GetColumnOrNull ( AnnotationUtils . Kinds . SlotNames ) ? . Type as VectorDataViewType ;
1387+ //"The Count column does not have a text vector metadata of kind SlotNames."
1388+ host . Assert ( type != null && type . IsKnownSize && type . ItemType is TextDataViewType ) ;
1389+
1390+ // Get the class names
13621391 var labelNames = default ( VBuffer < ReadOnlyMemory < char > > ) ;
1363- confusionDataView . Schema [ countCol ] . Annotations . GetValue ( AnnotationUtils . Kinds . SlotNames , ref labelNames ) ;
1364- host . Check ( labelNames . IsDense , "Slot names vector must be dense" ) ;
1392+ countColumn . Annotations . GetValue ( AnnotationUtils . Kinds . SlotNames , ref labelNames ) ;
1393+ host . Assert ( labelNames . IsDense , "Slot names vector must be dense" ) ;
13651394
13661395 int numConfusionTableLabels = sample < 0 ? labelNames . Length : Math . Min ( labelNames . Length , sample ) ;
13671396
@@ -1387,32 +1416,32 @@ public static string GetConfusionTable(IHost host, IDataView confusionDataView,
13871416
13881417 double [ ] precisionSums ;
13891418 double [ ] recallSums ;
1390- var confusionTable = GetConfusionTableAsArray ( confusionDataView , countCol , labelNames . Length ,
1391- labelIndexToConfIndexMap , numConfusionTableLabels , out precisionSums , out recallSums ) ;
1419+ double [ ] [ ] confusionTable ;
13921420
1393- var predictedLabelNames = GetPredictedLabelNames ( in labelNames , labelIndexToConfIndexMap ) ;
1394- var confusionTableString = GetConfusionTableAsString ( confusionTable , recallSums , precisionSums ,
1395- predictedLabelNames ,
1396- sampled : numConfusionTableLabels < labelNames . Length , binary : binary ) ;
1421+ if ( getWeighted )
1422+ confusionTable = GetConfusionTableAsArray ( confusionDataView , weightColumn . Value . Index , labelNames . Length ,
1423+ labelIndexToConfIndexMap , numConfusionTableLabels , out precisionSums , out recallSums ) ;
1424+ else
1425+ confusionTable = GetConfusionTableAsArray ( confusionDataView , countColumn . Index , labelNames . Length ,
1426+ labelIndexToConfIndexMap , numConfusionTableLabels , out precisionSums , out recallSums ) ;
13971427
1398- int weightIndex ;
1399- if ( confusionDataView . Schema . TryGetColumnIndex ( MetricKinds . ColumnNames . Weight , out weightIndex ) )
1428+ double [ ] precision = new double [ numConfusionTableLabels ] ;
1429+ double [ ] recall = new double [ numConfusionTableLabels ] ;
1430+ for ( int i = 0 ; i < numConfusionTableLabels ; i ++ )
14001431 {
1401- confusionTable = GetConfusionTableAsArray ( confusionDataView , weightIndex , labelNames . Length ,
1402- labelIndexToConfIndexMap , numConfusionTableLabels , out precisionSums , out recallSums ) ;
1403- weightedConfusionTable = GetConfusionTableAsString ( confusionTable , recallSums , precisionSums ,
1404- predictedLabelNames ,
1405- sampled : numConfusionTableLabels < labelNames . Length , prefix : "Weighted " , binary : binary ) ;
1432+ recall [ i ] = recallSums [ i ] > 0 ? confusionTable [ i ] [ i ] / recallSums [ i ] : 0 ;
1433+ precision [ i ] = precisionSums [ i ] > 0 ? confusionTable [ i ] [ i ] / precisionSums [ i ] : 0 ;
14061434 }
1407- else
1408- weightedConfusionTable = null ;
14091435
1410- return confusionTableString ;
1436+ var predictedLabelNames = GetPredictedLabelNames ( in labelNames , labelIndexToConfIndexMap ) ;
1437+ bool sampled = numConfusionTableLabels < labelNames . Length ;
1438+
1439+ return new ConfusionMatrix ( host , precision , recall , confusionTable , predictedLabelNames , sampled , binary ) ;
14111440 }
14121441
14131442 private static List < ReadOnlyMemory < char > > GetPredictedLabelNames ( in VBuffer < ReadOnlyMemory < char > > labelNames , int [ ] labelIndexToConfIndexMap )
14141443 {
1415- List < ReadOnlyMemory < char > > result = new List < ReadOnlyMemory < char > > ( ) ;
1444+ List < ReadOnlyMemory < char > > result = new List < ReadOnlyMemory < char > > ( ) ;
14161445 var values = labelNames . GetValues ( ) ;
14171446 for ( int i = 0 ; i < values . Length ; i ++ )
14181447 {
@@ -1553,13 +1582,13 @@ private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView dat
15531582 }
15541583
15551584 // Get a string representation of a confusion table.
1556- private static string GetConfusionTableAsString ( double [ ] [ ] confusionTable , double [ ] rowSums , double [ ] columnSums ,
1557- List < ReadOnlyMemory < char > > predictedLabelNames , string prefix = "" , bool sampled = false , bool binary = true )
1585+ internal static string GetConfusionTableAsString ( ConfusionMatrix confusionMatrix , bool isWeighted )
15581586 {
1559- int numLabels = Utils . Size ( confusionTable ) ;
1587+ string prefix = isWeighted ? "Weighted " : "" ;
1588+ int numLabels = confusionMatrix ? . Counts == null ? 0 : confusionMatrix . Counts . Count ;
15601589
15611590 int colWidth = numLabels == 2 ? 8 : 5 ;
1562- int maxNameLen = predictedLabelNames . Max ( name => name . Length ) ;
1591+ int maxNameLen = confusionMatrix . PredictedClassesIndicators . Max ( name => name . Length ) ;
15631592 // If the names are too long to fit in the column header, we back off to using class indices
15641593 // in the header. This will also require putting the indices in the row, but it's better than
15651594 // the alternative of having ambiguous abbreviated column headers, or having a table potentially
@@ -1572,7 +1601,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
15721601 {
15731602 // The row label will also include the index, so a user can easily match against the header.
15741603 // In such a case, a label like "Foo" would be presented as something like "5. Foo".
1575- rowDigitLen = Math . Max ( predictedLabelNames . Count - 1 , 0 ) . ToString ( ) . Length ;
1604+ rowDigitLen = Math . Max ( confusionMatrix . PredictedClassesIndicators . Count - 1 , 0 ) . ToString ( ) . Length ;
15761605 Contracts . Assert ( rowDigitLen >= 1 ) ;
15771606 rowLabelLen += rowDigitLen + 2 ;
15781607 }
@@ -1591,10 +1620,11 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
15911620 else
15921621 rowLabelFormat = string . Format ( "{{1,{0}}} ||" , paddingLen ) ;
15931622
1623+ var confusionTable = confusionMatrix . Counts ;
15941624 var sb = new StringBuilder ( ) ;
1595- if ( numLabels == 2 && binary )
1625+ if ( numLabels == 2 && confusionMatrix . IsBinary )
15961626 {
1597- var positiveCaps = predictedLabelNames [ 0 ] . ToString ( ) . ToUpper ( ) ;
1627+ var positiveCaps = confusionMatrix . PredictedClassesIndicators [ 0 ] . ToString ( ) . ToUpper ( ) ;
15981628
15991629 var numTruePos = confusionTable [ 0 ] [ 0 ] ;
16001630 var numFalseNeg = confusionTable [ 0 ] [ 1 ] ;
@@ -1607,7 +1637,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
16071637
16081638 sb . AppendLine ( ) ;
16091639 sb . AppendFormat ( "{0}Confusion table" , prefix ) ;
1610- if ( sampled )
1640+ if ( confusionMatrix . IsSampled )
16111641 sb . AppendLine ( " (sampled)" ) ;
16121642 else
16131643 sb . AppendLine ( ) ;
@@ -1619,7 +1649,7 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
16191649 sb . AppendFormat ( "PREDICTED {0}||" , pad ) ;
16201650 string format = string . Format ( " {{{0},{1}}} |" , useNumbersInHeader ? 0 : 1 , colWidth ) ;
16211651 for ( int i = 0 ; i < numLabels ; i ++ )
1622- sb . AppendFormat ( format , i , predictedLabelNames [ i ] ) ;
1652+ sb . AppendFormat ( format , i , confusionMatrix . PredictedClassesIndicators [ i ] ) ;
16231653 sb . AppendLine ( " Recall" ) ;
16241654 sb . AppendFormat ( "TRUTH {0}||" , pad ) ;
16251655 for ( int i = 0 ; i < numLabels ; i ++ )
@@ -1631,11 +1661,10 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
16311661 string . IsNullOrWhiteSpace ( prefix ) ? "N0" : "F1" ) ;
16321662 for ( int i = 0 ; i < numLabels ; i ++ )
16331663 {
1634- sb . AppendFormat ( rowLabelFormat , i , predictedLabelNames [ i ] ) ;
1664+ sb . AppendFormat ( rowLabelFormat , i , confusionMatrix . PredictedClassesIndicators [ i ] ) ;
16351665 for ( int j = 0 ; j < numLabels ; j ++ )
16361666 sb . AppendFormat ( format2 , confusionTable [ i ] [ j ] ) ;
1637- Double recall = rowSums [ i ] > 0 ? confusionTable [ i ] [ i ] / rowSums [ i ] : 0 ;
1638- sb . AppendFormat ( " {0,5:F4}" , recall ) ;
1667+ sb . AppendFormat ( " {0,5:F4}" , confusionMatrix . PerClassRecall [ i ] ) ;
16391668 sb . AppendLine ( ) ;
16401669 }
16411670 sb . AppendFormat ( " {0}||" , pad ) ;
@@ -1645,10 +1674,8 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl
16451674 sb . AppendFormat ( "Precision {0}||" , pad ) ;
16461675 format = string . Format ( "{{0,{0}:N4}} |" , colWidth + 1 ) ;
16471676 for ( int i = 0 ; i < numLabels ; i ++ )
1648- {
1649- Double precision = columnSums [ i ] > 0 ? confusionTable [ i ] [ i ] / columnSums [ i ] : 0 ;
1650- sb . AppendFormat ( format , precision ) ;
1651- }
1677+ sb . AppendFormat ( format , confusionMatrix . PerClassPrecision [ i ] ) ;
1678+
16521679 sb . AppendLine ( ) ;
16531680 return sb . ToString ( ) ;
16541681 }
@@ -1701,7 +1728,7 @@ public static void PrintWarnings(IChannel ch, Dictionary<string, IDataView> metr
17011728 if ( metrics . TryGetValue ( MetricKinds . Warnings , out warnings ) )
17021729 {
17031730 var warningTextColumn = warnings . Schema . GetColumnOrNull ( MetricKinds . ColumnNames . WarningText ) ;
1704- if ( warningTextColumn != null && warningTextColumn . HasValue && warningTextColumn . Value . Type is TextDataViewType )
1731+ if ( warningTextColumn != null && warningTextColumn . HasValue && warningTextColumn . Value . Type is TextDataViewType )
17051732 {
17061733 using ( var cursor = warnings . GetRowCursor ( warnings . Schema [ MetricKinds . ColumnNames . WarningText ] ) )
17071734 {
0 commit comments