diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs
index e587dcb4c3..e994629ba1 100644
--- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs
+++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs
@@ -49,7 +49,8 @@ private protected override void CheckOptions(IChannel ch)
if (FastTreeTrainerOptions.EnablePruning && !HasValidSet)
throw ch.Except("Cannot perform pruning (pruning) without a validation set (valid).");
- if (FastTreeTrainerOptions.EarlyStoppingRule != null && !HasValidSet)
+ bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null;
+ if (doEarlyStop && !HasValidSet)
throw ch.Except("Cannot perform early stopping without a validation set (valid).");
if (FastTreeTrainerOptions.UseTolerantPruning && (!FastTreeTrainerOptions.EnablePruning || !HasValidSet))
@@ -113,9 +114,9 @@ private protected override IGradientAdjuster MakeGradientWrapper(IChannel ch)
return new BestStepRegressionGradientWrapper();
}
- private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStoppingRule, ref int bestIteration)
+ private protected override bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStoppingRule, ref int bestIteration)
{
- if (FastTreeTrainerOptions.EarlyStoppingRule == null)
+ if (FastTreeTrainerOptions.EarlyStoppingRuleFactory == null)
return false;
ch.AssertValue(ValidTest);
@@ -128,13 +129,16 @@ private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriter
var trainingResult = TrainTest.ComputeTests().First();
ch.Assert(trainingResult.FinalValue >= 0);
- // Create early stopping rule.
+ // Create early stopping rule if it's null.
if (earlyStoppingRule == null)
{
- earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRule.CreateComponent(Host, lowerIsBetter);
- ch.Assert(earlyStoppingRule != null);
+ if (FastTreeTrainerOptions.EarlyStoppingRuleFactory != null)
+ earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRuleFactory.CreateComponent(Host, lowerIsBetter);
}
+ // Early stopping rule cannot be null!
+ ch.Assert(earlyStoppingRule != null);
+
bool isBestCandidate;
bool shouldStop = earlyStoppingRule.CheckScore((float)validationResult.FinalValue,
(float)trainingResult.FinalValue, out isBestCandidate);
diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs
index 2814ca859a..6162ca06f8 100644
--- a/src/Microsoft.ML.FastTree/FastTree.cs
+++ b/src/Microsoft.ML.FastTree/FastTree.cs
@@ -245,7 +245,7 @@ private protected void TrainCore(IChannel ch)
}
}
- private protected virtual bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStopping, ref int bestIteration)
+ private protected virtual bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStopping, ref int bestIteration)
{
bestIteration = Ensemble.NumTrees;
return false;
@@ -650,7 +650,7 @@ private protected virtual void Train(IChannel ch)
#endif
#endif
- IEarlyStoppingCriterion earlyStoppingRule = null;
+ EarlyStoppingRuleBase earlyStoppingRule = null;
int bestIteration = 0;
int emptyTrees = 0;
using (var pch = Host.StartProgressChannel("FastTree training"))
diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs
index 6d2f2efded..767c706ece 100644
--- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs
@@ -621,9 +621,29 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc
///
/// Early stopping rule. (Validation set (/valid) is required).
///
- [Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", ShortName = "esr", NullName = "")]
+ [BestFriend]
+ [Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", Name = "EarlyStoppingRule", ShortName = "esr", NullName = "")]
[TGUI(Label = "Early Stopping Rule", Description = "Early stopping rule. (Validation set (/valid) is required.)")]
- public IEarlyStoppingCriterionFactory EarlyStoppingRule;
+ internal IEarlyStoppingCriterionFactory EarlyStoppingRuleFactory;
+
+ ///
+ /// The underlying state of and .
+ ///
+ private EarlyStoppingRuleBase _earlyStoppingRuleBase;
+
+ ///
+ /// Early stopping rule used to terminate training process once meeting a specified criterion. Possible choices are
+ /// 's implementations such as and .
+ ///
+ public EarlyStoppingRuleBase EarlyStoppingRule
+ {
+ get { return _earlyStoppingRuleBase; }
+ set
+ {
+ _earlyStoppingRuleBase = value;
+ EarlyStoppingRuleFactory = _earlyStoppingRuleBase.BuildFactory();
+ }
+ }
///
/// Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3).
diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
index 6db1145e7c..9279772227 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs
@@ -156,8 +156,12 @@ private protected override void CheckOptions(IChannel ch)
Dataset.DatasetSkeleton.LabelGainMap = gains;
}
- ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
- "earlyStoppingMetrics should be 1 or 3.");
+ bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
+ FastTreeTrainerOptions.EnablePruning;
+
+ if (doEarlyStop)
+ ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3,
+ nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 or 3.");
base.CheckOptions(ch);
}
diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs
index ac7ca46fbc..e02d5ec257 100644
--- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs
@@ -105,8 +105,12 @@ private protected override void CheckOptions(IChannel ch)
base.CheckOptions(ch);
- ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
- "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
+ bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
+ FastTreeTrainerOptions.EnablePruning;
+
+ if (doEarlyStop)
+ ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2,
+ nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
}
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
index 58510cec75..f68e67476f 100644
--- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
+++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs
@@ -112,12 +112,16 @@ private protected override void CheckOptions(IChannel ch)
// REVIEW: In order to properly support early stopping, the early stopping metric should be a subcomponent, not just
// a simple integer, because the metric that we might want is parameterized by this floating point "index" parameter. For now
// we just leave the existing regression checks, though with a warning.
-
if (FastTreeTrainerOptions.EarlyStoppingMetrics > 0)
ch.Warning("For Tweedie regression, early stopping does not yet use the Tweedie distribution.");
- ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
- "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
+ bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
+ FastTreeTrainerOptions.EnablePruning;
+
+ // Please do not remove it! See comment above.
+ if (doEarlyStop)
+ ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 2,
+ nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 (L1-norm) or 2 (L2-norm).");
}
private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
diff --git a/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs
index e01a6b25c7..e0241379d6 100644
--- a/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs
+++ b/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs
@@ -8,24 +8,30 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Trainers.FastTree;
-[assembly: LoadableClass(typeof(TolerantEarlyStoppingCriterion), typeof(TolerantEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")]
-[assembly: LoadableClass(typeof(GLEarlyStoppingCriterion), typeof(GLEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")]
-[assembly: LoadableClass(typeof(LPEarlyStoppingCriterion), typeof(LPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Low Progress (LP)", "lp")]
-[assembly: LoadableClass(typeof(PQEarlyStoppingCriterion), typeof(PQEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Generality to Progress Ratio (PQ)", "pq")]
-[assembly: LoadableClass(typeof(UPEarlyStoppingCriterion), typeof(UPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Consecutive Loss in Generality (UP)", "up")]
-
-[assembly: EntryPointModule(typeof(TolerantEarlyStoppingCriterion))]
-[assembly: EntryPointModule(typeof(GLEarlyStoppingCriterion))]
-[assembly: EntryPointModule(typeof(LPEarlyStoppingCriterion))]
-[assembly: EntryPointModule(typeof(PQEarlyStoppingCriterion))]
-[assembly: EntryPointModule(typeof(UPEarlyStoppingCriterion))]
+[assembly: LoadableClass(typeof(TolerantEarlyStoppingRule), typeof(TolerantEarlyStoppingRule.Options), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")]
+[assembly: LoadableClass(typeof(GeneralityLossRule), typeof(GeneralityLossRule.Options), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")]
+[assembly: LoadableClass(typeof(LowProgressRule), typeof(LowProgressRule.Options), typeof(SignatureEarlyStoppingCriterion), "Low Progress (LP)", "lp")]
+[assembly: LoadableClass(typeof(GeneralityToProgressRatioRule), typeof(GeneralityToProgressRatioRule.Options), typeof(SignatureEarlyStoppingCriterion), "Generality to Progress Ratio (PQ)", "pq")]
+[assembly: LoadableClass(typeof(ConsecutiveGeneralityLossRule), typeof(ConsecutiveGeneralityLossRule.Options), typeof(SignatureEarlyStoppingCriterion), "Consecutive Loss in Generality (UP)", "up")]
+
+[assembly: EntryPointModule(typeof(TolerantEarlyStoppingRule))]
+[assembly: EntryPointModule(typeof(GeneralityLossRule))]
+[assembly: EntryPointModule(typeof(LowProgressRule))]
+[assembly: EntryPointModule(typeof(GeneralityToProgressRatioRule))]
+[assembly: EntryPointModule(typeof(ConsecutiveGeneralityLossRule))]
+
+[assembly: EntryPointModule(typeof(TolerantEarlyStoppingRule.Options))]
+[assembly: EntryPointModule(typeof(GeneralityLossRule.Options))]
+[assembly: EntryPointModule(typeof(LowProgressRule.Options))]
+[assembly: EntryPointModule(typeof(GeneralityToProgressRatioRule.Options))]
+[assembly: EntryPointModule(typeof(ConsecutiveGeneralityLossRule.Options))]
namespace Microsoft.ML.Trainers.FastTree
{
internal delegate void SignatureEarlyStoppingCriterion(bool lowerIsBetter);
// These criteria will be used in FastTree and NeuralNets.
- public abstract class IEarlyStoppingCriterion
+ public abstract class EarlyStoppingRuleBase
{
///
/// Check if the learning should stop or not.
@@ -35,24 +41,36 @@ public abstract class IEarlyStoppingCriterion
/// True if the current result is the best ever.
/// If true, the learning should stop.
public abstract bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate);
+
+ ///
+ /// Having constructor without parameter prevents derivations of from being
+ /// implemented by the public.
+ ///
+ private protected EarlyStoppingRuleBase() { }
+
+ ///
+ /// Create for supporting legacy infra built upon .
+ ///
+ internal abstract IEarlyStoppingCriterionFactory BuildFactory();
}
+ [BestFriend]
[TlcModule.ComponentKind("EarlyStoppingCriterion")]
- public interface IEarlyStoppingCriterionFactory : IComponentFactory
+ internal interface IEarlyStoppingCriterionFactory : IComponentFactory
{
- new IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter);
+ new EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter);
}
- public abstract class EarlyStoppingCriterion : IEarlyStoppingCriterion
- where TOptions : EarlyStoppingCriterion.OptionsBase
+ public abstract class EarlyStoppingRule : EarlyStoppingRuleBase
{
- public abstract class OptionsBase { }
-
private float _bestScore;
- protected readonly TOptions EarlyStoppingCriterionOptions;
- protected readonly bool LowerIsBetter;
- protected float BestScore {
+ ///
+ /// It's if the selected stopping metric should be as low as possible, and otherwise.
+ ///
+ private protected bool LowerIsBetter { get; }
+
+ private protected float BestScore {
get { return _bestScore; }
set
{
@@ -61,19 +79,26 @@ protected float BestScore {
}
}
- internal EarlyStoppingCriterion(TOptions options, bool lowerIsBetter)
+ private protected EarlyStoppingRule(bool lowerIsBetter) : base()
{
- EarlyStoppingCriterionOptions = options;
LowerIsBetter = lowerIsBetter;
_bestScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity;
}
+ ///
+ /// Lazy constructor. It doesn't initialize anything because in runtime, will be
+ /// called inside the training process to initialize needed fields.
+ ///
+ private protected EarlyStoppingRule() : base()
+ {
+ }
+
///
/// Check if the given score is the best ever. The best score will be stored at this._bestScore.
///
/// The latest score
/// True if the given score is the best ever.
- protected bool CheckBestScore(float score)
+ private protected bool CheckBestScore(float score)
{
bool isBestEver = ((score > BestScore) != LowerIsBetter);
if (isBestEver)
@@ -83,27 +108,50 @@ protected bool CheckBestScore(float score)
}
}
- public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion
+ public sealed class TolerantEarlyStoppingRule : EarlyStoppingRule
{
+ [BestFriend]
[TlcModule.Component(FriendlyName = "Tolerant (TR)", Name = "TR", Desc = "Stop if validation score exceeds threshold value.")]
- public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory
+ internal sealed class Options : IEarlyStoppingCriterionFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance threshold. (Non negative value)", ShortName = "th")]
[TlcModule.Range(Min = 0.0f)]
public float Threshold = 0.01f;
- public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
+ public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter)
{
- return new TolerantEarlyStoppingCriterion(this, lowerIsBetter);
+ return new TolerantEarlyStoppingRule(this, lowerIsBetter);
}
}
- public TolerantEarlyStoppingCriterion(Options options, bool lowerIsBetter)
- : base(options, lowerIsBetter)
+ ///
+ /// The upper bound of the indicated metric on validation set.
+ ///
+ public float Threshold { get; }
+
+ ///
+ /// Create a rule which may terminate the training process if validation score exceeds compared with
+ /// the best historical validation score.
+ ///
+ /// The maximum difference allowed between the (current) validation score and its best historical value.
+ public TolerantEarlyStoppingRule(float threshold = 0.01f)
+ : base()
{
- Contracts.CheckUserArg(EarlyStoppingCriterionOptions.Threshold >= 0, nameof(options.Threshold), "Must be non-negative.");
+ Contracts.CheckUserArg(threshold >= 0, nameof(threshold), "Must be non-negative.");
+ Threshold = threshold;
}
+ // Used in command line tool to construct lodable class.
+ private TolerantEarlyStoppingRule(Options options, bool lowerIsBetter)
+ : base(lowerIsBetter)
+ {
+ Contracts.CheckUserArg(options.Threshold >= 0, nameof(options.Threshold), "Must be non-negative.");
+ Threshold = options.Threshold;
+ }
+
+ ///
+ /// See .
+ ///
public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate)
{
Contracts.Assert(validationScore >= 0);
@@ -111,19 +159,22 @@ public override bool CheckScore(float validationScore, float trainingScore, out
isBestCandidate = CheckBestScore(validationScore);
if (LowerIsBetter)
- return (validationScore - BestScore > EarlyStoppingCriterionOptions.Threshold);
+ return (validationScore - BestScore > Threshold);
else
- return (BestScore - validationScore > EarlyStoppingCriterionOptions.Threshold);
+ return (BestScore - validationScore > Threshold);
}
+
+ internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { Threshold = Threshold };
}
// For the detail of the following rules, see the following paper.
// Lodwich, Aleksander, Yves Rangoni, and Thomas Breuel. "Evaluation of robustness and performance of early stopping rules with multi layer perceptrons."
// Neural Networks, 2009. IJCNN 2009. International Joint Conference on. IEEE, 2009.
- public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion
+ public abstract class MovingWindowRule : EarlyStoppingRule
{
- public class Options : OptionsBase
+ [BestFriend]
+ internal class Options
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")]
[TlcModule.Range(Min = 0.0f, Max = 1.0f)]
@@ -134,15 +185,39 @@ public class Options : OptionsBase
public int WindowSize = 5;
}
- protected Queue PastScores;
+ ///
+ /// A threshold in range [0, 1].
+ ///
+ public float Threshold { get; }
+
+ ///
+ /// The number of historical validation scores considered when determining if the training process should stop.
+ ///
+ public int WindowSize { get; }
+
+ // Hide this because it's a runtime value.
+ private protected Queue PastScores;
+
+ private protected MovingWindowRule(float threshold, int windowSize)
+ : base()
+ {
+ Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1].");
+ Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive.");
+
+ Threshold = threshold;
+ WindowSize = windowSize;
+ PastScores = new Queue(windowSize);
+ }
- private protected MovingWindowEarlyStoppingCriterion(Options args, bool lowerIsBetter)
- : base(args, lowerIsBetter)
+ private protected MovingWindowRule(bool lowerIsBetter, float threshold, int windowSize)
+ : base(lowerIsBetter)
{
- Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1].");
- Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(args.WindowSize), "Must be positive.");
+ Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1].");
+ Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive.");
- PastScores = new Queue(EarlyStoppingCriterionOptions.WindowSize);
+ Threshold = threshold;
+ WindowSize = windowSize;
+ PastScores = new Queue(windowSize);
}
///
@@ -177,7 +252,7 @@ private float GetRecentBest(IEnumerable recentScores)
return recentBestScore;
}
- protected bool CheckRecentScores(float score, int windowSize, out float recentBest, out float recentAverage)
+ private protected bool CheckRecentScores(float score, int windowSize, out float recentBest, out float recentAverage)
{
if (PastScores.Count >= windowSize)
{
@@ -200,28 +275,53 @@ protected bool CheckRecentScores(float score, int windowSize, out float recentBe
///
/// Loss of Generality (GL).
///
- public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion
+ public sealed class GeneralityLossRule : EarlyStoppingRule
{
+ [BestFriend]
[TlcModule.Component(FriendlyName = "Loss of Generality (GL)", Name = "GL",
Desc = "Stop in case of loss of generality.")]
- public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory
+ internal sealed class Options : IEarlyStoppingCriterionFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")]
[TlcModule.Range(Min = 0.0f, Max = 1.0f)]
public float Threshold = 0.01f;
- public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
+ public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter)
{
- return new GLEarlyStoppingCriterion(this, lowerIsBetter);
+ return new GeneralityLossRule(this, lowerIsBetter);
}
}
- public GLEarlyStoppingCriterion(Options options, bool lowerIsBetter)
- : base(options, lowerIsBetter)
+ ///
+ /// The maximum gap (in percentage such as 0.01 for 1% and 0.5 for 50%) between the (current) validation
+ /// score and its best historical value.
+ ///
+ public float Threshold { get; }
+
+ ///
+ /// Create a rule which may terminate the training process in case of loss of generality. The loss of generality means
+ /// the specified score on validation start increaseing.
+ ///
+ /// The maximum gap (in percentage such as 0.01 for 1% and 0.5 for 50%) between the (current) validation
+ /// score and its best historical value.
+ public GeneralityLossRule(float threshold = 0.01f) :
+ base()
+ {
+ Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1].");
+ Threshold = threshold;
+ }
+
+ // Used in command line tool to construct lodable class.
+ private GeneralityLossRule(Options options, bool lowerIsBetter)
+ : base(lowerIsBetter)
{
- Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1].");
+ Contracts.CheckUserArg(0 <= options.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1].");
+ Threshold = options.Threshold;
}
+ ///
+ /// See .
+ ///
public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate)
{
Contracts.Assert(validationScore >= 0);
@@ -229,30 +329,52 @@ public override bool CheckScore(float validationScore, float trainingScore, out
isBestCandidate = CheckBestScore(validationScore);
if (LowerIsBetter)
- return (validationScore > (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore);
+ return (validationScore > (1 + Threshold) * BestScore);
else
- return (validationScore < (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore);
+ return (validationScore < (1 - Threshold) * BestScore);
}
- }
+
+ internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { Threshold = Threshold };
+}
///
/// Low Progress (LP).
/// This rule fires when the improvements on the score stall.
///
- public sealed class LPEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion
+ public sealed class LowProgressRule : MovingWindowRule
{
+ [BestFriend]
[TlcModule.Component(FriendlyName = "Low Progress (LP)", Name = "LP", Desc = "Stops in case of low progress.")]
- public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory
+ internal new sealed class Options : MovingWindowRule.Options, IEarlyStoppingCriterionFactory
{
- public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
+ public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter)
{
- return new LPEarlyStoppingCriterion(this, lowerIsBetter);
+ return new LowProgressRule(this, lowerIsBetter);
}
}
- public LPEarlyStoppingCriterion(Options options, bool lowerIsBetter)
- : base(options, lowerIsBetter) { }
+ ///
+ /// Create a rule which may terminate the training process when the improvements in terms of validation score is slow.
+ /// It will terminate the training process if the average of the recent validation scores
+ /// is worse than the best historical validation score.
+ ///
+ /// The maximum gap (in percentage such as 0.01 for 1% and 0.5 for 50%) between the (current) averaged validation
+ /// score and its best historical value.
+ /// See .
+ public LowProgressRule(float threshold = 0.01f, int windowSize = 5)
+ : base(threshold, windowSize)
+ {
+ }
+ // Used in command line tool to construct lodable class.
+ private LowProgressRule(Options options, bool lowerIsBetter)
+ : base(lowerIsBetter, options.Threshold, options.WindowSize)
+ {
+ }
+
+ ///
+ /// See .
+ ///
public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate)
{
Contracts.Assert(validationScore >= 0);
@@ -262,35 +384,54 @@ public override bool CheckScore(float validationScore, float trainingScore, out
float recentBest;
float recentAverage;
- if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage))
+ if (CheckRecentScores(trainingScore, WindowSize, out recentBest, out recentAverage))
{
if (LowerIsBetter)
- return (recentAverage <= (1 + EarlyStoppingCriterionOptions.Threshold) * recentBest);
+ return (recentAverage <= (1 + Threshold) * recentBest);
else
- return (recentAverage >= (1 - EarlyStoppingCriterionOptions.Threshold) * recentBest);
+ return (recentAverage >= (1 - Threshold) * recentBest);
}
return false;
}
- }
+
+ internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { Threshold = Threshold, WindowSize = WindowSize };
+}
///
/// Generality to Progress Ratio (PQ).
///
- public sealed class PQEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion
+ public sealed class GeneralityToProgressRatioRule : MovingWindowRule
{
+ [BestFriend]
[TlcModule.Component(FriendlyName = "Generality to Progress Ratio (PQ)", Name = "PQ", Desc = "Stops in case of generality to progress ration exceeds threshold.")]
- public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory
+ internal new sealed class Options : MovingWindowRule.Options, IEarlyStoppingCriterionFactory
{
- public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
+ public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter)
{
- return new PQEarlyStoppingCriterion(this, lowerIsBetter);
+ return new GeneralityToProgressRatioRule(this, lowerIsBetter);
}
}
- public PQEarlyStoppingCriterion(Options options, bool lowerIsBetter)
- : base(options, lowerIsBetter) { }
+ ///
+ /// Create a rule which may terminate the training process when generality-to-progress ratio exceeds .
+ ///
+ /// The maximum ratio gap (in percentage such as 0.01 for 1% and 0.5 for 50%).
+ /// See .
+ public GeneralityToProgressRatioRule(float threshold = 0.01f, int windowSize = 5)
+ : base(threshold, windowSize)
+ {
+ }
+
+ // Used in command line tool to construct lodable class.
+ private GeneralityToProgressRatioRule(Options options, bool lowerIsBetter)
+ : base(lowerIsBetter, options.Threshold, options.WindowSize)
+ {
+ }
+ ///
+ /// See .
+ ///
public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate)
{
Contracts.Assert(validationScore >= 0);
@@ -300,48 +441,72 @@ public override bool CheckScore(float validationScore, float trainingScore, out
float recentBest;
float recentAverage;
- if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage))
+ if (CheckRecentScores(trainingScore, WindowSize, out recentBest, out recentAverage))
{
if (LowerIsBetter)
- return (validationScore * recentBest >= (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage);
+ return (validationScore * recentBest >= (1 + Threshold) * BestScore * recentAverage);
else
- return (validationScore * recentBest <= (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage);
+ return (validationScore * recentBest <= (1 - Threshold) * BestScore * recentAverage);
}
return false;
}
+
+ [BestFriend]
+ internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { Threshold = Threshold, WindowSize = WindowSize };
}
///
/// Consecutive Loss in Generality (UP).
///
- public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion
+ public sealed class ConsecutiveGeneralityLossRule : EarlyStoppingRule
{
+ [BestFriend]
[TlcModule.Component(FriendlyName = "Consecutive Loss in Generality (UP)", Name = "UP",
Desc = "Stops in case of consecutive loss in generality.")]
- public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory
+ internal sealed class Options : IEarlyStoppingCriterionFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "The window size.", ShortName = "w")]
[TlcModule.Range(Inf = 0)]
public int WindowSize = 5;
- public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter)
+ public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter)
{
- return new UPEarlyStoppingCriterion(this, lowerIsBetter);
+ return new ConsecutiveGeneralityLossRule(this, lowerIsBetter);
}
}
+ ///
+ /// The number of historical validation scores considered when determining if the training process should stop.
+ ///
+ public int WindowSize { get; }
+
private int _count;
private float _prevScore;
- public UPEarlyStoppingCriterion(Options options, bool lowerIsBetter)
- : base(options, lowerIsBetter)
+ ///
+ /// Creates a rule which terminates the training process if the validation score is not improved in consecutive iterations.
+ ///
+ /// Number of training iterations allowed to have no improvement.
+ public ConsecutiveGeneralityLossRule(int windowSize = 5)
+ : base()
{
- Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(options.WindowSize), "Must be positive");
+ Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive");
+ WindowSize = windowSize;
+ _prevScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity;
+ }
+ private ConsecutiveGeneralityLossRule(Options options, bool lowerIsBetter)
+ : base(lowerIsBetter)
+ {
+ Contracts.CheckUserArg(options.WindowSize > 0, nameof(options.WindowSize), "Must be positive");
+ WindowSize = options.WindowSize;
_prevScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity;
}
+ ///
+ /// See .
+ ///
public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate)
{
Contracts.Assert(validationScore >= 0);
@@ -351,7 +516,10 @@ public override bool CheckScore(float validationScore, float trainingScore, out
_count = ((validationScore < _prevScore) != LowerIsBetter) ? _count + 1 : 0;
_prevScore = validationScore;
- return (_count >= EarlyStoppingCriterionOptions.WindowSize);
+ return (_count >= WindowSize);
}
+
+ [BestFriend]
+ internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { WindowSize = WindowSize };
}
}
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs
index c1ee89342f..214e6d4f2a 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs
@@ -10,18 +10,18 @@ namespace Microsoft.ML.RunTests
{
public sealed class TestEarlyStoppingCriteria
{
- private IEarlyStoppingCriterion CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter)
+ private EarlyStoppingRuleBase CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter)
{
var env = new MLContext()
.AddStandardComponents();
- var sub = new SubComponent(name, args);
+ var sub = new SubComponent(name, args);
return sub.CreateInstance(env, lowerIsBetter);
}
[Fact]
public void TolerantEarlyStoppingCriterionTest()
{
- IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("tr", "th=0.01", false);
+ EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("tr", "th=0.01", false);
bool isBestCandidate;
bool shouldStop;
@@ -46,7 +46,7 @@ public void TolerantEarlyStoppingCriterionTest()
[Fact]
public void GLEarlyStoppingCriterionTest()
{
- IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("gl", "th=0.01", false);
+ EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("gl", "th=0.01", false);
bool isBestCandidate;
bool shouldStop;
@@ -71,7 +71,7 @@ public void GLEarlyStoppingCriterionTest()
[Fact]
public void LPEarlyStoppingCriterionTest()
{
- IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("lp", "th=0.01 w=5", false);
+ EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("lp", "th=0.01 w=5", false);
bool isBestCandidate;
bool shouldStop;
@@ -107,7 +107,7 @@ public void LPEarlyStoppingCriterionTest()
[Fact]
public void PQEarlyStoppingCriterionTest()
{
- IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("pq", "th=0.01 w=5", false);
+ EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("pq", "th=0.01 w=5", false);
bool isBestCandidate;
bool shouldStop;
@@ -144,7 +144,7 @@ public void PQEarlyStoppingCriterionTest()
public void UPEarlyStoppingCriterionTest()
{
const int windowSize = 8;
- IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("up", "w=8", false);
+ EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("up", "w=8", false);
bool isBestCandidate;
bool shouldStop;
diff --git a/test/Microsoft.ML.Functional.Tests/Validation.cs b/test/Microsoft.ML.Functional.Tests/Validation.cs
index cb0cacfc7e..9953cd5fc7 100644
--- a/test/Microsoft.ML.Functional.Tests/Validation.cs
+++ b/test/Microsoft.ML.Functional.Tests/Validation.cs
@@ -84,7 +84,7 @@ public void TrainWithValidationSet()
var trainedModel = mlContext.Regression.Trainers.FastTree(new Trainers.FastTree.FastTreeRegressionTrainer.Options {
NumberOfTrees = 2,
EarlyStoppingMetric = EarlyStoppingMetric.L2Norm,
- EarlyStoppingRule = new GLEarlyStoppingCriterion.Options()
+ EarlyStoppingRule = new GeneralityLossRule()
})
.Fit(trainData: preprocessedTrainData, validationData: preprocessedValidData);