@@ -20,19 +20,81 @@ internal interface IFastTreeTrainerFactory : IComponentFactory<ITrainer>
2020 {
2121 }
2222
23+ /// <summary>
24+ /// Stopping measurements for classification and regression.
25+ /// </summary>
26+ public enum EarlyStoppingMetric
27+ {
28+ /// <summary>
29+ /// L1-norm of gradient.
30+ /// </summary>
31+ L1Norm = 1 ,
32+ /// <summary>
33+ /// L2-norm of gradient.
34+ /// </summary>
35+ L2Norm = 2
36+ } ;
37+
38+ /// <summary>
39+ /// Stopping measurements for ranking.
40+ /// </summary>
41+ public enum EarlyStoppingRankingMetric
42+ {
43+ /// <summary>
44+ /// NDCG@1
45+ /// </summary>
46+ NdcgAt1 = 1 ,
47+ /// <summary>
48+ /// NDCG@3
49+ /// </summary>
50+ NdcgAt3 = 3
51+ }
52+
2353 /// <include file='doc.xml' path='doc/members/member[@name="FastTree"]/*' />
2454 public sealed partial class FastTreeBinaryClassificationTrainer
2555 {
2656 [ TlcModule . Component ( Name = LoadNameValue , FriendlyName = UserNameValue , Desc = Summary ) ]
2757 public sealed class Options : BoostedTreeOptions , IFastTreeTrainerFactory
2858 {
59+
2960 /// <summary>
3061 /// Option for using derivatives optimized for unbalanced sets.
3162 /// </summary>
3263 [ Argument ( ArgumentType . LastOccurenceWins , HelpText = "Option for using derivatives optimized for unbalanced sets" , ShortName = "us" ) ]
3364 [ TGUI ( Label = "Optimize for unbalanced" ) ]
3465 public bool UnbalancedSets = false ;
3566
67+ /// <summary>
68+ /// internal state of <see cref="EarlyStoppingMetric"/>. It should be always synced with
69+ /// <see cref="BoostedTreeOptions.EarlyStoppingMetrics"/>.
70+ /// </summary>
71+ // Disable 649 because Visual Studio can't detect its assignment via property.
72+ #pragma warning disable 649
73+ private EarlyStoppingMetric _earlyStoppingMetric ;
74+ #pragma warning restore 649
75+
76+ /// <summary>
77+ /// Early stopping metrics.
78+ /// </summary>
79+ public EarlyStoppingMetric EarlyStoppingMetric
80+ {
81+ get { return _earlyStoppingMetric ; }
82+
83+ set
84+ {
85+ // Update the state of the user-facing stopping metric.
86+ _earlyStoppingMetric = value ;
87+ // Set up internal property according to its public value.
88+ EarlyStoppingMetrics = ( int ) _earlyStoppingMetric ;
89+ }
90+ }
91+
92+ public Options ( )
93+ {
94+ // Use L1 by default.
95+ EarlyStoppingMetric = EarlyStoppingMetric . L1Norm ;
96+ }
97+
3698 ITrainer IComponentFactory < ITrainer > . CreateComponent ( IHostEnvironment env ) => new FastTreeBinaryClassificationTrainer ( env , this ) ;
3799 }
38100 }
@@ -42,9 +104,31 @@ public sealed partial class FastTreeRegressionTrainer
42104 [ TlcModule . Component ( Name = LoadNameValue , FriendlyName = UserNameValue , Desc = Summary ) ]
43105 public sealed class Options : BoostedTreeOptions , IFastTreeTrainerFactory
44106 {
107+ /// <summary>
108+ /// internal state of <see cref="EarlyStoppingMetric"/>. It should be always synced with
109+ /// <see cref="BoostedTreeOptions.EarlyStoppingMetrics"/>.
110+ /// </summary>
111+ private EarlyStoppingMetric _earlyStoppingMetric ;
112+
113+ /// <summary>
114+ /// Early stopping metrics.
115+ /// </summary>
116+ public EarlyStoppingMetric EarlyStoppingMetric
117+ {
118+ get { return _earlyStoppingMetric ; }
119+
120+ set
121+ {
122+ // Update the state of the user-facing stopping metric.
123+ _earlyStoppingMetric = value ;
124+ // Set up internal property according to its public value.
125+ EarlyStoppingMetrics = ( int ) _earlyStoppingMetric ;
126+ }
127+ }
128+
45129 public Options ( )
46130 {
47- EarlyStoppingMetrics = 1 ; // Use L1 by default.
131+ EarlyStoppingMetric = EarlyStoppingMetric . L1Norm ; // Use L1 by default.
48132 }
49133
50134 ITrainer IComponentFactory < ITrainer > . CreateComponent ( IHostEnvironment env ) => new FastTreeRegressionTrainer ( env , this ) ;
@@ -64,6 +148,36 @@ public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
64148 "and intermediate values are compound Poisson loss." ) ]
65149 public Double Index = 1.5 ;
66150
151+ /// <summary>
152+ /// internal state of <see cref="EarlyStoppingMetric"/>. It should be always synced with
153+ /// <see cref="BoostedTreeOptions.EarlyStoppingMetrics"/>.
154+ /// </summary>
155+ // Disable 649 because Visual Studio can't detect its assignment via property.
156+ #pragma warning disable 649
157+ private EarlyStoppingMetric _earlyStoppingMetric ;
158+ #pragma warning restore 649
159+
160+ /// <summary>
161+ /// Early stopping metrics.
162+ /// </summary>
163+ public EarlyStoppingMetric EarlyStoppingMetric
164+ {
165+ get { return _earlyStoppingMetric ; }
166+
167+ set
168+ {
169+ // Update the state of the user-facing stopping metric.
170+ _earlyStoppingMetric = value ;
171+ // Set up internal property according to its public value.
172+ EarlyStoppingMetrics = ( int ) _earlyStoppingMetric ;
173+ }
174+ }
175+
176+ public Options ( )
177+ {
178+ EarlyStoppingMetric = EarlyStoppingMetric . L1Norm ; // Use L1 by default.
179+ }
180+
67181 ITrainer IComponentFactory < ITrainer > . CreateComponent ( IHostEnvironment env ) => new FastTreeTweedieTrainer ( env , this ) ;
68182 }
69183 }
@@ -113,9 +227,34 @@ public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
113227 [ TGUI ( NotGui = true ) ]
114228 internal bool NormalizeQueryLambdas ;
115229
230+ /// <summary>
231+ /// internal state of <see cref="EarlyStoppingMetric"/>. It should be always synced with
232+ /// <see cref="BoostedTreeOptions.EarlyStoppingMetrics"/>.
233+ /// </summary>
234+ // Disable 649 because Visual Studio can't detect its assignment via property.
235+ #pragma warning disable 649
236+ private EarlyStoppingRankingMetric _earlyStoppingMetric ;
237+ #pragma warning restore 649
238+
239+ /// <summary>
240+ /// Early stopping metrics.
241+ /// </summary>
242+ public EarlyStoppingRankingMetric EarlyStoppingMetric
243+ {
244+ get { return _earlyStoppingMetric ; }
245+
246+ set
247+ {
248+ // Update the state of the user-facing stopping metric.
249+ _earlyStoppingMetric = value ;
250+ // Set up internal property according to its public value.
251+ EarlyStoppingMetrics = ( int ) _earlyStoppingMetric ;
252+ }
253+ }
254+
116255 public Options ( )
117256 {
118- EarlyStoppingMetrics = 1 ;
257+ EarlyStoppingMetric = EarlyStoppingRankingMetric . NdcgAt1 ; // Use L1 by default.
119258 }
120259
121260 ITrainer IComponentFactory < ITrainer > . CreateComponent ( IHostEnvironment env ) => new FastTreeRankingTrainer ( env , this ) ;
0 commit comments