Skip to content

Commit ed6885d

Browse files
committed
Changes based on PR comments.
1 parent 70989b6 commit ed6885d

File tree

5 files changed

+19
-15
lines changed

5 files changed

+19
-15
lines changed

src/Microsoft.ML.AutoML/API/RankingExperiment.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public sealed class RankingExperimentSettings : ExperimentSettings
3333
/// <value>
3434
/// The default value is 10.
3535
/// </value>
36-
public int OptimizationMetricTruncationLevel { get; set; }
36+
public uint OptimizationMetricTruncationLevel { get; set; }
3737

3838
public RankingExperimentSettings()
3939
{
@@ -80,7 +80,7 @@ public static class RankingExperimentResultExtensions
8080
/// <param name="metric">Metric to consider when selecting the best run.</param>
8181
/// <param name="optimizationMetricTruncationLevel">Maximum truncation level for computing (N)DCG. Defaults to 10.</param>
8282
/// <returns>The best experiment run.</returns>
83-
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, int optimizationMetricTruncationLevel = 10)
83+
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, uint optimizationMetricTruncationLevel = 10)
8484
{
8585
var metricsAgent = new RankingMetricsAgent(null, metric, optimizationMetricTruncationLevel);
8686
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
@@ -94,7 +94,7 @@ public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingM
9494
/// <param name="metric">Metric to consider when selecting the best run.</param>
9595
/// <param name="optimizationMetricTruncationLevel">Maximum truncation level for computing (N)DCG. Defaults to 10.</param>
9696
/// <returns>The best experiment run.</returns>
97-
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, int optimizationMetricTruncationLevel = 10)
97+
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, uint optimizationMetricTruncationLevel = 10)
9898
{
9999
var metricsAgent = new RankingMetricsAgent(null, metric, optimizationMetricTruncationLevel);
100100
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;

src/Microsoft.ML.AutoML/Experiment/MetricsAgents/RankingMetricsAgent.cs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,29 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using Microsoft.ML.Data;
7+
using Microsoft.ML.Runtime;
68

79
namespace Microsoft.ML.AutoML
810
{
911
internal class RankingMetricsAgent : IMetricsAgent<RankingMetrics>
1012
{
1113
private readonly MLContext _mlContext;
1214
private readonly RankingMetric _optimizingMetric;
13-
private readonly int _dcgTruncationLevel;
15+
private readonly uint _dcgTruncationLevel;
1416

15-
public RankingMetricsAgent(MLContext mlContext, RankingMetric metric, int optimizationMetricTruncationLevel)
17+
public RankingMetricsAgent(MLContext mlContext, RankingMetric metric, uint optimizationMetricTruncationLevel)
1618
{
1719
_mlContext = mlContext;
1820
_optimizingMetric = metric;
1921

22+
if (optimizationMetricTruncationLevel <= 0)
23+
throw _mlContext.ExceptUserArg(nameof(optimizationMetricTruncationLevel), "DCG Truncation Level must be greater than 0");
24+
2025
// We want to make sure we always have at least 10 results. Getting extra results adds no measurable performance
2126
// impact, so err on the side of more.
22-
_dcgTruncationLevel = System.Math.Max(10, 2 * optimizationMetricTruncationLevel);
27+
_dcgTruncationLevel = optimizationMetricTruncationLevel;
2328
}
2429

2530
// Optimizing metric used: NDCG@10 and DCG@10
@@ -33,11 +38,9 @@ public double GetScore(RankingMetrics metrics)
3338
switch (_optimizingMetric)
3439
{
3540
case RankingMetric.Ndcg:
36-
return (metrics.NormalizedDiscountedCumulativeGains.Count >= 10) ? metrics.NormalizedDiscountedCumulativeGains[9] :
37-
metrics.NormalizedDiscountedCumulativeGains[metrics.NormalizedDiscountedCumulativeGains.Count - 1];
41+
return metrics.NormalizedDiscountedCumulativeGains[Math.Min(metrics.NormalizedDiscountedCumulativeGains.Count, (int)_dcgTruncationLevel) - 1];
3842
case RankingMetric.Dcg:
39-
return (metrics.DiscountedCumulativeGains.Count >= 10) ? metrics.DiscountedCumulativeGains[9] :
40-
metrics.DiscountedCumulativeGains[metrics.DiscountedCumulativeGains.Count-1];
43+
return metrics.DiscountedCumulativeGains[Math.Min(metrics.DiscountedCumulativeGains.Count, (int)_dcgTruncationLevel) - 1];
4144
default:
4245
throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
4346
}
@@ -66,7 +69,7 @@ public RankingMetrics EvaluateMetrics(IDataView data, string labelColumn, string
6669
{
6770
var rankingEvalOptions = new RankingEvaluatorOptions
6871
{
69-
DcgTruncationLevel = _dcgTruncationLevel
72+
DcgTruncationLevel = Math.Max(10, 2 * (int)_dcgTruncationLevel)
7073
};
7174

7275
return _mlContext.Ranking.Evaluate(data, rankingEvalOptions, labelColumn, groupIdColumn);

src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public static RunDetail<MulticlassClassificationMetrics> GetBestRun(IEnumerable<
3535
}
3636

3737
public static RunDetail<RankingMetrics> GetBestRun(IEnumerable<RunDetail<RankingMetrics>> results,
38-
RankingMetric metric, int dcgTruncationLevel)
38+
RankingMetric metric, uint dcgTruncationLevel)
3939
{
4040
var metricsAgent = new RankingMetricsAgent(null, metric, dcgTruncationLevel);
4141

test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ public void AutoFitRankingTest()
176176
var settings = new RankingExperimentSettings()
177177
{
178178
MaxExperimentTimeInSeconds = 5,
179-
OptimizationMetricTruncationLevel = 5
179+
OptimizationMetricTruncationLevel = 3
180180
};
181181
var experiment = mlContext.Auto()
182182
.CreateRankingExperiment(settings);
@@ -203,6 +203,7 @@ public void AutoFitRankingTest()
203203
for (int i = 0; i < experimentResults.Length; i++)
204204
{
205205
RunDetail<RankingMetrics> bestRun = experimentResults[i].BestRun;
206+
// The user requested 3, but we always return at least 10.
206207
Assert.Equal(10, bestRun.ValidationMetrics.DiscountedCumulativeGains.Count);
207208
Assert.Equal(10, bestRun.ValidationMetrics.NormalizedDiscountedCumulativeGains.Count);
208209
Assert.True(experimentResults[i].RunDetails.Count() > 0);

test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ private static double GetScore(RegressionMetrics metrics, RegressionMetric metri
179179
return new RegressionMetricsAgent(null, metric).GetScore(metrics);
180180
}
181181

182-
private static double GetScore(RankingMetrics metrics, RankingMetric metric, int dcgTruncationLevel)
182+
private static double GetScore(RankingMetrics metrics, RankingMetric metric, uint dcgTruncationLevel)
183183
{
184184
return new RankingMetricsAgent(null, metric, dcgTruncationLevel).GetScore(metrics);
185185
}
@@ -202,7 +202,7 @@ private static bool IsPerfectModel(RegressionMetrics metrics, RegressionMetric m
202202
return IsPerfectModel(metricsAgent, metrics);
203203
}
204204

205-
private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric, int dcgTruncationLevel)
205+
private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric, uint dcgTruncationLevel)
206206
{
207207
var metricsAgent = new RankingMetricsAgent(null, metric, dcgTruncationLevel);
208208
return IsPerfectModel(metricsAgent, metrics);

0 commit comments

Comments
 (0)