1515using Microsoft . ML . Internal . Utilities ;
1616using Microsoft . ML . Runtime ;
1717
18- [ assembly: LoadableClass ( typeof ( RankingEvaluator ) , typeof ( RankingEvaluator ) , typeof ( RankingEvaluator . Arguments ) , typeof ( SignatureEvaluator ) ,
18+ [ assembly: LoadableClass ( typeof ( RankingEvaluator ) , typeof ( RankingEvaluator ) , typeof ( RankingEvaluatorOptions ) , typeof ( SignatureEvaluator ) ,
1919 "Ranking Evaluator" , RankingEvaluator . LoadName , "Ranking" , "rank" ) ]
2020
2121[ assembly: LoadableClass ( typeof ( RankingMamlEvaluator ) , typeof ( RankingMamlEvaluator ) , typeof ( RankingMamlEvaluator . Arguments ) , typeof ( SignatureMamlEvaluator ) ,
2626
2727namespace Microsoft . ML . Data
2828{
29- [ BestFriend ]
30- internal sealed class RankingEvaluator : EvaluatorBase < RankingEvaluator . Aggregator >
29+ /// <summary>
30+ /// Options to control the output of the RankingEvaluator
31+ /// </summary>
32+ public sealed class RankingEvaluatorOptions
3133 {
32- public sealed class Arguments
33- {
34- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Maximum truncation level for computing (N)DCG" , ShortName = "t" ) ]
35- public int DcgTruncationLevel = 3 ;
34+ /// <value>
35+ /// Maximum truncation level for computing (N)DCG
36+ /// </value>
37+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Maximum truncation level for computing (N)DCG" , ShortName = "t" ) ]
38+ public int DcgTruncationLevel = 3 ;
3639
37- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Label relevance gains" , ShortName = "gains" ) ]
38- public string LabelGains = "0,3,7,15,31" ;
40+ /// <value>
41+ /// Label relevance gains
42+ /// </value>
43+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Label relevance gains" , ShortName = "gains" ) ]
44+ public string LabelGains = "0,3,7,15,31" ;
3945
40- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Generate per-group (N)DCG" , ShortName = "ogs" ) ]
41- public bool OutputGroupSummary ;
42- }
46+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Generate per-group (N)DCG" , ShortName = "ogs" ) ]
47+ internal bool OutputGroupSummary ;
48+ }
4349
50+ [ BestFriend ]
51+ internal sealed class RankingEvaluator : EvaluatorBase < RankingEvaluator . Aggregator >
52+ {
4453 internal const string LoadName = "RankingEvaluator" ;
4554
4655 public const string Ndcg = "NDCG" ;
@@ -60,24 +69,25 @@ public sealed class Arguments
6069 private readonly bool _groupSummary ;
6170 private readonly Double [ ] _labelGains ;
6271
63- public RankingEvaluator ( IHostEnvironment env , Arguments args )
72+ public RankingEvaluator ( IHostEnvironment env , RankingEvaluatorOptions options )
6473 : base ( env , LoadName )
6574 {
6675 // REVIEW: What kind of checking should be applied to labelGains?
67- if ( args . DcgTruncationLevel <= 0 || args . DcgTruncationLevel > Aggregator . Counters . MaxTruncationLevel )
68- throw Host . ExceptUserArg ( nameof ( args . DcgTruncationLevel ) , "DCG Truncation Level must be between 1 and {0}" , Aggregator . Counters . MaxTruncationLevel ) ;
69- Host . CheckUserArg ( args . LabelGains != null , nameof ( args . LabelGains ) , "Label gains cannot be null" ) ;
76+ // add the setter to utils here
77+ if ( options . DcgTruncationLevel <= 0 )
78+ throw Host . ExceptUserArg ( nameof ( options . DcgTruncationLevel ) , "DCG Truncation Level must be greater than 0" ) ;
79+ Host . CheckUserArg ( options . LabelGains != null , nameof ( options . LabelGains ) , "Label gains cannot be null" ) ;
7080
71- _truncationLevel = args . DcgTruncationLevel ;
72- _groupSummary = args . OutputGroupSummary ;
81+ _truncationLevel = options . DcgTruncationLevel ;
82+ _groupSummary = options . OutputGroupSummary ;
7383
7484 var labelGains = new List < Double > ( ) ;
75- string [ ] gains = args . LabelGains . Split ( ',' ) ;
85+ string [ ] gains = options . LabelGains . Split ( ',' ) ;
7686 for ( int i = 0 ; i < gains . Length ; i ++ )
7787 {
7888 Double gain ;
7989 if ( ! Double . TryParse ( gains [ i ] , out gain ) )
80- throw Host . ExceptUserArg ( nameof ( args . LabelGains ) , "Label Gains must be of floating or integral type" , Aggregator . Counters . MaxTruncationLevel ) ;
90+ throw Host . ExceptUserArg ( nameof ( options . LabelGains ) , "Label Gains must be of floating or integral type" ) ;
8191 labelGains . Add ( gain ) ;
8292 }
8393 _labelGains = labelGains . ToArray ( ) ;
@@ -271,8 +281,6 @@ public sealed class Aggregator : AggregatorBase
271281 {
272282 public sealed class Counters
273283 {
274- public const int MaxTruncationLevel = 10 ;
275-
276284 public readonly int TruncationLevel ;
277285 private readonly List < Double [ ] > _groupNdcg ;
278286 private readonly List < Double [ ] > _groupDcg ;
@@ -287,6 +295,7 @@ public sealed class Counters
287295 private readonly List < short > _queryLabels ;
288296 private readonly List < Single > _queryOutputs ;
289297 private readonly Double [ ] _labelGains ;
298+ private readonly Double [ ] _discountMap ;
290299
291300 public bool GroupSummary { get { return _groupNdcg != null ; } }
292301
@@ -348,6 +357,8 @@ public Counters(Double[] labelGains, int truncationLevel, bool groupSummary)
348357 Contracts . AssertValue ( labelGains ) ;
349358
350359 TruncationLevel = truncationLevel ;
360+ _discountMap = RankingUtils . GetDiscountMap ( truncationLevel ) ;
361+
351362 _sumDcgAtN = new Double [ TruncationLevel ] ;
352363 _sumNdcgAtN = new Double [ TruncationLevel ] ;
353364
@@ -373,15 +384,15 @@ public void Update(short label, Single output)
373384
374385 public void UpdateGroup ( Single weight )
375386 {
376- RankingUtils . QueryMaxDcg ( _labelGains , TruncationLevel , _queryLabels , _queryOutputs , _groupMaxDcgCur ) ;
387+ RankingUtils . QueryMaxDcg ( _labelGains , TruncationLevel , _discountMap , _queryLabels , _queryOutputs , _groupMaxDcgCur ) ;
377388 if ( _groupMaxDcg != null )
378389 {
379390 var maxDcg = new Double [ TruncationLevel ] ;
380391 Array . Copy ( _groupMaxDcgCur , maxDcg , TruncationLevel ) ;
381392 _groupMaxDcg . Add ( maxDcg ) ;
382393 }
383394
384- RankingUtils . QueryDcg ( _labelGains , TruncationLevel , _queryLabels , _queryOutputs , _groupDcgCur ) ;
395+ RankingUtils . QueryDcg ( _labelGains , TruncationLevel , _discountMap , _queryLabels , _queryOutputs , _groupDcgCur ) ;
385396 if ( _groupDcg != null )
386397 {
387398 var groupDcg = new Double [ TruncationLevel ] ;
@@ -684,17 +695,19 @@ private void SlotNamesGetter(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
684695
685696 private readonly Bindings _bindings ;
686697 private readonly int _truncationLevel ;
698+ private readonly Double [ ] _discountMap ;
687699 private readonly Double [ ] _labelGains ;
688700
689701 public Transform ( IHostEnvironment env , IDataView input , string labelCol , string scoreCol , string groupCol ,
690702 int truncationLevel , Double [ ] labelGains )
691703 : base ( env , input , labelCol , scoreCol , groupCol , RegistrationName )
692704 {
693- Host . CheckParam ( 0 < truncationLevel && truncationLevel < 100 , nameof ( truncationLevel ) ,
694- "Truncation level must be between 1 and 99 " ) ;
705+ Host . CheckParam ( 0 < truncationLevel , nameof ( truncationLevel ) ,
706+ "Truncation level must be greater than 0 " ) ;
695707 Host . CheckValue ( labelGains , nameof ( labelGains ) ) ;
696708
697709 _truncationLevel = truncationLevel ;
710+ _discountMap = RankingUtils . GetDiscountMap ( _truncationLevel ) ;
698711 _labelGains = labelGains ;
699712 _bindings = new Bindings ( Host , Source . Schema , true , LabelCol , ScoreCol , GroupCol , _truncationLevel ) ;
700713 }
@@ -709,7 +722,7 @@ public Transform(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
709722 // double[]: _labelGains
710723
711724 _truncationLevel = ctx . Reader . ReadInt32 ( ) ;
712- Host . CheckDecode ( 0 < _truncationLevel && _truncationLevel < 100 ) ;
725+ Host . CheckDecode ( 0 < _truncationLevel ) ;
713726 _labelGains = ctx . Reader . ReadDoubleArray ( ) ;
714727 _bindings = new Bindings ( Host , input . Schema , false , LabelCol , ScoreCol , GroupCol , _truncationLevel ) ;
715728 }
@@ -725,7 +738,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
725738 // double[]: _labelGains
726739
727740 base . SaveModel ( ctx ) ;
728- Host . Assert ( 0 < _truncationLevel && _truncationLevel < 100 ) ;
741+ Host . Assert ( 0 < _truncationLevel ) ;
729742 ctx . Writer . Write ( _truncationLevel ) ;
730743 ctx . Writer . WriteDoubleArray ( _labelGains ) ;
731744 }
@@ -800,9 +813,9 @@ protected override void ProcessExample(RowCursorState state, short label, Single
800813 protected override void UpdateState ( RowCursorState state )
801814 {
802815 // Calculate the current group DCG, NDCG and MaxDcg.
803- RankingUtils . QueryMaxDcg ( _labelGains , _truncationLevel , state . QueryLabels , state . QueryOutputs ,
816+ RankingUtils . QueryMaxDcg ( _labelGains , _truncationLevel , _discountMap , state . QueryLabels , state . QueryOutputs ,
804817 state . MaxDcgCur ) ;
805- RankingUtils . QueryDcg ( _labelGains , _truncationLevel , state . QueryLabels , state . QueryOutputs , state . DcgCur ) ;
818+ RankingUtils . QueryDcg ( _labelGains , _truncationLevel , _discountMap , state . QueryLabels , state . QueryOutputs , state . DcgCur ) ;
806819 for ( int t = 0 ; t < _truncationLevel ; t ++ )
807820 {
808821 Double ndcg = state . MaxDcgCur [ t ] > 0 ? state . DcgCur [ t ] / state . MaxDcgCur [ t ] : 0 ;
@@ -823,7 +836,7 @@ public sealed class RowCursorState
823836
824837 public RowCursorState ( int truncationLevel )
825838 {
826- Contracts . Assert ( 0 < truncationLevel && truncationLevel < 100 ) ;
839+ Contracts . Assert ( 0 < truncationLevel ) ;
827840
828841 QueryLabels = new List < short > ( ) ;
829842 QueryOutputs = new List < Single > ( ) ;
@@ -867,12 +880,12 @@ public RankingMamlEvaluator(IHostEnvironment env, Arguments args)
867880 Host . CheckValue ( args , nameof ( args ) ) ;
868881 Utils . CheckOptionalUserDirectory ( args . GroupSummaryFilename , nameof ( args . GroupSummaryFilename ) ) ;
869882
870- var evalArgs = new RankingEvaluator . Arguments ( ) ;
871- evalArgs . DcgTruncationLevel = args . DcgTruncationLevel ;
872- evalArgs . LabelGains = args . LabelGains ;
873- evalArgs . OutputGroupSummary = ! string . IsNullOrEmpty ( args . GroupSummaryFilename ) ;
883+ var evalOpts = new RankingEvaluatorOptions ( ) ;
884+ evalOpts . DcgTruncationLevel = args . DcgTruncationLevel ;
885+ evalOpts . LabelGains = args . LabelGains ;
886+ evalOpts . OutputGroupSummary = ! string . IsNullOrEmpty ( args . GroupSummaryFilename ) ;
874887
875- _evaluator = new RankingEvaluator ( Host , evalArgs ) ;
888+ _evaluator = new RankingEvaluator ( Host , evalOpts ) ;
876889 _groupSummaryFilename = args . GroupSummaryFilename ;
877890 _groupIdCol = args . GroupIdColumn ;
878891 }
@@ -946,30 +959,41 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
946959
947960 internal static class RankingUtils
948961 {
949- private static volatile Double [ ] _discountMap ;
950- public static Double [ ] DiscountMap
962+ // Truncation levels are typically less than 100. So we maintain a fixed discount map of size 100
963+ // If truncation level greater than 100 is required, we build a new one and return that.
964+ private const int FixedDiscountMapSize = 100 ;
965+ private static Double [ ] _discountMapFixed ;
966+
967+ private static Double [ ] GetDiscountMapCore ( int truncationLevel )
951968 {
952- get
969+ var discountMap = new Double [ truncationLevel ] ;
970+
971+ for ( int i = 0 ; i < discountMap . Length ; i ++ )
972+ discountMap [ i ] = 1 / Math . Log ( 2 + i ) ;
973+
974+ return discountMap ;
975+ }
976+
977+ public static Double [ ] GetDiscountMap ( int truncationLevel )
978+ {
979+ var discountMap = _discountMapFixed ;
980+ if ( discountMap == null )
953981 {
954- double [ ] result = _discountMap ;
955- if ( result == null )
956- {
957- var discountMap = new Double [ 100 ] ; //Hard to believe anyone would set truncation Level higher than 100
958- for ( int i = 0 ; i < discountMap . Length ; i ++ )
959- {
960- discountMap [ i ] = 1 / Math . Log ( 2 + i ) ;
961- }
962- Interlocked . CompareExchange ( ref _discountMap , discountMap , null ) ;
963- result = _discountMap ;
964- }
965- return result ;
982+ discountMap = GetDiscountMapCore ( FixedDiscountMapSize ) ;
983+ Interlocked . CompareExchange ( ref _discountMapFixed , discountMap , null ) ;
984+ discountMap = _discountMapFixed ;
966985 }
986+
987+ if ( truncationLevel <= discountMap . Length )
988+ return discountMap ;
989+
990+ return GetDiscountMapCore ( truncationLevel ) ;
967991 }
968992
969993 /// <summary>
970994 /// Calculates natural-based max DCG at all truncations from 1 to truncationLevel.
971995 /// </summary>
972- public static void QueryMaxDcg ( Double [ ] labelGains , int truncationLevel ,
996+ public static void QueryMaxDcg ( Double [ ] labelGains , int truncationLevel , Double [ ] discountMap ,
973997 List < short > queryLabels , List < Single > queryOutputs , Double [ ] groupMaxDcgCur )
974998 {
975999 Contracts . Assert ( Utils . Size ( groupMaxDcgCur ) == truncationLevel ) ;
@@ -994,21 +1018,21 @@ public static void QueryMaxDcg(Double[] labelGains, int truncationLevel,
9941018 while ( labelCounts [ topLabel ] == 0 )
9951019 topLabel -- ;
9961020
997- groupMaxDcgCur [ 0 ] = labelGains [ topLabel ] * DiscountMap [ 0 ] ;
1021+ groupMaxDcgCur [ 0 ] = labelGains [ topLabel ] * discountMap [ 0 ] ;
9981022 labelCounts [ topLabel ] -- ;
9991023 for ( int t = 1 ; t < maxTrunc ; t ++ )
10001024 {
10011025 while ( labelCounts [ topLabel ] == 0 )
10021026 topLabel -- ;
1003- groupMaxDcgCur [ t ] = groupMaxDcgCur [ t - 1 ] + labelGains [ topLabel ] * DiscountMap [ t ] ;
1027+ groupMaxDcgCur [ t ] = groupMaxDcgCur [ t - 1 ] + labelGains [ topLabel ] * discountMap [ t ] ;
10041028 labelCounts [ topLabel ] -- ;
10051029 }
10061030 for ( int t = maxTrunc ; t < truncationLevel ; t ++ )
10071031 groupMaxDcgCur [ t ] = groupMaxDcgCur [ t - 1 ] ;
10081032 }
10091033 }
10101034
1011- public static void QueryDcg ( Double [ ] labelGains , int truncationLevel ,
1035+ public static void QueryDcg ( Double [ ] labelGains , int truncationLevel , Double [ ] discountMap ,
10121036 List < short > queryLabels , List < Single > queryOutputs , Double [ ] groupDcgCur )
10131037 {
10141038 // calculate the permutation
@@ -1021,7 +1045,7 @@ public static void QueryDcg(Double[] labelGains, int truncationLevel,
10211045 Double dcg = 0 ;
10221046 for ( int t = 0 ; t < count ; ++ t )
10231047 {
1024- dcg = dcg + labelGains [ queryLabels [ permutation [ t ] ] ] * DiscountMap [ t ] ;
1048+ dcg = dcg + labelGains [ queryLabels [ permutation [ t ] ] ] * discountMap [ t ] ;
10251049 groupDcgCur [ t ] = dcg ;
10261050 }
10271051 for ( int t = count ; t < truncationLevel ; ++ t )
0 commit comments