22// The .NET Foundation licenses this file to you under the MIT license.
33// See the LICENSE file in the project root for more information.
44
5+ using System ;
56using Microsoft . ML . Data ;
7+ using Microsoft . ML . Runtime ;
68
79namespace Microsoft . ML . AutoML
810{
911 internal class RankingMetricsAgent : IMetricsAgent < RankingMetrics >
1012 {
1113 private readonly MLContext _mlContext ;
1214 private readonly RankingMetric _optimizingMetric ;
13- private readonly int _dcgTruncationLevel ;
15+ private readonly uint _dcgTruncationLevel ;
1416
15- public RankingMetricsAgent ( MLContext mlContext , RankingMetric metric , int optimizationMetricTruncationLevel )
17+ public RankingMetricsAgent ( MLContext mlContext , RankingMetric metric , uint optimizationMetricTruncationLevel )
1618 {
1719 _mlContext = mlContext ;
1820 _optimizingMetric = metric ;
1921
22+ if ( optimizationMetricTruncationLevel <= 0 )
23+ throw _mlContext . ExceptUserArg ( nameof ( optimizationMetricTruncationLevel ) , "DCG Truncation Level must be greater than 0" ) ;
24+
2025 // We want to make sure we always have at least 10 results. Getting extra results adds no measurable performance
2126 // impact, so err on the side of more.
22- _dcgTruncationLevel = System . Math . Max ( 10 , 2 * optimizationMetricTruncationLevel ) ;
27+ _dcgTruncationLevel = optimizationMetricTruncationLevel ;
2328 }
2429
2530 // Optimizing metric used: NDCG@10 and DCG@10
@@ -33,11 +38,9 @@ public double GetScore(RankingMetrics metrics)
3338 switch ( _optimizingMetric )
3439 {
3540 case RankingMetric . Ndcg :
36- return ( metrics . NormalizedDiscountedCumulativeGains . Count >= 10 ) ? metrics . NormalizedDiscountedCumulativeGains [ 9 ] :
37- metrics . NormalizedDiscountedCumulativeGains [ metrics . NormalizedDiscountedCumulativeGains . Count - 1 ] ;
41+ return metrics . NormalizedDiscountedCumulativeGains [ Math . Min ( metrics . NormalizedDiscountedCumulativeGains . Count , ( int ) _dcgTruncationLevel ) - 1 ] ;
3842 case RankingMetric . Dcg :
39- return ( metrics . DiscountedCumulativeGains . Count >= 10 ) ? metrics . DiscountedCumulativeGains [ 9 ] :
40- metrics . DiscountedCumulativeGains [ metrics . DiscountedCumulativeGains . Count - 1 ] ;
43+ return metrics . DiscountedCumulativeGains [ Math . Min ( metrics . DiscountedCumulativeGains . Count , ( int ) _dcgTruncationLevel ) - 1 ] ;
4144 default :
4245 throw MetricsAgentUtil . BuildMetricNotSupportedException ( _optimizingMetric ) ;
4346 }
@@ -66,7 +69,7 @@ public RankingMetrics EvaluateMetrics(IDataView data, string labelColumn, string
6669 {
6770 var rankingEvalOptions = new RankingEvaluatorOptions
6871 {
69- DcgTruncationLevel = _dcgTruncationLevel
72+ DcgTruncationLevel = Math . Max ( 10 , 2 * ( int ) _dcgTruncationLevel )
7073 } ;
7174
7275 return _mlContext . Ranking . Evaluate ( data , rankingEvalOptions , labelColumn , groupIdColumn ) ;
0 commit comments