diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index e587dcb4c3..e994629ba1 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -49,7 +49,8 @@ private protected override void CheckOptions(IChannel ch) if (FastTreeTrainerOptions.EnablePruning && !HasValidSet) throw ch.Except("Cannot perform pruning (pruning) without a validation set (valid)."); - if (FastTreeTrainerOptions.EarlyStoppingRule != null && !HasValidSet) + bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null; + if (doEarlyStop && !HasValidSet) throw ch.Except("Cannot perform early stopping without a validation set (valid)."); if (FastTreeTrainerOptions.UseTolerantPruning && (!FastTreeTrainerOptions.EnablePruning || !HasValidSet)) @@ -113,9 +114,9 @@ private protected override IGradientAdjuster MakeGradientWrapper(IChannel ch) return new BestStepRegressionGradientWrapper(); } - private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStoppingRule, ref int bestIteration) + private protected override bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStoppingRule, ref int bestIteration) { - if (FastTreeTrainerOptions.EarlyStoppingRule == null) + if (FastTreeTrainerOptions.EarlyStoppingRuleFactory == null) return false; ch.AssertValue(ValidTest); @@ -128,13 +129,16 @@ private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriter var trainingResult = TrainTest.ComputeTests().First(); ch.Assert(trainingResult.FinalValue >= 0); - // Create early stopping rule. + // Create early stopping rule if it's null. if (earlyStoppingRule == null) { - earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRule.CreateComponent(Host, lowerIsBetter); - ch.Assert(earlyStoppingRule != null); + if (FastTreeTrainerOptions.EarlyStoppingRuleFactory != null) + earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRuleFactory.CreateComponent(Host, lowerIsBetter); } + // Early stopping rule cannot be null! + ch.Assert(earlyStoppingRule != null); + bool isBestCandidate; bool shouldStop = earlyStoppingRule.CheckScore((float)validationResult.FinalValue, (float)trainingResult.FinalValue, out isBestCandidate); diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 2814ca859a..6162ca06f8 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -245,7 +245,7 @@ private protected void TrainCore(IChannel ch) } } - private protected virtual bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStopping, ref int bestIteration) + private protected virtual bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStopping, ref int bestIteration) { bestIteration = Ensemble.NumTrees; return false; @@ -650,7 +650,7 @@ private protected virtual void Train(IChannel ch) #endif #endif - IEarlyStoppingCriterion earlyStoppingRule = null; + EarlyStoppingRuleBase earlyStoppingRule = null; int bestIteration = 0; int emptyTrees = 0; using (var pch = Host.StartProgressChannel("FastTree training")) diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index 6d2f2efded..767c706ece 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -621,9 +621,29 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc /// /// Early stopping rule. (Validation set (/valid) is required). /// - [Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", ShortName = "esr", NullName = "")] + [BestFriend] + [Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", Name = "EarlyStoppingRule", ShortName = "esr", NullName = "")] [TGUI(Label = "Early Stopping Rule", Description = "Early stopping rule. (Validation set (/valid) is required.)")] - public IEarlyStoppingCriterionFactory EarlyStoppingRule; + internal IEarlyStoppingCriterionFactory EarlyStoppingRuleFactory; + + /// + /// The underlying state of and . + /// + private EarlyStoppingRuleBase _earlyStoppingRuleBase; + + /// + /// Early stopping rule used to terminate training process once meeting a specified criterion. Possible choices are + /// 's implementations such as and . + /// + public EarlyStoppingRuleBase EarlyStoppingRule + { + get { return _earlyStoppingRuleBase; } + set + { + _earlyStoppingRuleBase = value; + EarlyStoppingRuleFactory = _earlyStoppingRuleBase.BuildFactory(); + } + } /// /// Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3). diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 6db1145e7c..9279772227 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -156,8 +156,12 @@ private protected override void CheckOptions(IChannel ch) Dataset.DatasetSkeleton.LabelGainMap = gains; } - ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), - "earlyStoppingMetrics should be 1 or 3."); + bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || + FastTreeTrainerOptions.EnablePruning; + + if (doEarlyStop) + ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3, + nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 or 3."); base.CheckOptions(ch); } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index ac7ca46fbc..e02d5ec257 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -105,8 +105,12 @@ private protected override void CheckOptions(IChannel ch) base.CheckOptions(ch); - ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), - "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); + bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || + FastTreeTrainerOptions.EnablePruning; + + if (doEarlyStop) + ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2, + nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); } private static SchemaShape.Column MakeLabelColumn(string labelColumn) diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 58510cec75..f68e67476f 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -112,12 +112,16 @@ private protected override void CheckOptions(IChannel ch) // REVIEW: In order to properly support early stopping, the early stopping metric should be a subcomponent, not just // a simple integer, because the metric that we might want is parameterized by this floating point "index" parameter. For now // we just leave the existing regression checks, though with a warning. - if (FastTreeTrainerOptions.EarlyStoppingMetrics > 0) ch.Warning("For Tweedie regression, early stopping does not yet use the Tweedie distribution."); - ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), - "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); + bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || + FastTreeTrainerOptions.EnablePruning; + + // Please do not remove it! See comment above. + if (doEarlyStop) + ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 2, + nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 (L1-norm) or 2 (L2-norm)."); } private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch) diff --git a/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs index e01a6b25c7..e0241379d6 100644 --- a/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs +++ b/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs @@ -8,24 +8,30 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Trainers.FastTree; -[assembly: LoadableClass(typeof(TolerantEarlyStoppingCriterion), typeof(TolerantEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")] -[assembly: LoadableClass(typeof(GLEarlyStoppingCriterion), typeof(GLEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")] -[assembly: LoadableClass(typeof(LPEarlyStoppingCriterion), typeof(LPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Low Progress (LP)", "lp")] -[assembly: LoadableClass(typeof(PQEarlyStoppingCriterion), typeof(PQEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Generality to Progress Ratio (PQ)", "pq")] -[assembly: LoadableClass(typeof(UPEarlyStoppingCriterion), typeof(UPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Consecutive Loss in Generality (UP)", "up")] - -[assembly: EntryPointModule(typeof(TolerantEarlyStoppingCriterion))] -[assembly: EntryPointModule(typeof(GLEarlyStoppingCriterion))] -[assembly: EntryPointModule(typeof(LPEarlyStoppingCriterion))] -[assembly: EntryPointModule(typeof(PQEarlyStoppingCriterion))] -[assembly: EntryPointModule(typeof(UPEarlyStoppingCriterion))] +[assembly: LoadableClass(typeof(TolerantEarlyStoppingRule), typeof(TolerantEarlyStoppingRule.Options), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")] +[assembly: LoadableClass(typeof(GeneralityLossRule), typeof(GeneralityLossRule.Options), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")] +[assembly: LoadableClass(typeof(LowProgressRule), typeof(LowProgressRule.Options), typeof(SignatureEarlyStoppingCriterion), "Low Progress (LP)", "lp")] +[assembly: LoadableClass(typeof(GeneralityToProgressRatioRule), typeof(GeneralityToProgressRatioRule.Options), typeof(SignatureEarlyStoppingCriterion), "Generality to Progress Ratio (PQ)", "pq")] +[assembly: LoadableClass(typeof(ConsecutiveGeneralityLossRule), typeof(ConsecutiveGeneralityLossRule.Options), typeof(SignatureEarlyStoppingCriterion), "Consecutive Loss in Generality (UP)", "up")] + +[assembly: EntryPointModule(typeof(TolerantEarlyStoppingRule))] +[assembly: EntryPointModule(typeof(GeneralityLossRule))] +[assembly: EntryPointModule(typeof(LowProgressRule))] +[assembly: EntryPointModule(typeof(GeneralityToProgressRatioRule))] +[assembly: EntryPointModule(typeof(ConsecutiveGeneralityLossRule))] + +[assembly: EntryPointModule(typeof(TolerantEarlyStoppingRule.Options))] +[assembly: EntryPointModule(typeof(GeneralityLossRule.Options))] +[assembly: EntryPointModule(typeof(LowProgressRule.Options))] +[assembly: EntryPointModule(typeof(GeneralityToProgressRatioRule.Options))] +[assembly: EntryPointModule(typeof(ConsecutiveGeneralityLossRule.Options))] namespace Microsoft.ML.Trainers.FastTree { internal delegate void SignatureEarlyStoppingCriterion(bool lowerIsBetter); // These criteria will be used in FastTree and NeuralNets. - public abstract class IEarlyStoppingCriterion + public abstract class EarlyStoppingRuleBase { /// /// Check if the learning should stop or not. @@ -35,24 +41,36 @@ public abstract class IEarlyStoppingCriterion /// True if the current result is the best ever. /// If true, the learning should stop. public abstract bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate); + + /// + /// Having constructor without parameter prevents derivations of from being + /// implemented by the public. + /// + private protected EarlyStoppingRuleBase() { } + + /// + /// Create for supporting legacy infra built upon . + /// + internal abstract IEarlyStoppingCriterionFactory BuildFactory(); } + [BestFriend] [TlcModule.ComponentKind("EarlyStoppingCriterion")] - public interface IEarlyStoppingCriterionFactory : IComponentFactory + internal interface IEarlyStoppingCriterionFactory : IComponentFactory { - new IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter); + new EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter); } - public abstract class EarlyStoppingCriterion : IEarlyStoppingCriterion - where TOptions : EarlyStoppingCriterion.OptionsBase + public abstract class EarlyStoppingRule : EarlyStoppingRuleBase { - public abstract class OptionsBase { } - private float _bestScore; - protected readonly TOptions EarlyStoppingCriterionOptions; - protected readonly bool LowerIsBetter; - protected float BestScore { + /// + /// It's if the selected stopping metric should be as low as possible, and otherwise. + /// + private protected bool LowerIsBetter { get; } + + private protected float BestScore { get { return _bestScore; } set { @@ -61,19 +79,26 @@ protected float BestScore { } } - internal EarlyStoppingCriterion(TOptions options, bool lowerIsBetter) + private protected EarlyStoppingRule(bool lowerIsBetter) : base() { - EarlyStoppingCriterionOptions = options; LowerIsBetter = lowerIsBetter; _bestScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity; } + /// + /// Lazy constructor. It doesn't initialize anything because in runtime, will be + /// called inside the training process to initialize needed fields. + /// + private protected EarlyStoppingRule() : base() + { + } + /// /// Check if the given score is the best ever. The best score will be stored at this._bestScore. /// /// The latest score /// True if the given score is the best ever. - protected bool CheckBestScore(float score) + private protected bool CheckBestScore(float score) { bool isBestEver = ((score > BestScore) != LowerIsBetter); if (isBestEver) @@ -83,27 +108,50 @@ protected bool CheckBestScore(float score) } } - public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion + public sealed class TolerantEarlyStoppingRule : EarlyStoppingRule { + [BestFriend] [TlcModule.Component(FriendlyName = "Tolerant (TR)", Name = "TR", Desc = "Stop if validation score exceeds threshold value.")] - public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory + internal sealed class Options : IEarlyStoppingCriterionFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance threshold. (Non negative value)", ShortName = "th")] [TlcModule.Range(Min = 0.0f)] public float Threshold = 0.01f; - public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) + public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new TolerantEarlyStoppingCriterion(this, lowerIsBetter); + return new TolerantEarlyStoppingRule(this, lowerIsBetter); } } - public TolerantEarlyStoppingCriterion(Options options, bool lowerIsBetter) - : base(options, lowerIsBetter) + /// + /// The upper bound of the indicated metric on validation set. + /// + public float Threshold { get; } + + /// + /// Create a rule which may terminate the training process if validation score exceeds compared with + /// the best historical validation score. + /// + /// The maximum difference allowed between the (current) validation score and its best historical value. + public TolerantEarlyStoppingRule(float threshold = 0.01f) + : base() { - Contracts.CheckUserArg(EarlyStoppingCriterionOptions.Threshold >= 0, nameof(options.Threshold), "Must be non-negative."); + Contracts.CheckUserArg(threshold >= 0, nameof(threshold), "Must be non-negative."); + Threshold = threshold; } + // Used in command line tool to construct lodable class. + private TolerantEarlyStoppingRule(Options options, bool lowerIsBetter) + : base(lowerIsBetter) + { + Contracts.CheckUserArg(options.Threshold >= 0, nameof(options.Threshold), "Must be non-negative."); + Threshold = options.Threshold; + } + + /// + /// See . + /// public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { Contracts.Assert(validationScore >= 0); @@ -111,19 +159,22 @@ public override bool CheckScore(float validationScore, float trainingScore, out isBestCandidate = CheckBestScore(validationScore); if (LowerIsBetter) - return (validationScore - BestScore > EarlyStoppingCriterionOptions.Threshold); + return (validationScore - BestScore > Threshold); else - return (BestScore - validationScore > EarlyStoppingCriterionOptions.Threshold); + return (BestScore - validationScore > Threshold); } + + internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { Threshold = Threshold }; } // For the detail of the following rules, see the following paper. // Lodwich, Aleksander, Yves Rangoni, and Thomas Breuel. "Evaluation of robustness and performance of early stopping rules with multi layer perceptrons." // Neural Networks, 2009. IJCNN 2009. International Joint Conference on. IEEE, 2009. - public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion + public abstract class MovingWindowRule : EarlyStoppingRule { - public class Options : OptionsBase + [BestFriend] + internal class Options { [Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")] [TlcModule.Range(Min = 0.0f, Max = 1.0f)] @@ -134,15 +185,39 @@ public class Options : OptionsBase public int WindowSize = 5; } - protected Queue PastScores; + /// + /// A threshold in range [0, 1]. + /// + public float Threshold { get; } + + /// + /// The number of historical validation scores considered when determining if the training process should stop. + /// + public int WindowSize { get; } + + // Hide this because it's a runtime value. + private protected Queue PastScores; + + private protected MovingWindowRule(float threshold, int windowSize) + : base() + { + Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1]."); + Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive."); + + Threshold = threshold; + WindowSize = windowSize; + PastScores = new Queue(windowSize); + } - private protected MovingWindowEarlyStoppingCriterion(Options args, bool lowerIsBetter) - : base(args, lowerIsBetter) + private protected MovingWindowRule(bool lowerIsBetter, float threshold, int windowSize) + : base(lowerIsBetter) { - Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1]."); - Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(args.WindowSize), "Must be positive."); + Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1]."); + Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive."); - PastScores = new Queue(EarlyStoppingCriterionOptions.WindowSize); + Threshold = threshold; + WindowSize = windowSize; + PastScores = new Queue(windowSize); } /// @@ -177,7 +252,7 @@ private float GetRecentBest(IEnumerable recentScores) return recentBestScore; } - protected bool CheckRecentScores(float score, int windowSize, out float recentBest, out float recentAverage) + private protected bool CheckRecentScores(float score, int windowSize, out float recentBest, out float recentAverage) { if (PastScores.Count >= windowSize) { @@ -200,28 +275,53 @@ protected bool CheckRecentScores(float score, int windowSize, out float recentBe /// /// Loss of Generality (GL). /// - public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion + public sealed class GeneralityLossRule : EarlyStoppingRule { + [BestFriend] [TlcModule.Component(FriendlyName = "Loss of Generality (GL)", Name = "GL", Desc = "Stop in case of loss of generality.")] - public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory + internal sealed class Options : IEarlyStoppingCriterionFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")] [TlcModule.Range(Min = 0.0f, Max = 1.0f)] public float Threshold = 0.01f; - public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) + public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new GLEarlyStoppingCriterion(this, lowerIsBetter); + return new GeneralityLossRule(this, lowerIsBetter); } } - public GLEarlyStoppingCriterion(Options options, bool lowerIsBetter) - : base(options, lowerIsBetter) + /// + /// The maximum gap (in percentage such as 0.01 for 1% and 0.5 for 50%) between the (current) validation + /// score and its best historical value. + /// + public float Threshold { get; } + + /// + /// Create a rule which may terminate the training process in case of loss of generality. The loss of generality means + /// the specified score on validation start increaseing. + /// + /// The maximum gap (in percentage such as 0.01 for 1% and 0.5 for 50%) between the (current) validation + /// score and its best historical value. + public GeneralityLossRule(float threshold = 0.01f) : + base() + { + Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1]."); + Threshold = threshold; + } + + // Used in command line tool to construct lodable class. + private GeneralityLossRule(Options options, bool lowerIsBetter) + : base(lowerIsBetter) { - Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1]."); + Contracts.CheckUserArg(0 <= options.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1]."); + Threshold = options.Threshold; } + /// + /// See . + /// public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { Contracts.Assert(validationScore >= 0); @@ -229,30 +329,52 @@ public override bool CheckScore(float validationScore, float trainingScore, out isBestCandidate = CheckBestScore(validationScore); if (LowerIsBetter) - return (validationScore > (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore); + return (validationScore > (1 + Threshold) * BestScore); else - return (validationScore < (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore); + return (validationScore < (1 - Threshold) * BestScore); } - } + + internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { Threshold = Threshold }; +} /// /// Low Progress (LP). /// This rule fires when the improvements on the score stall. /// - public sealed class LPEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion + public sealed class LowProgressRule : MovingWindowRule { + [BestFriend] [TlcModule.Component(FriendlyName = "Low Progress (LP)", Name = "LP", Desc = "Stops in case of low progress.")] - public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory + internal new sealed class Options : MovingWindowRule.Options, IEarlyStoppingCriterionFactory { - public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) + public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new LPEarlyStoppingCriterion(this, lowerIsBetter); + return new LowProgressRule(this, lowerIsBetter); } } - public LPEarlyStoppingCriterion(Options options, bool lowerIsBetter) - : base(options, lowerIsBetter) { } + /// + /// Create a rule which may terminate the training process when the improvements in terms of validation score is slow. + /// It will terminate the training process if the average of the recent validation scores + /// is worse than the best historical validation score. + /// + /// The maximum gap (in percentage such as 0.01 for 1% and 0.5 for 50%) between the (current) averaged validation + /// score and its best historical value. + /// See . + public LowProgressRule(float threshold = 0.01f, int windowSize = 5) + : base(threshold, windowSize) + { + } + // Used in command line tool to construct lodable class. + private LowProgressRule(Options options, bool lowerIsBetter) + : base(lowerIsBetter, options.Threshold, options.WindowSize) + { + } + + /// + /// See . + /// public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { Contracts.Assert(validationScore >= 0); @@ -262,35 +384,54 @@ public override bool CheckScore(float validationScore, float trainingScore, out float recentBest; float recentAverage; - if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage)) + if (CheckRecentScores(trainingScore, WindowSize, out recentBest, out recentAverage)) { if (LowerIsBetter) - return (recentAverage <= (1 + EarlyStoppingCriterionOptions.Threshold) * recentBest); + return (recentAverage <= (1 + Threshold) * recentBest); else - return (recentAverage >= (1 - EarlyStoppingCriterionOptions.Threshold) * recentBest); + return (recentAverage >= (1 - Threshold) * recentBest); } return false; } - } + + internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { Threshold = Threshold, WindowSize = WindowSize }; +} /// /// Generality to Progress Ratio (PQ). /// - public sealed class PQEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion + public sealed class GeneralityToProgressRatioRule : MovingWindowRule { + [BestFriend] [TlcModule.Component(FriendlyName = "Generality to Progress Ratio (PQ)", Name = "PQ", Desc = "Stops in case of generality to progress ration exceeds threshold.")] - public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory + internal new sealed class Options : MovingWindowRule.Options, IEarlyStoppingCriterionFactory { - public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) + public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new PQEarlyStoppingCriterion(this, lowerIsBetter); + return new GeneralityToProgressRatioRule(this, lowerIsBetter); } } - public PQEarlyStoppingCriterion(Options options, bool lowerIsBetter) - : base(options, lowerIsBetter) { } + /// + /// Create a rule which may terminate the training process when generality-to-progress ratio exceeds . + /// + /// The maximum ratio gap (in percentage such as 0.01 for 1% and 0.5 for 50%). + /// See . + public GeneralityToProgressRatioRule(float threshold = 0.01f, int windowSize = 5) + : base(threshold, windowSize) + { + } + + // Used in command line tool to construct lodable class. + private GeneralityToProgressRatioRule(Options options, bool lowerIsBetter) + : base(lowerIsBetter, options.Threshold, options.WindowSize) + { + } + /// + /// See . + /// public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { Contracts.Assert(validationScore >= 0); @@ -300,48 +441,72 @@ public override bool CheckScore(float validationScore, float trainingScore, out float recentBest; float recentAverage; - if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage)) + if (CheckRecentScores(trainingScore, WindowSize, out recentBest, out recentAverage)) { if (LowerIsBetter) - return (validationScore * recentBest >= (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage); + return (validationScore * recentBest >= (1 + Threshold) * BestScore * recentAverage); else - return (validationScore * recentBest <= (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage); + return (validationScore * recentBest <= (1 - Threshold) * BestScore * recentAverage); } return false; } + + [BestFriend] + internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { Threshold = Threshold, WindowSize = WindowSize }; } /// /// Consecutive Loss in Generality (UP). /// - public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion + public sealed class ConsecutiveGeneralityLossRule : EarlyStoppingRule { + [BestFriend] [TlcModule.Component(FriendlyName = "Consecutive Loss in Generality (UP)", Name = "UP", Desc = "Stops in case of consecutive loss in generality.")] - public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory + internal sealed class Options : IEarlyStoppingCriterionFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "The window size.", ShortName = "w")] [TlcModule.Range(Inf = 0)] public int WindowSize = 5; - public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) + public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new UPEarlyStoppingCriterion(this, lowerIsBetter); + return new ConsecutiveGeneralityLossRule(this, lowerIsBetter); } } + /// + /// The number of historical validation scores considered when determining if the training process should stop. + /// + public int WindowSize { get; } + private int _count; private float _prevScore; - public UPEarlyStoppingCriterion(Options options, bool lowerIsBetter) - : base(options, lowerIsBetter) + /// + /// Creates a rule which terminates the training process if the validation score is not improved in consecutive iterations. + /// + /// Number of training iterations allowed to have no improvement. + public ConsecutiveGeneralityLossRule(int windowSize = 5) + : base() { - Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(options.WindowSize), "Must be positive"); + Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive"); + WindowSize = windowSize; + _prevScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity; + } + private ConsecutiveGeneralityLossRule(Options options, bool lowerIsBetter) + : base(lowerIsBetter) + { + Contracts.CheckUserArg(options.WindowSize > 0, nameof(options.WindowSize), "Must be positive"); + WindowSize = options.WindowSize; _prevScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity; } + /// + /// See . + /// public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { Contracts.Assert(validationScore >= 0); @@ -351,7 +516,10 @@ public override bool CheckScore(float validationScore, float trainingScore, out _count = ((validationScore < _prevScore) != LowerIsBetter) ? _count + 1 : 0; _prevScore = validationScore; - return (_count >= EarlyStoppingCriterionOptions.WindowSize); + return (_count >= WindowSize); } + + [BestFriend] + internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { WindowSize = WindowSize }; } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs index c1ee89342f..214e6d4f2a 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs @@ -10,18 +10,18 @@ namespace Microsoft.ML.RunTests { public sealed class TestEarlyStoppingCriteria { - private IEarlyStoppingCriterion CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter) + private EarlyStoppingRuleBase CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter) { var env = new MLContext() .AddStandardComponents(); - var sub = new SubComponent(name, args); + var sub = new SubComponent(name, args); return sub.CreateInstance(env, lowerIsBetter); } [Fact] public void TolerantEarlyStoppingCriterionTest() { - IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("tr", "th=0.01", false); + EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("tr", "th=0.01", false); bool isBestCandidate; bool shouldStop; @@ -46,7 +46,7 @@ public void TolerantEarlyStoppingCriterionTest() [Fact] public void GLEarlyStoppingCriterionTest() { - IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("gl", "th=0.01", false); + EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("gl", "th=0.01", false); bool isBestCandidate; bool shouldStop; @@ -71,7 +71,7 @@ public void GLEarlyStoppingCriterionTest() [Fact] public void LPEarlyStoppingCriterionTest() { - IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("lp", "th=0.01 w=5", false); + EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("lp", "th=0.01 w=5", false); bool isBestCandidate; bool shouldStop; @@ -107,7 +107,7 @@ public void LPEarlyStoppingCriterionTest() [Fact] public void PQEarlyStoppingCriterionTest() { - IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("pq", "th=0.01 w=5", false); + EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("pq", "th=0.01 w=5", false); bool isBestCandidate; bool shouldStop; @@ -144,7 +144,7 @@ public void PQEarlyStoppingCriterionTest() public void UPEarlyStoppingCriterionTest() { const int windowSize = 8; - IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("up", "w=8", false); + EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("up", "w=8", false); bool isBestCandidate; bool shouldStop; diff --git a/test/Microsoft.ML.Functional.Tests/Validation.cs b/test/Microsoft.ML.Functional.Tests/Validation.cs index cb0cacfc7e..9953cd5fc7 100644 --- a/test/Microsoft.ML.Functional.Tests/Validation.cs +++ b/test/Microsoft.ML.Functional.Tests/Validation.cs @@ -84,7 +84,7 @@ public void TrainWithValidationSet() var trainedModel = mlContext.Regression.Trainers.FastTree(new Trainers.FastTree.FastTreeRegressionTrainer.Options { NumberOfTrees = 2, EarlyStoppingMetric = EarlyStoppingMetric.L2Norm, - EarlyStoppingRule = new GLEarlyStoppingCriterion.Options() + EarlyStoppingRule = new GeneralityLossRule() }) .Fit(trainData: preprocessedTrainData, validationData: preprocessedValidData);