From 7bdff400eba97c14eac8954c9f72598711ad8b41 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 5 Mar 2019 10:32:10 -0800 Subject: [PATCH 1/5] Make some changes. --- .../Training/EarlyStoppingCriteria.cs | 141 +++++++++++------- .../UnitTests/TestEntryPoints.cs | 2 +- 2 files changed, 92 insertions(+), 51 deletions(-) diff --git a/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs index e01a6b25c7..5cb628dbed 100644 --- a/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs +++ b/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs @@ -43,14 +43,10 @@ public interface IEarlyStoppingCriterionFactory : IComponentFactory : IEarlyStoppingCriterion - where TOptions : EarlyStoppingCriterion.OptionsBase + public abstract class EarlyStoppingCriterion : IEarlyStoppingCriterion { - public abstract class OptionsBase { } - private float _bestScore; - protected readonly TOptions EarlyStoppingCriterionOptions; protected readonly bool LowerIsBetter; protected float BestScore { get { return _bestScore; } @@ -61,9 +57,8 @@ protected float BestScore { } } - internal EarlyStoppingCriterion(TOptions options, bool lowerIsBetter) + internal EarlyStoppingCriterion(bool lowerIsBetter) { - EarlyStoppingCriterionOptions = options; LowerIsBetter = lowerIsBetter; _bestScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity; } @@ -83,10 +78,10 @@ protected bool CheckBestScore(float score) } } - public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion + public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion { [TlcModule.Component(FriendlyName = "Tolerant (TR)", Name = "TR", Desc = "Stop if validation score exceeds threshold value.")] - public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory + public sealed class Options : IEarlyStoppingCriterionFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance threshold. (Non negative value)", ShortName = "th")] [TlcModule.Range(Min = 0.0f)] @@ -94,14 +89,23 @@ public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new TolerantEarlyStoppingCriterion(this, lowerIsBetter); + return new TolerantEarlyStoppingCriterion(Threshold, lowerIsBetter); } } - public TolerantEarlyStoppingCriterion(Options options, bool lowerIsBetter) - : base(options, lowerIsBetter) + public float Threshold { get; } + + public TolerantEarlyStoppingCriterion(float threshold, bool lowerIsBetter = true) + : base(lowerIsBetter) + { + Contracts.CheckUserArg(threshold >= 0, nameof(threshold), "Must be non-negative."); + Threshold = threshold; + } + + [BestFriend] + internal TolerantEarlyStoppingCriterion(Options options, bool lowerIsBetter = true) + : this(options.Threshold, lowerIsBetter) { - Contracts.CheckUserArg(EarlyStoppingCriterionOptions.Threshold >= 0, nameof(options.Threshold), "Must be non-negative."); } public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) @@ -111,9 +115,9 @@ public override bool CheckScore(float validationScore, float trainingScore, out isBestCandidate = CheckBestScore(validationScore); if (LowerIsBetter) - return (validationScore - BestScore > EarlyStoppingCriterionOptions.Threshold); + return (validationScore - BestScore > Threshold); else - return (BestScore - validationScore > EarlyStoppingCriterionOptions.Threshold); + return (BestScore - validationScore > Threshold); } } @@ -121,9 +125,9 @@ public override bool CheckScore(float validationScore, float trainingScore, out // Lodwich, Aleksander, Yves Rangoni, and Thomas Breuel. "Evaluation of robustness and performance of early stopping rules with multi layer perceptrons." // Neural Networks, 2009. IJCNN 2009. International Joint Conference on. IEEE, 2009. - public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion + public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion { - public class Options : OptionsBase + public class Options { [Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")] [TlcModule.Range(Min = 0.0f, Max = 1.0f)] @@ -134,15 +138,20 @@ public class Options : OptionsBase public int WindowSize = 5; } + public float Threshold { get; } + public int WindowSize { get; } + protected Queue PastScores; - private protected MovingWindowEarlyStoppingCriterion(Options args, bool lowerIsBetter) - : base(args, lowerIsBetter) + private protected MovingWindowEarlyStoppingCriterion(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5) + : base(lowerIsBetter) { - Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1]."); - Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(args.WindowSize), "Must be positive."); + Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1]."); + Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive."); - PastScores = new Queue(EarlyStoppingCriterionOptions.WindowSize); + Threshold = threshold; + WindowSize = windowSize; + PastScores = new Queue(windowSize); } /// @@ -200,11 +209,11 @@ protected bool CheckRecentScores(float score, int windowSize, out float recentBe /// /// Loss of Generality (GL). /// - public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion + public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion { [TlcModule.Component(FriendlyName = "Loss of Generality (GL)", Name = "GL", Desc = "Stop in case of loss of generality.")] - public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory + public sealed class Options : IEarlyStoppingCriterionFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")] [TlcModule.Range(Min = 0.0f, Max = 1.0f)] @@ -212,14 +221,23 @@ public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new GLEarlyStoppingCriterion(this, lowerIsBetter); + return new GLEarlyStoppingCriterion(lowerIsBetter, Threshold); } } - public GLEarlyStoppingCriterion(Options options, bool lowerIsBetter) - : base(options, lowerIsBetter) + public float Threshold { get; } + + public GLEarlyStoppingCriterion(bool lowerIsBetter = true, float threshold = 0.01f) : + base(lowerIsBetter) + { + Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1]."); + Threshold = threshold; + } + + [BestFriend] + internal GLEarlyStoppingCriterion(Options options, bool lowerIsBetter = true) + : this(lowerIsBetter, options.Threshold) { - Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1]."); } public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) @@ -229,9 +247,9 @@ public override bool CheckScore(float validationScore, float trainingScore, out isBestCandidate = CheckBestScore(validationScore); if (LowerIsBetter) - return (validationScore > (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore); + return (validationScore > (1 + Threshold) * BestScore); else - return (validationScore < (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore); + return (validationScore < (1 - Threshold) * BestScore); } } @@ -246,12 +264,20 @@ public sealed class LPEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterio { public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new LPEarlyStoppingCriterion(this, lowerIsBetter); + return new LPEarlyStoppingCriterion(lowerIsBetter, Threshold, WindowSize); } } - public LPEarlyStoppingCriterion(Options options, bool lowerIsBetter) - : base(options, lowerIsBetter) { } + public LPEarlyStoppingCriterion(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5) + : base(lowerIsBetter, threshold, windowSize) + { + } + + [BestFriend] + internal LPEarlyStoppingCriterion(Options options, bool lowerIsBetter = true) + : this(lowerIsBetter, options.Threshold, options.WindowSize) + { + } public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { @@ -262,12 +288,12 @@ public override bool CheckScore(float validationScore, float trainingScore, out float recentBest; float recentAverage; - if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage)) + if (CheckRecentScores(trainingScore, WindowSize, out recentBest, out recentAverage)) { if (LowerIsBetter) - return (recentAverage <= (1 + EarlyStoppingCriterionOptions.Threshold) * recentBest); + return (recentAverage <= (1 + Threshold) * recentBest); else - return (recentAverage >= (1 - EarlyStoppingCriterionOptions.Threshold) * recentBest); + return (recentAverage >= (1 - Threshold) * recentBest); } return false; @@ -284,12 +310,20 @@ public sealed class PQEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterio { public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new PQEarlyStoppingCriterion(this, lowerIsBetter); + return new PQEarlyStoppingCriterion(lowerIsBetter, Threshold, WindowSize); } } - public PQEarlyStoppingCriterion(Options options, bool lowerIsBetter) - : base(options, lowerIsBetter) { } + public PQEarlyStoppingCriterion(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5) + : base(lowerIsBetter, threshold, windowSize) + { + } + + [BestFriend] + internal PQEarlyStoppingCriterion(Options options, bool lowerIsBetter = true) + : this(lowerIsBetter, options.Threshold, options.WindowSize) + { + } public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { @@ -300,12 +334,12 @@ public override bool CheckScore(float validationScore, float trainingScore, out float recentBest; float recentAverage; - if (CheckRecentScores(trainingScore, EarlyStoppingCriterionOptions.WindowSize, out recentBest, out recentAverage)) + if (CheckRecentScores(trainingScore, WindowSize, out recentBest, out recentAverage)) { if (LowerIsBetter) - return (validationScore * recentBest >= (1 + EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage); + return (validationScore * recentBest >= (1 + Threshold) * BestScore * recentAverage); else - return (validationScore * recentBest <= (1 - EarlyStoppingCriterionOptions.Threshold) * BestScore * recentAverage); + return (validationScore * recentBest <= (1 - Threshold) * BestScore * recentAverage); } return false; @@ -315,11 +349,11 @@ public override bool CheckScore(float validationScore, float trainingScore, out /// /// Consecutive Loss in Generality (UP). /// - public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion + public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion { [TlcModule.Component(FriendlyName = "Consecutive Loss in Generality (UP)", Name = "UP", Desc = "Stops in case of consecutive loss in generality.")] - public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory + public sealed class Options : IEarlyStoppingCriterionFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "The window size.", ShortName = "w")] [TlcModule.Range(Inf = 0)] @@ -327,21 +361,28 @@ public sealed class Options : OptionsBase, IEarlyStoppingCriterionFactory public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new UPEarlyStoppingCriterion(this, lowerIsBetter); + return new UPEarlyStoppingCriterion(lowerIsBetter, WindowSize); } } + public int WindowSize { get; } private int _count; private float _prevScore; - public UPEarlyStoppingCriterion(Options options, bool lowerIsBetter) - : base(options, lowerIsBetter) + public UPEarlyStoppingCriterion(bool lowerIsBetter, int windowSize = 5) + : base(lowerIsBetter) { - Contracts.CheckUserArg(EarlyStoppingCriterionOptions.WindowSize > 0, nameof(options.WindowSize), "Must be positive"); - + Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive"); + WindowSize = windowSize; _prevScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity; } + [BestFriend] + internal UPEarlyStoppingCriterion(Options options, bool lowerIsBetter = true) + : this(lowerIsBetter, options.WindowSize) + { + } + public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { Contracts.Assert(validationScore >= 0); @@ -351,7 +392,7 @@ public override bool CheckScore(float validationScore, float trainingScore, out _count = ((validationScore < _prevScore) != LowerIsBetter) ? _count + 1 : 0; _prevScore = validationScore; - return (_count >= EarlyStoppingCriterionOptions.WindowSize); + return (_count >= WindowSize); } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 0046e4a469..f06690fa2c 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -4974,7 +4974,7 @@ public void TestCrossValidationMacroWithNonDefaultNames() 'MaximumNumberOfLineSearchSteps': 0, 'MinimumStepSize': 0.0, 'OptimizationAlgorithm': 'GradientDescent', - 'EarlyStoppingRule': null, + 'EarlyStoppingRuleFactory': null, 'EarlyStoppingMetrics': 1, 'EnablePruning': false, 'UseTolerantPruning': false, From 8f4422f788500031047e3de8d2c7716247fee9a7 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 5 Mar 2019 10:32:22 -0800 Subject: [PATCH 2/5] First version of new early stopping rule. Generate missing entry points --- src/Microsoft.ML.FastTree/BoostingFastTree.cs | 20 +- src/Microsoft.ML.FastTree/FastTree.cs | 4 +- .../FastTreeArguments.cs | 5 +- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 9 +- .../FastTreeRegression.cs | 9 +- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 11 +- .../Training/EarlyStoppingCriteria.cs | 187 ++++++++++++------ .../Common/EntryPoints/core_manifest.json | 16 +- .../UnitTests/TestEarlyStoppingCriteria.cs | 14 +- .../Validation.cs | 2 +- 10 files changed, 190 insertions(+), 87 deletions(-) diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index e587dcb4c3..56f2285a70 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -49,7 +49,8 @@ private protected override void CheckOptions(IChannel ch) if (FastTreeTrainerOptions.EnablePruning && !HasValidSet) throw ch.Except("Cannot perform pruning (pruning) without a validation set (valid)."); - if (FastTreeTrainerOptions.EarlyStoppingRule != null && !HasValidSet) + bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || FastTreeTrainerOptions.EarlyStoppingRule != null; + if (doEarlyStop && !HasValidSet) throw ch.Except("Cannot perform early stopping without a validation set (valid)."); if (FastTreeTrainerOptions.UseTolerantPruning && (!FastTreeTrainerOptions.EnablePruning || !HasValidSet)) @@ -113,9 +114,9 @@ private protected override IGradientAdjuster MakeGradientWrapper(IChannel ch) return new BestStepRegressionGradientWrapper(); } - private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStoppingRule, ref int bestIteration) + private protected override bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStoppingRule, ref int bestIteration) { - if (FastTreeTrainerOptions.EarlyStoppingRule == null) + if (FastTreeTrainerOptions.EarlyStoppingRuleFactory == null && FastTreeTrainerOptions.EarlyStoppingRule == null) return false; ch.AssertValue(ValidTest); @@ -128,13 +129,20 @@ private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriter var trainingResult = TrainTest.ComputeTests().First(); ch.Assert(trainingResult.FinalValue >= 0); - // Create early stopping rule. + // Create early stopping rule if it's null. if (earlyStoppingRule == null) { - earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRule.CreateComponent(Host, lowerIsBetter); - ch.Assert(earlyStoppingRule != null); + // There are two possible sources of stopping rules. One is the classical IComponentFactory and + // the other one is the rule passed in directly by user. + if (FastTreeTrainerOptions.EarlyStoppingRuleFactory != null) + earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRuleFactory.CreateComponent(Host, lowerIsBetter); + else if (FastTreeTrainerOptions.EarlyStoppingRule != null) + earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRule; } + // Early stopping rule cannot be null! + ch.Assert(earlyStoppingRule != null); + bool isBestCandidate; bool shouldStop = earlyStoppingRule.CheckScore((float)validationResult.FinalValue, (float)trainingResult.FinalValue, out isBestCandidate); diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 2814ca859a..6162ca06f8 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -245,7 +245,7 @@ private protected void TrainCore(IChannel ch) } } - private protected virtual bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStopping, ref int bestIteration) + private protected virtual bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStopping, ref int bestIteration) { bestIteration = Ensemble.NumTrees; return false; @@ -650,7 +650,7 @@ private protected virtual void Train(IChannel ch) #endif #endif - IEarlyStoppingCriterion earlyStoppingRule = null; + EarlyStoppingRuleBase earlyStoppingRule = null; int bestIteration = 0; int emptyTrees = 0; using (var pch = Host.StartProgressChannel("FastTree training")) diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index 6d2f2efded..1953ef121b 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -621,9 +621,12 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc /// /// Early stopping rule. (Validation set (/valid) is required). /// + [BestFriend] [Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", ShortName = "esr", NullName = "")] [TGUI(Label = "Early Stopping Rule", Description = "Early stopping rule. (Validation set (/valid) is required.)")] - public IEarlyStoppingCriterionFactory EarlyStoppingRule; + internal IEarlyStoppingCriterionFactory EarlyStoppingRuleFactory; + + public EarlyStoppingRuleBase EarlyStoppingRule; /// /// Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3). diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index 6db1145e7c..efb0feabc8 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -156,8 +156,13 @@ private protected override void CheckOptions(IChannel ch) Dataset.DatasetSkeleton.LabelGainMap = gains; } - ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), - "earlyStoppingMetrics should be 1 or 3."); + bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || + FastTreeTrainerOptions.EarlyStoppingRule != null || + FastTreeTrainerOptions.EnablePruning; + + if (doEarlyStop) + ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3, + nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 or 3."); base.CheckOptions(ch); } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index ac7ca46fbc..7a4d2f121c 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -105,8 +105,13 @@ private protected override void CheckOptions(IChannel ch) base.CheckOptions(ch); - ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), - "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); + bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || + FastTreeTrainerOptions.EarlyStoppingRule != null || + FastTreeTrainerOptions.EnablePruning; + + if (doEarlyStop) + ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2, + nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); } private static SchemaShape.Column MakeLabelColumn(string labelColumn) diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 58510cec75..fe6c0d0e72 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -112,12 +112,17 @@ private protected override void CheckOptions(IChannel ch) // REVIEW: In order to properly support early stopping, the early stopping metric should be a subcomponent, not just // a simple integer, because the metric that we might want is parameterized by this floating point "index" parameter. For now // we just leave the existing regression checks, though with a warning. - if (FastTreeTrainerOptions.EarlyStoppingMetrics > 0) ch.Warning("For Tweedie regression, early stopping does not yet use the Tweedie distribution."); - ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), - "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); + bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || + FastTreeTrainerOptions.EarlyStoppingRule != null || + FastTreeTrainerOptions.EnablePruning; + + // Please do not remove it! See comment above. + if (doEarlyStop) + ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 2, + nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 (L1-norm) or 2 (L2-norm)."); } private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch) diff --git a/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs index 5cb628dbed..6b96bd95c7 100644 --- a/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs +++ b/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs @@ -8,24 +8,30 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Trainers.FastTree; -[assembly: LoadableClass(typeof(TolerantEarlyStoppingCriterion), typeof(TolerantEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")] -[assembly: LoadableClass(typeof(GLEarlyStoppingCriterion), typeof(GLEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")] -[assembly: LoadableClass(typeof(LPEarlyStoppingCriterion), typeof(LPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Low Progress (LP)", "lp")] -[assembly: LoadableClass(typeof(PQEarlyStoppingCriterion), typeof(PQEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Generality to Progress Ratio (PQ)", "pq")] -[assembly: LoadableClass(typeof(UPEarlyStoppingCriterion), typeof(UPEarlyStoppingCriterion.Options), typeof(SignatureEarlyStoppingCriterion), "Consecutive Loss in Generality (UP)", "up")] - -[assembly: EntryPointModule(typeof(TolerantEarlyStoppingCriterion))] -[assembly: EntryPointModule(typeof(GLEarlyStoppingCriterion))] -[assembly: EntryPointModule(typeof(LPEarlyStoppingCriterion))] -[assembly: EntryPointModule(typeof(PQEarlyStoppingCriterion))] -[assembly: EntryPointModule(typeof(UPEarlyStoppingCriterion))] +[assembly: LoadableClass(typeof(TolerantEarlyStoppingRule), typeof(TolerantEarlyStoppingRule.Options), typeof(SignatureEarlyStoppingCriterion), "Tolerant (TR)", "tr")] +[assembly: LoadableClass(typeof(GeneralityLossRule), typeof(GeneralityLossRule.Options), typeof(SignatureEarlyStoppingCriterion), "Loss of Generality (GL)", "gl")] +[assembly: LoadableClass(typeof(LowProgressRule), typeof(LowProgressRule.Options), typeof(SignatureEarlyStoppingCriterion), "Low Progress (LP)", "lp")] +[assembly: LoadableClass(typeof(GeneralityToProgressRatioRule), typeof(GeneralityToProgressRatioRule.Options), typeof(SignatureEarlyStoppingCriterion), "Generality to Progress Ratio (PQ)", "pq")] +[assembly: LoadableClass(typeof(ConsecutiveGeneralityLossRule), typeof(ConsecutiveGeneralityLossRule.Options), typeof(SignatureEarlyStoppingCriterion), "Consecutive Loss in Generality (UP)", "up")] + +[assembly: EntryPointModule(typeof(TolerantEarlyStoppingRule))] +[assembly: EntryPointModule(typeof(GeneralityLossRule))] +[assembly: EntryPointModule(typeof(LowProgressRule))] +[assembly: EntryPointModule(typeof(GeneralityToProgressRatioRule))] +[assembly: EntryPointModule(typeof(ConsecutiveGeneralityLossRule))] + +[assembly: EntryPointModule(typeof(TolerantEarlyStoppingRule.Options))] +[assembly: EntryPointModule(typeof(GeneralityLossRule.Options))] +[assembly: EntryPointModule(typeof(LowProgressRule.Options))] +[assembly: EntryPointModule(typeof(GeneralityToProgressRatioRule.Options))] +[assembly: EntryPointModule(typeof(ConsecutiveGeneralityLossRule.Options))] namespace Microsoft.ML.Trainers.FastTree { internal delegate void SignatureEarlyStoppingCriterion(bool lowerIsBetter); // These criteria will be used in FastTree and NeuralNets. - public abstract class IEarlyStoppingCriterion + public abstract class EarlyStoppingRuleBase { /// /// Check if the learning should stop or not. @@ -37,18 +43,23 @@ public abstract class IEarlyStoppingCriterion public abstract bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate); } + [BestFriend] [TlcModule.ComponentKind("EarlyStoppingCriterion")] - public interface IEarlyStoppingCriterionFactory : IComponentFactory + internal interface IEarlyStoppingCriterionFactory : IComponentFactory { - new IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter); + new EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter); } - public abstract class EarlyStoppingCriterion : IEarlyStoppingCriterion + public abstract class EarlyStoppingRule : EarlyStoppingRuleBase { private float _bestScore; - protected readonly bool LowerIsBetter; - protected float BestScore { + /// + /// It's if the selected stopping metric should be as low as possible, and otherwise. + /// + public bool LowerIsBetter { get; } + + private protected float BestScore { get { return _bestScore; } set { @@ -57,7 +68,7 @@ protected float BestScore { } } - internal EarlyStoppingCriterion(bool lowerIsBetter) + private protected EarlyStoppingRule(bool lowerIsBetter) { LowerIsBetter = lowerIsBetter; _bestScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity; @@ -68,7 +79,7 @@ internal EarlyStoppingCriterion(bool lowerIsBetter) /// /// The latest score /// True if the given score is the best ever. - protected bool CheckBestScore(float score) + private protected bool CheckBestScore(float score) { bool isBestEver = ((score > BestScore) != LowerIsBetter); if (isBestEver) @@ -78,24 +89,34 @@ protected bool CheckBestScore(float score) } } - public sealed class TolerantEarlyStoppingCriterion : EarlyStoppingCriterion + public sealed class TolerantEarlyStoppingRule : EarlyStoppingRule { + [BestFriend] [TlcModule.Component(FriendlyName = "Tolerant (TR)", Name = "TR", Desc = "Stop if validation score exceeds threshold value.")] - public sealed class Options : IEarlyStoppingCriterionFactory + internal sealed class Options : IEarlyStoppingCriterionFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance threshold. (Non negative value)", ShortName = "th")] [TlcModule.Range(Min = 0.0f)] public float Threshold = 0.01f; - public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) + public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new TolerantEarlyStoppingCriterion(Threshold, lowerIsBetter); + return new TolerantEarlyStoppingRule(lowerIsBetter, Threshold); } } + /// + /// The upper bound of the indicated metric on validation set. + /// public float Threshold { get; } - public TolerantEarlyStoppingCriterion(float threshold, bool lowerIsBetter = true) + /// + /// Create a rule which may terminate the training process if validation score exceeds compared with + /// the best historical validation score. + /// + /// Its meaning is identical to . + /// The maximum difference allowed between the (current) validation score and its best historical value. + public TolerantEarlyStoppingRule(bool lowerIsBetter = true, float threshold = 0.01f) : base(lowerIsBetter) { Contracts.CheckUserArg(threshold >= 0, nameof(threshold), "Must be non-negative."); @@ -103,11 +124,15 @@ public TolerantEarlyStoppingCriterion(float threshold, bool lowerIsBetter = true } [BestFriend] - internal TolerantEarlyStoppingCriterion(Options options, bool lowerIsBetter = true) - : this(options.Threshold, lowerIsBetter) + // Used in command line tool to construct lodable class. + internal TolerantEarlyStoppingRule(Options options, bool lowerIsBetter = true) + : this(lowerIsBetter, options.Threshold) { } + /// + /// See . + /// public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { Contracts.Assert(validationScore >= 0); @@ -125,9 +150,10 @@ public override bool CheckScore(float validationScore, float trainingScore, out // Lodwich, Aleksander, Yves Rangoni, and Thomas Breuel. "Evaluation of robustness and performance of early stopping rules with multi layer perceptrons." // Neural Networks, 2009. IJCNN 2009. International Joint Conference on. IEEE, 2009. - public abstract class MovingWindowEarlyStoppingCriterion : EarlyStoppingCriterion + public abstract class MovingWindowRule : EarlyStoppingRule { - public class Options + [BestFriend] + internal class Options { [Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")] [TlcModule.Range(Min = 0.0f, Max = 1.0f)] @@ -138,12 +164,20 @@ public class Options public int WindowSize = 5; } + /// + /// A threshold in range [0, 1]. + /// public float Threshold { get; } + + /// + /// The number of historical validation scores considered when determining if the training process should stop. + /// public int WindowSize { get; } - protected Queue PastScores; + // Hide this because it's a runtime value. + private protected Queue PastScores; - private protected MovingWindowEarlyStoppingCriterion(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5) + private protected MovingWindowRule(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5) : base(lowerIsBetter) { Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1]."); @@ -186,7 +220,7 @@ private float GetRecentBest(IEnumerable recentScores) return recentBestScore; } - protected bool CheckRecentScores(float score, int windowSize, out float recentBest, out float recentAverage) + private protected bool CheckRecentScores(float score, int windowSize, out float recentBest, out float recentAverage) { if (PastScores.Count >= windowSize) { @@ -209,25 +243,37 @@ protected bool CheckRecentScores(float score, int windowSize, out float recentBe /// /// Loss of Generality (GL). /// - public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion + public sealed class GeneralityLossRule : EarlyStoppingRule { + [BestFriend] [TlcModule.Component(FriendlyName = "Loss of Generality (GL)", Name = "GL", Desc = "Stop in case of loss of generality.")] - public sealed class Options : IEarlyStoppingCriterionFactory + internal sealed class Options : IEarlyStoppingCriterionFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "Threshold in range [0,1].", ShortName = "th")] [TlcModule.Range(Min = 0.0f, Max = 1.0f)] public float Threshold = 0.01f; - public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) + public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new GLEarlyStoppingCriterion(lowerIsBetter, Threshold); + return new GeneralityLossRule(lowerIsBetter, Threshold); } } + /// + /// The maximum gap (in percentage such as 0.01 for 1% and 0.5 for 50%) between the (current) validation + /// score and its best historical value. + /// public float Threshold { get; } - public GLEarlyStoppingCriterion(bool lowerIsBetter = true, float threshold = 0.01f) : + /// + /// Create a rule which may terminate the training process in case of loss of generality. The loss of generality means + /// the specified score on validation start increaseing. + /// + /// Its meaning is identical to . + /// The maximum gap (in percentage such as 0.01 for 1% and 0.5 for 50%) between the (current) validation + /// score and its best historical value. + public GeneralityLossRule(bool lowerIsBetter = true, float threshold = 0.01f) : base(lowerIsBetter) { Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1]."); @@ -235,11 +281,15 @@ public GLEarlyStoppingCriterion(bool lowerIsBetter = true, float threshold = 0.0 } [BestFriend] - internal GLEarlyStoppingCriterion(Options options, bool lowerIsBetter = true) + // Used in command line tool to construct lodable class. + internal GeneralityLossRule(Options options, bool lowerIsBetter = true) : this(lowerIsBetter, options.Threshold) { } + /// + /// See . + /// public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { Contracts.Assert(validationScore >= 0); @@ -257,28 +307,42 @@ public override bool CheckScore(float validationScore, float trainingScore, out /// Low Progress (LP). /// This rule fires when the improvements on the score stall. /// - public sealed class LPEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion + public sealed class LowProgressRule : MovingWindowRule { + [BestFriend] [TlcModule.Component(FriendlyName = "Low Progress (LP)", Name = "LP", Desc = "Stops in case of low progress.")] - public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory + internal new sealed class Options : MovingWindowRule.Options, IEarlyStoppingCriterionFactory { - public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) + public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new LPEarlyStoppingCriterion(lowerIsBetter, Threshold, WindowSize); + return new LowProgressRule(lowerIsBetter, Threshold, WindowSize); } } - public LPEarlyStoppingCriterion(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5) + /// + /// Create a rule which may terminate the training process when the improvements in terms of validation score is slow. + /// It will terminate the training process if the average of the recent validation scores + /// is worse than the best historical validation score. + /// + /// Its meaning is identical to . + /// The maximum gap (in percentage such as 0.01 for 1% and 0.5 for 50%) between the (current) averaged validation + /// score and its best historical value. + /// See . + public LowProgressRule(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5) : base(lowerIsBetter, threshold, windowSize) { } [BestFriend] - internal LPEarlyStoppingCriterion(Options options, bool lowerIsBetter = true) + // Used in command line tool to construct lodable class. + internal LowProgressRule(Options options, bool lowerIsBetter = true) : this(lowerIsBetter, options.Threshold, options.WindowSize) { } + /// + /// See . + /// public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { Contracts.Assert(validationScore >= 0); @@ -303,28 +367,33 @@ public override bool CheckScore(float validationScore, float trainingScore, out /// /// Generality to Progress Ratio (PQ). /// - public sealed class PQEarlyStoppingCriterion : MovingWindowEarlyStoppingCriterion + public sealed class GeneralityToProgressRatioRule : MovingWindowRule { + [BestFriend] [TlcModule.Component(FriendlyName = "Generality to Progress Ratio (PQ)", Name = "PQ", Desc = "Stops in case of generality to progress ration exceeds threshold.")] - public new sealed class Options : MovingWindowEarlyStoppingCriterion.Options, IEarlyStoppingCriterionFactory + internal new sealed class Options : MovingWindowRule.Options, IEarlyStoppingCriterionFactory { - public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) + public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new PQEarlyStoppingCriterion(lowerIsBetter, Threshold, WindowSize); + return new GeneralityToProgressRatioRule(lowerIsBetter, Threshold, WindowSize); } } - public PQEarlyStoppingCriterion(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5) + public GeneralityToProgressRatioRule(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5) : base(lowerIsBetter, threshold, windowSize) { } [BestFriend] - internal PQEarlyStoppingCriterion(Options options, bool lowerIsBetter = true) + // Used in command line tool to construct lodable class. + internal GeneralityToProgressRatioRule(Options options, bool lowerIsBetter = true) : this(lowerIsBetter, options.Threshold, options.WindowSize) { } + /// + /// See . + /// public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { Contracts.Assert(validationScore >= 0); @@ -349,27 +418,32 @@ public override bool CheckScore(float validationScore, float trainingScore, out /// /// Consecutive Loss in Generality (UP). /// - public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion + public sealed class ConsecutiveGeneralityLossRule : EarlyStoppingRule { + [BestFriend] [TlcModule.Component(FriendlyName = "Consecutive Loss in Generality (UP)", Name = "UP", Desc = "Stops in case of consecutive loss in generality.")] - public sealed class Options : IEarlyStoppingCriterionFactory + internal sealed class Options : IEarlyStoppingCriterionFactory { [Argument(ArgumentType.AtMostOnce, HelpText = "The window size.", ShortName = "w")] [TlcModule.Range(Inf = 0)] public int WindowSize = 5; - public IEarlyStoppingCriterion CreateComponent(IHostEnvironment env, bool lowerIsBetter) + public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new UPEarlyStoppingCriterion(lowerIsBetter, WindowSize); + return new ConsecutiveGeneralityLossRule(lowerIsBetter, WindowSize); } } + /// + /// The number of historical validation scores considered when determining if the training process should stop. + /// public int WindowSize { get; } + private int _count; private float _prevScore; - public UPEarlyStoppingCriterion(bool lowerIsBetter, int windowSize = 5) + public ConsecutiveGeneralityLossRule(bool lowerIsBetter, int windowSize = 5) : base(lowerIsBetter) { Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive"); @@ -378,11 +452,14 @@ public UPEarlyStoppingCriterion(bool lowerIsBetter, int windowSize = 5) } [BestFriend] - internal UPEarlyStoppingCriterion(Options options, bool lowerIsBetter = true) + internal ConsecutiveGeneralityLossRule(Options options, bool lowerIsBetter = true) : this(lowerIsBetter, options.WindowSize) { } + /// + /// See . + /// public override bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate) { Contracts.Assert(validationScore >= 0); diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 6caab60446..3569207c3b 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -6598,7 +6598,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRule", + "Name": "EarlyStoppingRuleFactory", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -7587,7 +7587,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRule", + "Name": "EarlyStoppingRuleFactory", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -8474,7 +8474,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRule", + "Name": "EarlyStoppingRuleFactory", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -9370,7 +9370,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRule", + "Name": "EarlyStoppingRuleFactory", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -25308,7 +25308,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRule", + "Name": "EarlyStoppingRuleFactory", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -26279,7 +26279,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRule", + "Name": "EarlyStoppingRuleFactory", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -27148,7 +27148,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRule", + "Name": "EarlyStoppingRuleFactory", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -28026,7 +28026,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRule", + "Name": "EarlyStoppingRuleFactory", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs index c1ee89342f..214e6d4f2a 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEarlyStoppingCriteria.cs @@ -10,18 +10,18 @@ namespace Microsoft.ML.RunTests { public sealed class TestEarlyStoppingCriteria { - private IEarlyStoppingCriterion CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter) + private EarlyStoppingRuleBase CreateEarlyStoppingCriterion(string name, string args, bool lowerIsBetter) { var env = new MLContext() .AddStandardComponents(); - var sub = new SubComponent(name, args); + var sub = new SubComponent(name, args); return sub.CreateInstance(env, lowerIsBetter); } [Fact] public void TolerantEarlyStoppingCriterionTest() { - IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("tr", "th=0.01", false); + EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("tr", "th=0.01", false); bool isBestCandidate; bool shouldStop; @@ -46,7 +46,7 @@ public void TolerantEarlyStoppingCriterionTest() [Fact] public void GLEarlyStoppingCriterionTest() { - IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("gl", "th=0.01", false); + EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("gl", "th=0.01", false); bool isBestCandidate; bool shouldStop; @@ -71,7 +71,7 @@ public void GLEarlyStoppingCriterionTest() [Fact] public void LPEarlyStoppingCriterionTest() { - IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("lp", "th=0.01 w=5", false); + EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("lp", "th=0.01 w=5", false); bool isBestCandidate; bool shouldStop; @@ -107,7 +107,7 @@ public void LPEarlyStoppingCriterionTest() [Fact] public void PQEarlyStoppingCriterionTest() { - IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("pq", "th=0.01 w=5", false); + EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("pq", "th=0.01 w=5", false); bool isBestCandidate; bool shouldStop; @@ -144,7 +144,7 @@ public void PQEarlyStoppingCriterionTest() public void UPEarlyStoppingCriterionTest() { const int windowSize = 8; - IEarlyStoppingCriterion cr = CreateEarlyStoppingCriterion("up", "w=8", false); + EarlyStoppingRuleBase cr = CreateEarlyStoppingCriterion("up", "w=8", false); bool isBestCandidate; bool shouldStop; diff --git a/test/Microsoft.ML.Functional.Tests/Validation.cs b/test/Microsoft.ML.Functional.Tests/Validation.cs index cb0cacfc7e..9953cd5fc7 100644 --- a/test/Microsoft.ML.Functional.Tests/Validation.cs +++ b/test/Microsoft.ML.Functional.Tests/Validation.cs @@ -84,7 +84,7 @@ public void TrainWithValidationSet() var trainedModel = mlContext.Regression.Trainers.FastTree(new Trainers.FastTree.FastTreeRegressionTrainer.Options { NumberOfTrees = 2, EarlyStoppingMetric = EarlyStoppingMetric.L2Norm, - EarlyStoppingRule = new GLEarlyStoppingCriterion.Options() + EarlyStoppingRule = new GeneralityLossRule() }) .Fit(trainData: preprocessedTrainData, validationData: preprocessedValidData); From 8a8e2f1faeccf06242ff697897cae08183021dfd Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 5 Mar 2019 12:54:38 -0800 Subject: [PATCH 3/5] Hide lowerIsBetter argument because it's determined by FastTree --- src/Microsoft.ML.FastTree/BoostingFastTree.cs | 8 +- .../FastTreeArguments.cs | 12 +- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 1 - .../FastTreeRegression.cs | 1 - src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 1 - .../Training/EarlyStoppingCriteria.cs | 118 ++++++++++++------ 6 files changed, 94 insertions(+), 47 deletions(-) diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index 56f2285a70..e994629ba1 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -49,7 +49,7 @@ private protected override void CheckOptions(IChannel ch) if (FastTreeTrainerOptions.EnablePruning && !HasValidSet) throw ch.Except("Cannot perform pruning (pruning) without a validation set (valid)."); - bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || FastTreeTrainerOptions.EarlyStoppingRule != null; + bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null; if (doEarlyStop && !HasValidSet) throw ch.Except("Cannot perform early stopping without a validation set (valid)."); @@ -116,7 +116,7 @@ private protected override IGradientAdjuster MakeGradientWrapper(IChannel ch) private protected override bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStoppingRule, ref int bestIteration) { - if (FastTreeTrainerOptions.EarlyStoppingRuleFactory == null && FastTreeTrainerOptions.EarlyStoppingRule == null) + if (FastTreeTrainerOptions.EarlyStoppingRuleFactory == null) return false; ch.AssertValue(ValidTest); @@ -132,12 +132,8 @@ private protected override bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBas // Create early stopping rule if it's null. if (earlyStoppingRule == null) { - // There are two possible sources of stopping rules. One is the classical IComponentFactory and - // the other one is the rule passed in directly by user. if (FastTreeTrainerOptions.EarlyStoppingRuleFactory != null) earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRuleFactory.CreateComponent(Host, lowerIsBetter); - else if (FastTreeTrainerOptions.EarlyStoppingRule != null) - earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRule; } // Early stopping rule cannot be null! diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index 1953ef121b..00250ffb9c 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -626,7 +626,17 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc [TGUI(Label = "Early Stopping Rule", Description = "Early stopping rule. (Validation set (/valid) is required.)")] internal IEarlyStoppingCriterionFactory EarlyStoppingRuleFactory; - public EarlyStoppingRuleBase EarlyStoppingRule; + private EarlyStoppingRuleBase _earlyStoppingRuleBase; + + public EarlyStoppingRuleBase EarlyStoppingRule + { + get { return _earlyStoppingRuleBase; } + set + { + _earlyStoppingRuleBase = value; + EarlyStoppingRuleFactory = _earlyStoppingRuleBase.BuildFactory(); + } + } /// /// Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3). diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index efb0feabc8..9279772227 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -157,7 +157,6 @@ private protected override void CheckOptions(IChannel ch) } bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || - FastTreeTrainerOptions.EarlyStoppingRule != null || FastTreeTrainerOptions.EnablePruning; if (doEarlyStop) diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 7a4d2f121c..e02d5ec257 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -106,7 +106,6 @@ private protected override void CheckOptions(IChannel ch) base.CheckOptions(ch); bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || - FastTreeTrainerOptions.EarlyStoppingRule != null || FastTreeTrainerOptions.EnablePruning; if (doEarlyStop) diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index fe6c0d0e72..f68e67476f 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -116,7 +116,6 @@ private protected override void CheckOptions(IChannel ch) ch.Warning("For Tweedie regression, early stopping does not yet use the Tweedie distribution."); bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || - FastTreeTrainerOptions.EarlyStoppingRule != null || FastTreeTrainerOptions.EnablePruning; // Please do not remove it! See comment above. diff --git a/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs index 6b96bd95c7..ba53359831 100644 --- a/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs +++ b/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs @@ -41,6 +41,11 @@ public abstract class EarlyStoppingRuleBase /// True if the current result is the best ever. /// If true, the learning should stop. public abstract bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate); + + /// + /// Create for supporting legacy infra built upon . + /// + internal abstract IEarlyStoppingCriterionFactory BuildFactory(); } [BestFriend] @@ -57,7 +62,7 @@ public abstract class EarlyStoppingRule : EarlyStoppingRuleBase /// /// It's if the selected stopping metric should be as low as possible, and otherwise. /// - public bool LowerIsBetter { get; } + private protected bool LowerIsBetter { get; } private protected float BestScore { get { return _bestScore; } @@ -74,6 +79,14 @@ private protected EarlyStoppingRule(bool lowerIsBetter) _bestScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity; } + /// + /// Lazy constructor. It doesn't initialize anything because in runtime, will be + /// called inside the training process to initialize needed fields. + /// + private protected EarlyStoppingRule() + { + } + /// /// Check if the given score is the best ever. The best score will be stored at this._bestScore. /// @@ -101,7 +114,7 @@ internal sealed class Options : IEarlyStoppingCriterionFactory public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new TolerantEarlyStoppingRule(lowerIsBetter, Threshold); + return new TolerantEarlyStoppingRule(this, lowerIsBetter); } } @@ -114,20 +127,20 @@ public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsB /// Create a rule which may terminate the training process if validation score exceeds compared with /// the best historical validation score. /// - /// Its meaning is identical to . /// The maximum difference allowed between the (current) validation score and its best historical value. - public TolerantEarlyStoppingRule(bool lowerIsBetter = true, float threshold = 0.01f) - : base(lowerIsBetter) + public TolerantEarlyStoppingRule(float threshold = 0.01f) + : base() { Contracts.CheckUserArg(threshold >= 0, nameof(threshold), "Must be non-negative."); Threshold = threshold; } - [BestFriend] // Used in command line tool to construct lodable class. - internal TolerantEarlyStoppingRule(Options options, bool lowerIsBetter = true) - : this(lowerIsBetter, options.Threshold) + private TolerantEarlyStoppingRule(Options options, bool lowerIsBetter) + : base(lowerIsBetter) { + Contracts.CheckUserArg(options.Threshold >= 0, nameof(options.Threshold), "Must be non-negative."); + Threshold = options.Threshold; } /// @@ -144,6 +157,8 @@ public override bool CheckScore(float validationScore, float trainingScore, out else return (BestScore - validationScore > Threshold); } + + internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { Threshold = Threshold }; } // For the detail of the following rules, see the following paper. @@ -177,7 +192,18 @@ internal class Options // Hide this because it's a runtime value. private protected Queue PastScores; - private protected MovingWindowRule(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5) + private protected MovingWindowRule(float threshold, int windowSize) + : base() + { + Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1]."); + Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive."); + + Threshold = threshold; + WindowSize = windowSize; + PastScores = new Queue(windowSize); + } + + private protected MovingWindowRule(bool lowerIsBetter, float threshold, int windowSize) : base(lowerIsBetter) { Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1]."); @@ -256,7 +282,7 @@ internal sealed class Options : IEarlyStoppingCriterionFactory public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new GeneralityLossRule(lowerIsBetter, Threshold); + return new GeneralityLossRule(this, lowerIsBetter); } } @@ -270,21 +296,21 @@ public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsB /// Create a rule which may terminate the training process in case of loss of generality. The loss of generality means /// the specified score on validation start increaseing. /// - /// Its meaning is identical to . /// The maximum gap (in percentage such as 0.01 for 1% and 0.5 for 50%) between the (current) validation /// score and its best historical value. - public GeneralityLossRule(bool lowerIsBetter = true, float threshold = 0.01f) : - base(lowerIsBetter) + public GeneralityLossRule(float threshold = 0.01f) : + base() { Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1]."); Threshold = threshold; } - [BestFriend] // Used in command line tool to construct lodable class. - internal GeneralityLossRule(Options options, bool lowerIsBetter = true) - : this(lowerIsBetter, options.Threshold) + private GeneralityLossRule(Options options, bool lowerIsBetter) + : base(lowerIsBetter) { + Contracts.CheckUserArg(0 <= options.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1]."); + Threshold = options.Threshold; } /// @@ -301,7 +327,9 @@ public override bool CheckScore(float validationScore, float trainingScore, out else return (validationScore < (1 - Threshold) * BestScore); } - } + + internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { Threshold = Threshold }; +} /// /// Low Progress (LP). @@ -315,7 +343,7 @@ public sealed class LowProgressRule : MovingWindowRule { public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new LowProgressRule(lowerIsBetter, Threshold, WindowSize); + return new LowProgressRule(this, lowerIsBetter); } } @@ -324,19 +352,17 @@ public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsB /// It will terminate the training process if the average of the recent validation scores /// is worse than the best historical validation score. /// - /// Its meaning is identical to . /// The maximum gap (in percentage such as 0.01 for 1% and 0.5 for 50%) between the (current) averaged validation /// score and its best historical value. /// See . - public LowProgressRule(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5) - : base(lowerIsBetter, threshold, windowSize) + public LowProgressRule(float threshold = 0.01f, int windowSize = 5) + : base(threshold, windowSize) { } - [BestFriend] // Used in command line tool to construct lodable class. - internal LowProgressRule(Options options, bool lowerIsBetter = true) - : this(lowerIsBetter, options.Threshold, options.WindowSize) + private LowProgressRule(Options options, bool lowerIsBetter) + : base(lowerIsBetter, options.Threshold, options.WindowSize) { } @@ -362,7 +388,9 @@ public override bool CheckScore(float validationScore, float trainingScore, out return false; } - } + + internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { Threshold = Threshold, WindowSize = WindowSize }; +} /// /// Generality to Progress Ratio (PQ). @@ -375,19 +403,23 @@ public sealed class GeneralityToProgressRatioRule : MovingWindowRule { public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new GeneralityToProgressRatioRule(lowerIsBetter, Threshold, WindowSize); + return new GeneralityToProgressRatioRule(this, lowerIsBetter); } } - public GeneralityToProgressRatioRule(bool lowerIsBetter, float threshold = 0.01f, int windowSize = 5) - : base(lowerIsBetter, threshold, windowSize) + /// + /// Create a rule which may terminate the training process when generality-to-progress ratio exceeds . + /// + /// The maximum ratio gap (in percentage such as 0.01 for 1% and 0.5 for 50%). + /// See . + public GeneralityToProgressRatioRule(float threshold = 0.01f, int windowSize = 5) + : base(threshold, windowSize) { } - [BestFriend] // Used in command line tool to construct lodable class. - internal GeneralityToProgressRatioRule(Options options, bool lowerIsBetter = true) - : this(lowerIsBetter, options.Threshold, options.WindowSize) + private GeneralityToProgressRatioRule(Options options, bool lowerIsBetter) + : base(lowerIsBetter, options.Threshold, options.WindowSize) { } @@ -413,6 +445,9 @@ public override bool CheckScore(float validationScore, float trainingScore, out return false; } + + [BestFriend] + internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { Threshold = Threshold, WindowSize = WindowSize }; } /// @@ -431,7 +466,7 @@ internal sealed class Options : IEarlyStoppingCriterionFactory public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsBetter) { - return new ConsecutiveGeneralityLossRule(lowerIsBetter, WindowSize); + return new ConsecutiveGeneralityLossRule(this, lowerIsBetter); } } @@ -443,18 +478,24 @@ public EarlyStoppingRuleBase CreateComponent(IHostEnvironment env, bool lowerIsB private int _count; private float _prevScore; - public ConsecutiveGeneralityLossRule(bool lowerIsBetter, int windowSize = 5) - : base(lowerIsBetter) + /// + /// Creates a rule which terminates the training process if the validation score is not improved in consecutive iterations. + /// + /// Number of training iterations allowed to have no improvement. + public ConsecutiveGeneralityLossRule(int windowSize = 5) + : base() { Contracts.CheckUserArg(windowSize > 0, nameof(windowSize), "Must be positive"); WindowSize = windowSize; _prevScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity; } - [BestFriend] - internal ConsecutiveGeneralityLossRule(Options options, bool lowerIsBetter = true) - : this(lowerIsBetter, options.WindowSize) + private ConsecutiveGeneralityLossRule(Options options, bool lowerIsBetter) + : base(lowerIsBetter) { + Contracts.CheckUserArg(options.WindowSize > 0, nameof(options.WindowSize), "Must be positive"); + WindowSize = options.WindowSize; + _prevScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity; } /// @@ -471,5 +512,8 @@ public override bool CheckScore(float validationScore, float trainingScore, out return (_count >= WindowSize); } + + [BestFriend] + internal override IEarlyStoppingCriterionFactory BuildFactory() => new Options() { WindowSize = WindowSize }; } } From 5b071bfa48129e3d88d9589519fd4870f50e3955 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 6 Mar 2019 09:23:24 -0800 Subject: [PATCH 4/5] Address comments --- src/Microsoft.ML.FastTree/FastTreeArguments.cs | 9 ++++++++- .../Training/EarlyStoppingCriteria.cs | 10 ++++++++-- .../Common/EntryPoints/core_manifest.json | 16 ++++++++-------- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index 00250ffb9c..767c706ece 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -622,12 +622,19 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc /// Early stopping rule. (Validation set (/valid) is required). /// [BestFriend] - [Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", ShortName = "esr", NullName = "")] + [Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", Name = "EarlyStoppingRule", ShortName = "esr", NullName = "")] [TGUI(Label = "Early Stopping Rule", Description = "Early stopping rule. (Validation set (/valid) is required.)")] internal IEarlyStoppingCriterionFactory EarlyStoppingRuleFactory; + /// + /// The underlying state of and . + /// private EarlyStoppingRuleBase _earlyStoppingRuleBase; + /// + /// Early stopping rule used to terminate training process once meeting a specified criterion. Possible choices are + /// 's implementations such as and . + /// public EarlyStoppingRuleBase EarlyStoppingRule { get { return _earlyStoppingRuleBase; } diff --git a/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs b/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs index ba53359831..e0241379d6 100644 --- a/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs +++ b/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs @@ -42,6 +42,12 @@ public abstract class EarlyStoppingRuleBase /// If true, the learning should stop. public abstract bool CheckScore(float validationScore, float trainingScore, out bool isBestCandidate); + /// + /// Having constructor without parameter prevents derivations of from being + /// implemented by the public. + /// + private protected EarlyStoppingRuleBase() { } + /// /// Create for supporting legacy infra built upon . /// @@ -73,7 +79,7 @@ private protected float BestScore { } } - private protected EarlyStoppingRule(bool lowerIsBetter) + private protected EarlyStoppingRule(bool lowerIsBetter) : base() { LowerIsBetter = lowerIsBetter; _bestScore = LowerIsBetter ? float.PositiveInfinity : float.NegativeInfinity; @@ -83,7 +89,7 @@ private protected EarlyStoppingRule(bool lowerIsBetter) /// Lazy constructor. It doesn't initialize anything because in runtime, will be /// called inside the training process to initialize needed fields. /// - private protected EarlyStoppingRule() + private protected EarlyStoppingRule() : base() { } diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 3569207c3b..6caab60446 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -6598,7 +6598,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRuleFactory", + "Name": "EarlyStoppingRule", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -7587,7 +7587,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRuleFactory", + "Name": "EarlyStoppingRule", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -8474,7 +8474,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRuleFactory", + "Name": "EarlyStoppingRule", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -9370,7 +9370,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRuleFactory", + "Name": "EarlyStoppingRule", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -25308,7 +25308,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRuleFactory", + "Name": "EarlyStoppingRule", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -26279,7 +26279,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRuleFactory", + "Name": "EarlyStoppingRule", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -27148,7 +27148,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRuleFactory", + "Name": "EarlyStoppingRule", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" @@ -28026,7 +28026,7 @@ "Default": "GradientDescent" }, { - "Name": "EarlyStoppingRuleFactory", + "Name": "EarlyStoppingRule", "Type": { "Kind": "Component", "ComponentKind": "EarlyStoppingCriterion" From e42285c9a4af5ab281cd195a545e95dd3fe4029a Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 6 Mar 2019 10:12:11 -0800 Subject: [PATCH 5/5] Fix test --- test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index f06690fa2c..0046e4a469 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -4974,7 +4974,7 @@ public void TestCrossValidationMacroWithNonDefaultNames() 'MaximumNumberOfLineSearchSteps': 0, 'MinimumStepSize': 0.0, 'OptimizationAlgorithm': 'GradientDescent', - 'EarlyStoppingRuleFactory': null, + 'EarlyStoppingRule': null, 'EarlyStoppingMetrics': 1, 'EnablePruning': false, 'UseTolerantPruning': false,