|
2 | 2 | // The .NET Foundation licenses this file to you under the MIT license. |
3 | 3 | // See the LICENSE file in the project root for more information. |
4 | 4 |
|
| 5 | +using System; |
5 | 6 | using System.Collections.Generic; |
6 | 7 | using System.Collections.Immutable; |
| 8 | +using System.Linq; |
7 | 9 | using Microsoft.ML.Runtime; |
8 | 10 |
|
9 | 11 | namespace Microsoft.ML.Data |
@@ -71,16 +73,22 @@ public sealed class MulticlassClassificationMetrics |
71 | 73 | public double MicroAccuracy { get; } |
72 | 74 |
|
73 | 75 | /// <summary> |
74 | | - /// If <see cref="TopKPredictionCount"/> is positive, this is the relative number of examples where |
75 | | - /// the true label is one of the top-k predicted labels by the predictor. |
| 76 | + /// Convenience method for "TopKAccuracyForAllK[TopKPredictionCount - 1]". If <see cref="TopKPredictionCount"/> is positive, |
| 77 | + /// this is the relative number of examples where |
| 78 | + /// the true label is one of the top K predicted labels by the predictor. |
76 | 79 | /// </summary> |
77 | | - public double TopKAccuracy { get; } |
| 80 | + public double TopKAccuracy => TopKAccuracyForAllK?.LastOrDefault() ?? 0; |
78 | 81 |
|
79 | 82 | /// <summary> |
80 | | - /// If positive, this indicates the K in <see cref="TopKAccuracy"/>. |
| 83 | + /// If positive, this indicates the K in <see cref="TopKAccuracy"/> and <see cref="TopKAccuracyForAllK"/>. |
81 | 84 | /// </summary> |
82 | 85 | public int TopKPredictionCount { get; } |
83 | 86 |
|
| 87 | + /// <summary> |
| 88 | + /// Returns the top K accuracy for all K from 1 to the value of TopKPredictionCount. |
| 89 | + /// </summary> |
| 90 | + public IReadOnlyList<double> TopKAccuracyForAllK { get; } |
| 91 | + |
84 | 92 | /// <summary> |
85 | 93 | /// Gets the log-loss of the classifier for each class. Log-loss measures the performance of a classifier |
86 | 94 | /// with respect to how much the predicted probabilities diverge from the true class label. Lower |
@@ -115,29 +123,30 @@ internal MulticlassClassificationMetrics(IHost host, DataViewRow overallResult, |
115 | 123 | LogLoss = FetchDouble(MulticlassClassificationEvaluator.LogLoss); |
116 | 124 | LogLossReduction = FetchDouble(MulticlassClassificationEvaluator.LogLossReduction); |
117 | 125 | TopKPredictionCount = topKPredictionCount; |
| 126 | + |
118 | 127 | if (topKPredictionCount > 0) |
119 | | - TopKAccuracy = FetchDouble(MulticlassClassificationEvaluator.TopKAccuracy); |
| 128 | + TopKAccuracyForAllK = RowCursorUtils.Fetch<VBuffer<double>>(host, overallResult, MulticlassClassificationEvaluator.AllTopKAccuracy).DenseValues().ToImmutableArray(); |
120 | 129 |
|
121 | 130 | var perClassLogLoss = RowCursorUtils.Fetch<VBuffer<double>>(host, overallResult, MulticlassClassificationEvaluator.PerClassLogLoss); |
122 | 131 | PerClassLogLoss = perClassLogLoss.DenseValues().ToImmutableArray(); |
123 | 132 | ConfusionMatrix = MetricWriter.GetConfusionMatrix(host, confusionMatrix, binary: false, perClassLogLoss.Length); |
124 | 133 | } |
125 | 134 |
|
126 | 135 | internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction, |
127 | | - int topKPredictionCount, double topKAccuracy, double[] perClassLogLoss) |
| 136 | + int topKPredictionCount, double[] topKAccuracies, double[] perClassLogLoss) |
128 | 137 | { |
129 | 138 | MicroAccuracy = accuracyMicro; |
130 | 139 | MacroAccuracy = accuracyMacro; |
131 | 140 | LogLoss = logLoss; |
132 | 141 | LogLossReduction = logLossReduction; |
133 | 142 | TopKPredictionCount = topKPredictionCount; |
134 | | - TopKAccuracy = topKAccuracy; |
| 143 | + TopKAccuracyForAllK = topKAccuracies; |
135 | 144 | PerClassLogLoss = perClassLogLoss.ToImmutableArray(); |
136 | 145 | } |
137 | 146 |
|
138 | 147 | internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction, |
139 | | - int topKPredictionCount, double topKAccuracy, double[] perClassLogLoss, ConfusionMatrix confusionMatrix) |
140 | | - : this(accuracyMicro, accuracyMacro, logLoss, logLossReduction, topKPredictionCount, topKAccuracy, perClassLogLoss) |
| 148 | + int topKPredictionCount, double[] topKAccuracies, double[] perClassLogLoss, ConfusionMatrix confusionMatrix) |
| 149 | + : this(accuracyMicro, accuracyMacro, logLoss, logLossReduction, topKPredictionCount, topKAccuracies, perClassLogLoss) |
141 | 150 | { |
142 | 151 | ConfusionMatrix = confusionMatrix; |
143 | 152 | } |
|
0 commit comments