Skip to content

Commit 813cbb5

Browse files
committed
Improve early-stopping metric field
1 parent fdca895 commit 813cbb5

File tree

2 files changed

+142
-3
lines changed

2 files changed

+142
-3
lines changed

src/Microsoft.ML.FastTree/FastTreeArguments.cs

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,81 @@ internal interface IFastTreeTrainerFactory : IComponentFactory<ITrainer>
2020
{
2121
}
2222

23+
/// <summary>
24+
/// Stopping measurements for classification and regression.
25+
/// </summary>
26+
public enum EarlyStoppingMetric
27+
{
28+
/// <summary>
29+
/// L1-norm of gradient.
30+
/// </summary>
31+
L1Norm = 1,
32+
/// <summary>
33+
/// L2-norm of gradient.
34+
/// </summary>
35+
L2Norm = 2
36+
};
37+
38+
/// <summary>
39+
/// Stopping measurements for ranking.
40+
/// </summary>
41+
public enum EarlyStoppingRankingMetric
42+
{
43+
/// <summary>
44+
/// NDCG@1
45+
/// </summary>
46+
NdcgAt1 = 1,
47+
/// <summary>
48+
/// NDCG@3
49+
/// </summary>
50+
NdcgAt3 = 3
51+
}
52+
2353
/// <include file='doc.xml' path='doc/members/member[@name="FastTree"]/*' />
2454
public sealed partial class FastTreeBinaryClassificationTrainer
2555
{
2656
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
2757
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
2858
{
59+
2960
/// <summary>
3061
/// Option for using derivatives optimized for unbalanced sets.
3162
/// </summary>
3263
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Option for using derivatives optimized for unbalanced sets", ShortName = "us")]
3364
[TGUI(Label = "Optimize for unbalanced")]
3465
public bool UnbalancedSets = false;
3566

67+
/// <summary>
68+
/// internal state of <see cref="EarlyStoppingMetric"/>. It should be always synced with
69+
/// <see cref="BoostedTreeOptions.EarlyStoppingMetrics"/>.
70+
/// </summary>
71+
// Disable 649 because Visual Studio can't detect its assignment via property.
72+
#pragma warning disable 649
73+
private EarlyStoppingMetric _earlyStoppingMetric;
74+
#pragma warning restore 649
75+
76+
/// <summary>
77+
/// Early stopping metrics.
78+
/// </summary>
79+
public EarlyStoppingMetric EarlyStoppingMetric
80+
{
81+
get { return _earlyStoppingMetric; }
82+
83+
set
84+
{
85+
// Update the state of the user-facing stopping metric.
86+
_earlyStoppingMetric = value;
87+
// Set up internal property according to its public value.
88+
EarlyStoppingMetrics = (int)_earlyStoppingMetric;
89+
}
90+
}
91+
92+
public Options()
93+
{
94+
// Use L1 by default.
95+
EarlyStoppingMetric = EarlyStoppingMetric.L1Norm;
96+
}
97+
3698
ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeBinaryClassificationTrainer(env, this);
3799
}
38100
}
@@ -42,9 +104,31 @@ public sealed partial class FastTreeRegressionTrainer
42104
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
43105
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
44106
{
107+
/// <summary>
108+
/// internal state of <see cref="EarlyStoppingMetric"/>. It should be always synced with
109+
/// <see cref="BoostedTreeOptions.EarlyStoppingMetrics"/>.
110+
/// </summary>
111+
private EarlyStoppingMetric _earlyStoppingMetric;
112+
113+
/// <summary>
114+
/// Early stopping metrics.
115+
/// </summary>
116+
public EarlyStoppingMetric EarlyStoppingMetric
117+
{
118+
get { return _earlyStoppingMetric; }
119+
120+
set
121+
{
122+
// Update the state of the user-facing stopping metric.
123+
_earlyStoppingMetric = value;
124+
// Set up internal property according to its public value.
125+
EarlyStoppingMetrics = (int)_earlyStoppingMetric;
126+
}
127+
}
128+
45129
public Options()
46130
{
47-
EarlyStoppingMetrics = 1; // Use L1 by default.
131+
EarlyStoppingMetric = EarlyStoppingMetric.L1Norm; // Use L1 by default.
48132
}
49133

50134
ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeRegressionTrainer(env, this);
@@ -64,6 +148,36 @@ public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
64148
"and intermediate values are compound Poisson loss.")]
65149
public Double Index = 1.5;
66150

151+
/// <summary>
152+
/// internal state of <see cref="EarlyStoppingMetric"/>. It should be always synced with
153+
/// <see cref="BoostedTreeOptions.EarlyStoppingMetrics"/>.
154+
/// </summary>
155+
// Disable 649 because Visual Studio can't detect its assignment via property.
156+
#pragma warning disable 649
157+
private EarlyStoppingMetric _earlyStoppingMetric;
158+
#pragma warning restore 649
159+
160+
/// <summary>
161+
/// Early stopping metrics.
162+
/// </summary>
163+
public EarlyStoppingMetric EarlyStoppingMetric
164+
{
165+
get { return _earlyStoppingMetric; }
166+
167+
set
168+
{
169+
// Update the state of the user-facing stopping metric.
170+
_earlyStoppingMetric = value;
171+
// Set up internal property according to its public value.
172+
EarlyStoppingMetrics = (int)_earlyStoppingMetric;
173+
}
174+
}
175+
176+
public Options()
177+
{
178+
EarlyStoppingMetric = EarlyStoppingMetric.L1Norm; // Use L1 by default.
179+
}
180+
67181
ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeTweedieTrainer(env, this);
68182
}
69183
}
@@ -113,9 +227,34 @@ public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
113227
[TGUI(NotGui = true)]
114228
internal bool NormalizeQueryLambdas;
115229

230+
/// <summary>
231+
/// internal state of <see cref="EarlyStoppingMetric"/>. It should be always synced with
232+
/// <see cref="BoostedTreeOptions.EarlyStoppingMetrics"/>.
233+
/// </summary>
234+
// Disable 649 because Visual Studio can't detect its assignment via property.
235+
#pragma warning disable 649
236+
private EarlyStoppingRankingMetric _earlyStoppingMetric;
237+
#pragma warning restore 649
238+
239+
/// <summary>
240+
/// Early stopping metrics.
241+
/// </summary>
242+
public EarlyStoppingRankingMetric EarlyStoppingMetric
243+
{
244+
get { return _earlyStoppingMetric; }
245+
246+
set
247+
{
248+
// Update the state of the user-facing stopping metric.
249+
_earlyStoppingMetric = value;
250+
// Set up internal property according to its public value.
251+
EarlyStoppingMetrics = (int)_earlyStoppingMetric;
252+
}
253+
}
254+
116255
public Options()
117256
{
118-
EarlyStoppingMetrics = 1;
257+
EarlyStoppingMetric = EarlyStoppingRankingMetric.NdcgAt1; // Use L1 by default.
119258
}
120259

121260
ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeRankingTrainer(env, this);

test/Microsoft.ML.Functional.Tests/Validation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public void TrainWithValidationSet()
8383
// Train the model with a validation set.
8484
var trainedModel = mlContext.Regression.Trainers.FastTree(new Trainers.FastTree.FastTreeRegressionTrainer.Options {
8585
NumberOfTrees = 2,
86-
EarlyStoppingMetrics = 2,
86+
EarlyStoppingMetric = EarlyStoppingMetric.L2Norm,
8787
EarlyStoppingRule = new GLEarlyStoppingCriterion.Options()
8888
})
8989
.Fit(trainData: preprocessedTrainData, validationData: preprocessedValidData);

0 commit comments

Comments
 (0)