From 054ed9c0ae2991b70512d8399176f1477a66bada Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Tue, 23 Oct 2018 16:25:41 -0700 Subject: [PATCH 1/8] adding multiclass and ranking extensions for LightGBM. Adding tests, and refactoring catalog and pigsty statics --- .../Training/TrainerUtils.cs | 5 +- src/Microsoft.ML.FastTree/FastTree.cs | 4 +- src/Microsoft.ML.FastTree/GamTrainer.cs | 4 +- .../OlsLinearRegression.cs | 2 +- .../KMeansPlusPlusTrainer.cs | 2 +- .../LightGbmBinaryTrainer.cs | 19 +++- src/Microsoft.ML.LightGBM/LightGbmCatalog.cs | 67 +++++++++-- .../LightGbmMulticlassTrainer.cs | 38 +++++-- .../LightGbmRankingTrainer.cs | 43 +++++-- src/Microsoft.ML.LightGBM/LightGbmStatic.cs | 107 ++++++++++++++++-- .../LightGbmTrainerBase.cs | 4 +- src/Microsoft.ML.PCA/PcaTrainer.cs | 2 +- .../Standard/LinearClassificationTrainer.cs | 6 +- .../LogisticRegression/LbfgsPredictorBase.cs | 2 +- .../Standard/Online/OnlineLinear.cs | 2 +- .../Standard/SdcaMultiClass.cs | 4 +- .../Standard/SdcaRegression.cs | 4 +- .../Standard/SdcaStatic.cs | 10 +- .../Training.cs | 86 ++++++++++++++ 19 files changed, 341 insertions(+), 70 deletions(-) diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index dff5748635..fb87eac884 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -377,9 +377,10 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn) /// The for the weight column. /// /// name of the weight column - public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn) + /// whether the column is implicitely, or explicitely defined + public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit) { - if (weightColumn == null) + if (weightColumn == null || !isExplicit) return null; return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 1f34a3f522..b7feecb4fe 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -93,7 +93,7 @@ public abstract class FastTreeTrainerBase : /// private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null )) { Args = new TArgs(); @@ -127,7 +127,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l /// Legacy constructor that is used when invoking the classes deriving from this, through maml. /// private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); Args = args; diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index ba70fb34c7..f9d96183c5 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -134,7 +134,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight private protected GamTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, string weightColumn = null, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn !=null)) { Args = new TArgs(); @@ -154,7 +154,7 @@ private protected GamTrainerBase(IHostEnvironment env, string name, SchemaShape. private protected GamTrainerBase(IHostEnvironment env, TArgs args, string name, SchemaShape.Column label) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Contracts.CheckValue(env, nameof(env)); Host.CheckValue(args, nameof(args)); diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index cc6c0a1361..24aab18377 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -85,7 +85,7 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, string featureColumn, st /// internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); Host.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "L2 regularization term cannot be negative"); diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index 83c4bd79e1..d0adbcca29 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -112,7 +112,7 @@ internal KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args) } private KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args, string featureColumn, string weightColumn, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) { Host.CheckValue(args, nameof(args)); diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index 397ba234a4..0cef514af3 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -110,22 +110,29 @@ internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string featureColumn, + public LightGbmBinaryTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, string weightColumn = null, int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : this(env, new LightGbmArguments { + LabelColumn = labelColumn ?? DefaultColumnNames.Label, + FeatureColumn = featureColumn ?? DefaultColumnNames.Features, + WeightColumn = weightColumn ?? Optional.Implicit(DefaultColumnNames.Weight), + NumLeaves = numLeaves ?? default, + MinDataPerLeaf = minDataPerLeaf ?? default, + LearningRate = learningRate ?? default, + NumBoostRound = numBoostRound + }) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - if (advancedSettings != null) CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); - // override with the directly provided values + // override with the directly provided values, giving them priority over the Advanced args, in case they are assigned twice. Args.NumBoostRound = numBoostRound; Args.NumLeaves = numLeaves ?? Args.NumLeaves; Args.LearningRate = learningRate ?? Args.LearningRate; diff --git a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs index 54068a9fa6..605d35b0b1 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML /// /// Regression trainer estimators. /// - public static class LightGbmRegressionExtensions + public static class LightGbmExtensions { /// /// Predict a target using a decision tree regression model trained with the . @@ -40,13 +40,6 @@ public static LightGbmRegressorTrainer LightGbm(this RegressionContext.Regressio var env = CatalogUtils.GetEnvironment(ctx); return new LightGbmRegressorTrainer(env, label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); } - } - - /// - /// Binary Classification trainer estimators. - /// - public static class LightGbmClassificationExtensions - { /// /// Predict a target using a decision tree binary classification model trained with the . @@ -75,5 +68,63 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationContext.Bi return new LightGbmBinaryTrainer(env, label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); } + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features colum. + /// The weights column. + /// The groupId column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static LightGbmRankingTrainer LightGbm(this RankingContext.RankingTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string groupId = DefaultColumnNames.GroupId, + string weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new LightGbmRankingTrainer(env, label, features, groupId, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + } + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features colum. + /// The weights column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new LightGbmMulticlassTrainer(env, label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + } } } diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 73d853878b..cbb71a502e 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -45,16 +45,40 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. /// A delegate to apply all the advanced arguments to the algorithm. - public LightGbmMulticlassTrainer(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + public LightGbmMulticlassTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + : this(env, new LightGbmArguments + { + LabelColumn = labelColumn ?? DefaultColumnNames.Label, + FeatureColumn = featureColumn ?? DefaultColumnNames.Features, + WeightColumn = weightColumn ?? Optional.Implicit(DefaultColumnNames.Weight), + NumLeaves = numLeaves ?? default, + MinDataPerLeaf = minDataPerLeaf ?? default, + LearningRate = learningRate ?? default, + NumBoostRound = numBoostRound + }) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - _numClass = -1; + if (advancedSettings != null) + CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); + + // override with the directly provided values, giving them priority over the Advanced args, in case they are assigned twice. + Args.NumBoostRound = numBoostRound; + Args.NumLeaves = numLeaves ?? Args.NumLeaves; + Args.LearningRate = learningRate ?? Args.LearningRate; + Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf; } private FastTree.Internal.Ensemble GetBinaryEnsemble(int classID) diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 9de052b53f..865061a242 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -91,16 +91,43 @@ internal LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. - /// The name for the column containing the initial weight. + /// The name of the column containing the group ID. + /// The name of the column containing the initial weight. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. /// A delegate to apply all the advanced arguments to the algorithm. - public LightGbmRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + public LightGbmRankingTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string groupIdColumn, + string weightColumn = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + : this(env, new LightGbmArguments + { + LabelColumn = labelColumn ?? DefaultColumnNames.Label, + FeatureColumn = featureColumn ?? DefaultColumnNames.Features, + GroupIdColumn = groupIdColumn ?? DefaultColumnNames.GroupId, + WeightColumn = weightColumn ?? null, // Optional.Implicit(DefaultColumnNames.Weight), + NumLeaves = numLeaves ?? default, + MinDataPerLeaf = minDataPerLeaf ?? default, + LearningRate = learningRate ?? default, + NumBoostRound = numBoostRound + }) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); + if (advancedSettings != null) + CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); + + // override with the directly provided values, giving them priority over the Advanced args, in case they are assigned twice. + Args.NumBoostRound = numBoostRound; + Args.NumLeaves = numLeaves ?? Args.NumLeaves; + Args.LearningRate = learningRate ?? Args.LearningRate; + Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf; } protected override void CheckDataValid(IChannel ch, RoleMappedData data) diff --git a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs index ecd84ef2a4..ebb8b0eb25 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.StaticPipe.Runtime; using System; @@ -14,7 +15,7 @@ namespace Microsoft.ML.StaticPipe /// /// Regression trainer estimators. /// - public static partial class RegressionTrainers + public static class LightGbmTrainers { /// /// LightGbm extension method. @@ -49,7 +50,7 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c Action advancedSettings = null, Action onFit = null) { - LightGbmStaticsUtils.CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => @@ -63,12 +64,6 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c return rec.Score; } - } - - /// - /// Binary Classification trainer estimators. - /// - public static partial class ClassificationTrainers { /// /// LightGbm extension method. @@ -98,7 +93,7 @@ public static (Scalar score, Scalar probability, Scalar pred Action advancedSettings = null, Action> onFit = null) { - LightGbmStaticsUtils.CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => @@ -114,11 +109,99 @@ public static (Scalar score, Scalar probability, Scalar pred return rec.Output; } - } - internal static class LightGbmStaticsUtils { + /// + /// LightGbm extension method. + /// + /// The . + /// The label column. + /// The features colum. + /// The groupId column. + /// The weights column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; + /// it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted binary classification score (which will range + /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label. + public static Scalar LightGbm(this RankingContext.RankingTrainers ctx, + Scalar label, Vector features, Key groupId, Scalar weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null, + Action onFit = null) + { + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + + var rec = new TrainerEstimatorReconciler.Ranker( + (env, labelName, featuresName, groupIdName, weightsName) => + { + var trainer = new LightGbmRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, numLeaves, + minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, groupId, weights); + + return rec.Score; + } + + /// + /// Predict a target using a linear multiclass classification model trained with the SDCA trainer. + /// + /// The multiclass classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The weights column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// A delegate to set more settings. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. + public static (Vector score, Key predictedLabel) + LightGbm(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + Key label, + Vector features, + Scalar weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null, + Action onFit = null) + { + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + + var rec = new TrainerEstimatorReconciler.MulticlassClassifier( + (env, labelName, featuresName, weightsName) => + { + var trainer = new LightGbmMulticlassTrainer(env, labelName, featuresName, weightsName, numLeaves, + minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, weights); + + return rec.Output; + } - internal static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, + private static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, int? numLeaves, int? minDataPerLeaf, double? learningRate, diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index d4d7a1a51e..dc54981060 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -58,7 +58,7 @@ private sealed class CategoricalMetaData private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) { Args = new LightGbmArguments(); @@ -81,7 +81,7 @@ private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaS } private protected LightGbmTrainerBase(IHostEnvironment env, string name, LightGbmArguments args, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index 11650e25d2..71bca1028f 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -107,7 +107,7 @@ internal RandomizedPcaTrainer(IHostEnvironment env, Arguments args) private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featureColumn, string weightColumn, int rank = 20, int oversampling = 20, bool center = true, int? seed = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) { // if the args are not null, we got here from maml, and the internal ctor. if (args != null) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 94f00e5505..c9ed4637d4 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -63,7 +63,7 @@ public abstract class LinearTrainerBase : TrainerEstimator private protected LinearTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, string weightColumn = null) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), - labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) { } @@ -1443,7 +1443,7 @@ public LinearClassificationTrainer(IHostEnvironment env, float? l1Threshold = null, int? maxIterations = null, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, + : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null), advancedSettings, l2Const, l1Threshold, maxIterations) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -1474,7 +1474,7 @@ public LinearClassificationTrainer(IHostEnvironment env, internal LinearClassificationTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null) - : base(env, args, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + : base(env, args, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) { _loss = args.LossFunction.CreateComponent(env); Loss = _loss; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index 5f7a40c5d2..2733a55067 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -169,7 +169,7 @@ internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column l int? memorySize = null, bool? enforceNoNegativity = null) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); Args = args; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index f7e8462e29..a2af99e7a2 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -83,7 +83,7 @@ public abstract class OnlineLinearTrainer : TrainerEstimat protected virtual bool NeedCalibration => false; protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights)) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights, args.InitialWeights != null)) { Contracts.CheckValue(args, nameof(args)); Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive); diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index a2cdddb881..73feffa848 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -67,7 +67,7 @@ public SdcaMultiClassTrainer(IHostEnvironment env, float? l1Threshold = null, int? maxIterations = null, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, + : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null), advancedSettings, l2Const, l1Threshold, maxIterations) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -78,7 +78,7 @@ public SdcaMultiClassTrainer(IHostEnvironment env, internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null) - : base(env, args, TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + : base(env, args, TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) { Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 3e92670cb2..e0818962a2 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -72,7 +72,7 @@ public SdcaRegressionTrainer(IHostEnvironment env, float? l1Threshold = null, int? maxIterations = null, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, + : base(env, featureColumn, TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null), advancedSettings, l2Const, l1Threshold, maxIterations) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -82,7 +82,7 @@ public SdcaRegressionTrainer(IHostEnvironment env, } internal SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null) - : base(env, args, TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + : base(env, args, TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) { Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs index 77c5930690..2315ac78dd 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.StaticPipe /// /// Extension methods and utilities for instantiating SDCA trainer estimators inside statically typed pipelines. /// - public static class SdcaRegressionExtensions + public static class SdcaExtensions { /// /// Predict a target using a linear regression model trained with the SDCA trainer. @@ -70,10 +70,6 @@ public static Scalar Sdca(this RegressionContext.RegressionTrainers ctx, return rec.Score; } - } - - public static class SdcaBinaryClassificationExtensions - { /// /// Predict a target using a linear binary classification model trained with the SDCA trainer, and log-loss. @@ -198,10 +194,6 @@ public static (Scalar score, Scalar predictedLabel) Sdca( return rec.Output; } - } - - public static class SdcaMulticlassExtensions - { /// /// Predict a target using a linear multiclass classification model trained with the SDCA trainer. diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index be284fa437..5c3c65b0fd 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -754,6 +754,92 @@ public void FastTreeRanking() Assert.InRange(metrics.Ndcg[2], 36.5, 37); } + [Fact] + public void LightGBMRanking() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.adultRanking.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new RankingContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(0), features: c.LoadFloat(9, 14), groupId: c.LoadText(1)), + separator: '\t', hasHeader: true); + + LightGbmRankingPredictor pred = null; + + var est = reader.MakeNewEstimator() + .Append(r => (r.label, r.features, groupId: r.groupId.ToKey())) + .Append(r => (r.label, r.groupId, score: ctx.Trainers.LightGbm(r.label, r.features, r.groupId, onFit: (p) => { pred = p; }))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + + var data = model.Read(dataSource); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.groupId, r => r.score); + Assert.NotNull(metrics); + + Assert.True(metrics.Ndcg.Length == metrics.Dcg.Length && metrics.Dcg.Length == 3); + + Assert.InRange(metrics.Dcg[0], 1.4, 1.6); + Assert.InRange(metrics.Dcg[1], 1.4, 1.8); + Assert.InRange(metrics.Dcg[2], 1.4, 1.8); + + Assert.InRange(metrics.Ndcg[0], 36.5, 37); + Assert.InRange(metrics.Ndcg[1], 36.5, 37); + Assert.InRange(metrics.Ndcg[2], 36.5, 37); + } + + [Fact] + public void MultiClassLightGBM() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.iris.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new MulticlassClassificationContext(env); + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); + + OvaPredictor pred = null; + + // With a custom loss function we no longer get calibrated predictions. + var est = reader.MakeNewEstimator() + .Append(r => (label: r.label.ToKey(), r.features)) + .Append(r => (r.label, preds: ctx.Trainers.LightGbm( + r.label, + r.features, onFit: p => pred = p))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + //int[] labelHistogram = default; + //int[][] featureHistogram = default; + //pred.GetLabelHistogram(ref labelHistogram, out int labelCount1); + //pred.GetFeatureHistogram(ref featureHistogram, out int labelCount2, out int featureCount); + //Assert.True(labelCount1 == 3 && labelCount1 == labelCount2 && labelCount1 <= labelHistogram.Length); + //for (int i = 0; i < labelCount1; i++) + // Assert.True(featureCount == 4 && (featureCount <= featureHistogram[i].Length)); + + var data = model.Read(dataSource); + + // Just output some data on the schema for fun. + var schema = data.AsDynamic.Schema; + for (int c = 0; c < schema.ColumnCount; ++c) + Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}"); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2); + Assert.True(metrics.LogLoss > 0); + Assert.True(metrics.TopKAccuracy > 0); + } + [Fact] public void MultiClassNaiveBayesTrainer() { From c41eb6c25a157ff65f9352e0c983265b1d2263c5 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Tue, 23 Oct 2018 22:52:32 -0700 Subject: [PATCH 2/8] removing duplicate method from TrainUtils adding groupid to the trainer estimator base refactoring the catalog and static extensions for trees --- .../Training/TrainerEstimatorBase.cs | 15 ++++- .../Training/TrainerUtils.cs | 13 ++-- src/Microsoft.ML.FastTree/FastTree.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 8 +-- src/Microsoft.ML.FastTree/GamTrainer.cs | 2 +- .../Microsoft.ML.FastTree.csproj | 4 +- ...tTreeCatalog.cs => TreeTrainersCatalog.cs} | 67 ++++++++++++++++--- ...astTreeStatic.cs => TreeTrainersStatic.cs} | 31 +++------ .../KMeansPlusPlusTrainer.cs | 2 +- .../LightGbmBinaryTrainer.cs | 19 ++---- .../LightGbmMulticlassTrainer.cs | 18 ++--- .../LightGbmRankingTrainer.cs | 18 ++--- src/Microsoft.ML.LightGBM/LightGbmStatic.cs | 1 + .../LightGbmTrainerBase.cs | 2 +- src/Microsoft.ML.PCA/PcaTrainer.cs | 2 +- .../Standard/LinearClassificationTrainer.cs | 6 +- .../MulticlassLogisticRegression.cs | 4 +- .../MultiClass/MultiClassNaiveBayesTrainer.cs | 4 +- .../Standard/Online/OnlineLinear.cs | 2 +- .../Standard/SdcaMultiClass.cs | 4 +- .../Standard/SdcaRegression.cs | 4 +- .../Training.cs | 7 -- 22 files changed, 133 insertions(+), 102 deletions(-) rename src/Microsoft.ML.FastTree/{FastTreeCatalog.cs => TreeTrainersCatalog.cs} (61%) rename src/Microsoft.ML.FastTree/{FastTreeStatic.cs => TreeTrainersStatic.cs} (92%) diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index ad074fd737..f4ea646999 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -41,6 +41,11 @@ public abstract class TrainerEstimatorBase : ITrainerEstim /// public readonly SchemaShape.Column WeightColumn; + /// + /// The optional groupID column that the ranking trainers expects. + /// + public readonly SchemaShape.Column GroupIdColumn; + protected readonly IHost Host; /// @@ -50,17 +55,23 @@ public abstract class TrainerEstimatorBase : ITrainerEstim public abstract PredictionKind PredictionKind { get; } - public TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null) + public TrainerEstimatorBase(IHost host, + SchemaShape.Column feature, + SchemaShape.Column label, + SchemaShape.Column weight = null, + SchemaShape.Column groupId = null) { Contracts.CheckValue(host, nameof(host)); Host = host; Host.CheckValue(feature, nameof(feature)); Host.CheckValueOrNull(label); Host.CheckValueOrNull(weight); + Host.CheckValueOrNull(groupId); FeatureColumn = feature; LabelColumn = label; WeightColumn = weight; + GroupIdColumn = groupId; } public TTransformer Fit(IDataView input) => TrainTransformer(input); @@ -150,7 +161,7 @@ protected TTransformer TrainTransformer(IDataView trainSet, protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema); private RoleMappedData MakeRoles(IDataView data) => - new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name); + new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name); IPredictor ITrainer.Train(TrainContext context) => Train(context); } diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index fb87eac884..decdab3753 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -362,9 +362,14 @@ public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn) /// /// The for the label column for regression tasks. /// - /// name of the weight column - public static SchemaShape.Column MakeU4ScalarLabel(string labelColumn) - => new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); + /// name of the weight column + public static SchemaShape.Column MakeU4ScalarColumn(string columnName) + { + if (columnName == null) + return null; + + return new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); + } /// /// The for the feature column. @@ -378,7 +383,7 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn) /// /// name of the weight column /// whether the column is implicitely, or explicitely defined - public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit) + public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true) { if (weightColumn == null || !isExplicit) return null; diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index b7feecb4fe..3568c1e9d0 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -93,7 +93,7 @@ public abstract class FastTreeTrainerBase : /// private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null )) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { Args = new TArgs(); diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index bd2e7afd5a..8873043e55 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -35,10 +35,10 @@ namespace Microsoft.ML.Runtime.FastTree public sealed partial class FastTreeTweedieTrainer : BoostingFastTreeTrainerBase, FastTreeTweediePredictor> { - public const string LoadNameValue = "FastTreeTweedieRegression"; - public const string UserNameValue = "FastTree (Boosted Trees) Tweedie Regression"; - public const string Summary = "Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression."; - public const string ShortName = "fttweedie"; + internal const string LoadNameValue = "FastTreeTweedieRegression"; + internal const string UserNameValue = "FastTree (Boosted Trees) Tweedie Regression"; + internal const string Summary = "Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression."; + internal const string ShortName = "fttweedie"; private TestHistory _firstTestSetHistory; private Test _trainRegressionTest; diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index f9d96183c5..1bc9c403a7 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -134,7 +134,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight private protected GamTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, string weightColumn = null, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn !=null)) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Args = new TArgs(); diff --git a/src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj b/src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj index a92b3c26d7..d54c0817c1 100644 --- a/src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj +++ b/src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj @@ -26,8 +26,8 @@ - - + + diff --git a/src/Microsoft.ML.FastTree/FastTreeCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs similarity index 61% rename from src/Microsoft.ML.FastTree/FastTreeCatalog.cs rename to src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs index f956415a69..85537c72da 100644 --- a/src/Microsoft.ML.FastTree/FastTreeCatalog.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML /// /// FastTree extension methods. /// - public static class FastTreeRegressionExtensions + public static class TreeExtensions { /// /// Predict a target using a decision tree regression model trained with the . @@ -40,10 +40,6 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi var env = CatalogUtils.GetEnvironment(ctx); return new FastTreeRegressionTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); } - } - - public static class FastTreeBinaryClassificationExtensions - { /// /// Predict a target using a decision tree binary classification model trained with the . @@ -71,10 +67,6 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica var env = CatalogUtils.GetEnvironment(ctx); return new FastTreeBinaryClassificationTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); } - } - - public static class FastTreeRankingExtensions - { /// /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . @@ -96,5 +88,62 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer var env = CatalogUtils.GetEnvironment(ctx); return new FastTreeRankingTrainer(env, label, features, groupId, weights, advancedSettings); } + + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The label column. + /// The features colum. + /// The optional weights column. + /// Algorithm advanced settings. + public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new BinaryClassificationGamTrainer(env, label, features, weights, advancedSettings); + } + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features colum. + /// The optional weights column. + /// Algorithm advanced settings. + public static RegressionGamTrainer GeneralizedAdditiveMethods(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new RegressionGamTrainer(env, label, features, weights, advancedSettings); + } + + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The label column. + /// The features colum. + /// The optional weights column. + /// Algorithm advanced settings. + public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeTweedieTrainer(env, label, features, weights, advancedSettings: advancedSettings); + } } } diff --git a/src/Microsoft.ML.FastTree/FastTreeStatic.cs b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs similarity index 92% rename from src/Microsoft.ML.FastTree/FastTreeStatic.cs rename to src/Microsoft.ML.FastTree/TreeTrainersStatic.cs index 051d6feaf6..7c45c27c02 100644 --- a/src/Microsoft.ML.FastTree/FastTreeStatic.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.StaticPipe /// /// FastTree extension methods. /// - public static class FastTreeRegressionExtensions + public static class TreeRegressionExtensions { /// /// FastTree extension method. @@ -50,7 +50,7 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c Action advancedSettings = null, Action onFit = null) { - FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => @@ -64,10 +64,6 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c return rec.Score; } - } - - public static class FastTreeBinaryClassificationExtensions - { /// /// FastTree extension method. @@ -98,7 +94,7 @@ public static (Scalar score, Scalar probability, Scalar pred Action advancedSettings = null, Action> onFit = null) { - FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => @@ -114,10 +110,6 @@ public static (Scalar score, Scalar probability, Scalar pred return rec.Output; } - } - - public static class FastTreeRankingExtensions - { /// /// FastTree . @@ -148,7 +140,7 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c Action advancedSettings = null, Action onFit = null) { - FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => @@ -161,17 +153,14 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c return rec.Score; } - } - internal class FastTreeStaticsUtils - { internal static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, - int numLeaves, - int numTrees, - int minDatapointsInLeafs, - double learningRate, - Delegate advancedSettings, - Delegate onFit) + int numLeaves, + int numTrees, + int minDatapointsInLeafs, + double learningRate, + Delegate advancedSettings, + Delegate onFit) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index d0adbcca29..83c4bd79e1 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -112,7 +112,7 @@ internal KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args) } private KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args, string featureColumn, string weightColumn, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Host.CheckValue(args, nameof(args)); diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index 0cef514af3..397ba234a4 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -110,29 +110,22 @@ internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - public LightGbmBinaryTrainer(IHostEnvironment env, - string labelColumn, - string featureColumn, + public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, int? numLeaves = null, int? minDataPerLeaf = null, double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null) - : this(env, new LightGbmArguments { - LabelColumn = labelColumn ?? DefaultColumnNames.Label, - FeatureColumn = featureColumn ?? DefaultColumnNames.Features, - WeightColumn = weightColumn ?? Optional.Implicit(DefaultColumnNames.Weight), - NumLeaves = numLeaves ?? default, - MinDataPerLeaf = minDataPerLeaf ?? default, - LearningRate = learningRate ?? default, - NumBoostRound = numBoostRound - }) + : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); + if (advancedSettings != null) CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); - // override with the directly provided values, giving them priority over the Advanced args, in case they are assigned twice. + // override with the directly provided values Args.NumBoostRound = numBoostRound; Args.NumLeaves = numLeaves ?? Args.NumLeaves; Args.LearningRate = learningRate ?? Args.LearningRate; diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index cbb71a502e..b515f352ad 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -60,25 +60,21 @@ public LightGbmMulticlassTrainer(IHostEnvironment env, double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null) - : this(env, new LightGbmArguments - { - LabelColumn = labelColumn ?? DefaultColumnNames.Label, - FeatureColumn = featureColumn ?? DefaultColumnNames.Features, - WeightColumn = weightColumn ?? Optional.Implicit(DefaultColumnNames.Weight), - NumLeaves = numLeaves ?? default, - MinDataPerLeaf = minDataPerLeaf ?? default, - LearningRate = learningRate ?? default, - NumBoostRound = numBoostRound - }) + : base(env, LoadNameValue, TrainerUtils.MakeU4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); + if (advancedSettings != null) CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); - // override with the directly provided values, giving them priority over the Advanced args, in case they are assigned twice. + // override with the directly provided values Args.NumBoostRound = numBoostRound; Args.NumLeaves = numLeaves ?? Args.NumLeaves; Args.LearningRate = learningRate ?? Args.LearningRate; Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf; + + _numClass = -1; } private FastTree.Internal.Ensemble GetBinaryEnsemble(int classID) diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index 865061a242..53393ff3ad 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -108,22 +108,16 @@ public LightGbmRankingTrainer(IHostEnvironment env, double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null) - : this(env, new LightGbmArguments - { - LabelColumn = labelColumn ?? DefaultColumnNames.Label, - FeatureColumn = featureColumn ?? DefaultColumnNames.Features, - GroupIdColumn = groupIdColumn ?? DefaultColumnNames.GroupId, - WeightColumn = weightColumn ?? null, // Optional.Implicit(DefaultColumnNames.Weight), - NumLeaves = numLeaves ?? default, - MinDataPerLeaf = minDataPerLeaf ?? default, - LearningRate = learningRate ?? default, - NumBoostRound = numBoostRound - }) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); + Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); + if (advancedSettings != null) CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); - // override with the directly provided values, giving them priority over the Advanced args, in case they are assigned twice. + // override with the directly provided values Args.NumBoostRound = numBoostRound; Args.NumLeaves = numLeaves ?? Args.NumLeaves; Args.LearningRate = learningRate ?? Args.LearningRate; diff --git a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs index ebb8b0eb25..ddf3bcc475 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs @@ -140,6 +140,7 @@ public static Scalar LightGbm(this RankingContext.RankingTrainers c Action onFit = null) { CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + Contracts.CheckValue(groupId, nameof(groupId)); var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index dc54981060..15308f1326 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -58,7 +58,7 @@ private sealed class CategoricalMetaData private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { Args = new LightGbmArguments(); diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index 71bca1028f..11650e25d2 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -107,7 +107,7 @@ internal RandomizedPcaTrainer(IHostEnvironment env, Arguments args) private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featureColumn, string weightColumn, int rank = 20, int oversampling = 20, bool center = true, int? seed = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), null, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { // if the args are not null, we got here from maml, and the internal ctor. if (args != null) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index c9ed4637d4..94f00e5505 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -63,7 +63,7 @@ public abstract class LinearTrainerBase : TrainerEstimator private protected LinearTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, string weightColumn = null) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), - labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) + labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { } @@ -1443,7 +1443,7 @@ public LinearClassificationTrainer(IHostEnvironment env, float? l1Threshold = null, int? maxIterations = null, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null), advancedSettings, + : base(env, featureColumn, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, l2Const, l1Threshold, maxIterations) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -1474,7 +1474,7 @@ public LinearClassificationTrainer(IHostEnvironment env, internal LinearClassificationTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null) - : base(env, args, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) + : base(env, args, TrainerUtils.MakeBoolScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { _loss = args.LossFunction.CreateComponent(env); Loss = _loss; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 281ec01769..626b054ee9 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -90,7 +90,7 @@ public MulticlassLogisticRegression(IHostEnvironment env, string featureColumn, int memorySize = Arguments.Defaults.MemorySize, bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), weightColumn, advancedSettings, + : base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), weightColumn, advancedSettings, l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -103,7 +103,7 @@ public MulticlassLogisticRegression(IHostEnvironment env, string featureColumn, /// Initializes a new instance of /// internal MulticlassLogisticRegression(IHostEnvironment env, Arguments args) - : base(env, args, TrainerUtils.MakeU4ScalarLabel(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeU4ScalarColumn(args.LabelColumn)) { ShowTrainingStats = Args.ShowTrainingStats; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index 03241f9a59..a130c3fb6f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -53,7 +53,7 @@ public sealed class Arguments : LearnerInputBaseWithLabel /// The name of the feature column. public MultiClassNaiveBayesTrainer(IHostEnvironment env, string featureColumn, string labelColumn) : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(featureColumn), - TrainerUtils.MakeU4ScalarLabel(labelColumn)) + TrainerUtils.MakeU4ScalarColumn(labelColumn)) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); @@ -64,7 +64,7 @@ public MultiClassNaiveBayesTrainer(IHostEnvironment env, string featureColumn, s /// internal MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args) : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - TrainerUtils.MakeU4ScalarLabel(args.LabelColumn)) + TrainerUtils.MakeU4ScalarColumn(args.LabelColumn)) { Host.CheckValue(args, nameof(args)); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index a2af99e7a2..f7e8462e29 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -83,7 +83,7 @@ public abstract class OnlineLinearTrainer : TrainerEstimat protected virtual bool NeedCalibration => false; protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights, args.InitialWeights != null)) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights)) { Contracts.CheckValue(args, nameof(args)); Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive); diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index 73feffa848..2eea592187 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -67,7 +67,7 @@ public SdcaMultiClassTrainer(IHostEnvironment env, float? l1Threshold = null, int? maxIterations = null, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null), advancedSettings, + : base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, l2Const, l1Threshold, maxIterations) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -78,7 +78,7 @@ public SdcaMultiClassTrainer(IHostEnvironment env, internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null) - : base(env, args, TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) + : base(env, args, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index e0818962a2..3e92670cb2 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -72,7 +72,7 @@ public SdcaRegressionTrainer(IHostEnvironment env, float? l1Threshold = null, int? maxIterations = null, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null), advancedSettings, + : base(env, featureColumn, TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, l2Const, l1Threshold, maxIterations) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -82,7 +82,7 @@ public SdcaRegressionTrainer(IHostEnvironment env, } internal SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null) - : base(env, args, TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn, weightColumn != null)) + : base(env, args, TrainerUtils.MakeR4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 5c3c65b0fd..fc3043866e 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -820,13 +820,6 @@ public void MultiClassLightGBM() Assert.Null(pred); var model = pipe.Fit(dataSource); Assert.NotNull(pred); - //int[] labelHistogram = default; - //int[][] featureHistogram = default; - //pred.GetLabelHistogram(ref labelHistogram, out int labelCount1); - //pred.GetFeatureHistogram(ref featureHistogram, out int labelCount2, out int featureCount); - //Assert.True(labelCount1 == 3 && labelCount1 == labelCount2 && labelCount1 <= labelHistogram.Length); - //for (int i = 0; i < labelCount1; i++) - // Assert.True(featureCount == 4 && (featureCount <= featureHistogram[i].Length)); var data = model.Read(dataSource); From a5840b2dcd34f06ba0ab4bf62390f92064ea3ddd Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Wed, 24 Oct 2018 12:07:31 -0700 Subject: [PATCH 3/8] Addressing PR comments --- src/Microsoft.ML.Data/Training/TrainerUtils.cs | 2 +- src/Microsoft.ML.LightGBM/LightGbmStatic.cs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index decdab3753..3c1987206e 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -382,7 +382,7 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn) /// The for the weight column. /// /// name of the weight column - /// whether the column is implicitely, or explicitely defined + /// whether the column is implicitly, or explicitly defined public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true) { if (weightColumn == null || !isExplicit) diff --git a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs index ddf3bcc475..e527a0ed92 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.StaticPipe public static class LightGbmTrainers { /// - /// LightGbm extension method. + /// Predict a target using a tree regression model trained with the LightGbm trainer. /// /// The . /// The label column. @@ -66,7 +66,7 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c } /// - /// LightGbm extension method. + /// Predict a target using a tree binary classification model trained with the LightGbm trainer. /// /// The . /// The label column. @@ -157,7 +157,7 @@ public static Scalar LightGbm(this RankingContext.RankingTrainers c } /// - /// Predict a target using a linear multiclass classification model trained with the SDCA trainer. + /// Predict a target using a tree multiclass classification model trained with the LightGbm trainer. /// /// The multiclass classification context trainer object. /// The label, or dependent variable. From ad51d71fa1777b9a5551f82c2857f99b7be3a7dc Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 26 Oct 2018 11:28:13 -0700 Subject: [PATCH 4/8] advanced arguments override the directly provided ones. --- .../Training/TrainerEstimatorBase.cs | 44 +++++++++--- .../Training/TrainerUtils.cs | 57 ---------------- src/Microsoft.ML.FastTree/BoostingFastTree.cs | 20 +++++- src/Microsoft.ML.FastTree/FastTree.cs | 53 ++++++--------- .../FastTreeClassification.cs | 16 +---- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 20 ++++-- .../FastTreeRegression.cs | 9 +-- src/Microsoft.ML.FastTree/FastTreeTweedie.cs | 18 +++-- .../GamClassification.cs | 15 ++-- src/Microsoft.ML.FastTree/GamRegression.cs | 14 ++-- src/Microsoft.ML.FastTree/GamTrainer.cs | 15 +++- src/Microsoft.ML.FastTree/RandomForest.cs | 15 +++- .../RandomForestClassification.cs | 18 +++-- .../RandomForestRegression.cs | 18 +++-- .../TreeTrainersCatalog.cs | 52 +++++++++++--- .../TreeTrainersStatic.cs | 9 +-- .../OlsLinearRegression.cs | 2 - .../LightGbmBinaryTrainer.cs | 18 ++--- src/Microsoft.ML.LightGBM/LightGbmCatalog.cs | 28 +++++--- .../LightGbmMulticlassTrainer.cs | 19 ++---- .../LightGbmRankingTrainer.cs | 18 ++--- .../LightGbmRegressionTrainer.cs | 18 ++--- src/Microsoft.ML.LightGBM/LightGbmStatic.cs | 11 +-- .../LightGbmTrainerBase.cs | 54 ++++++--------- .../FactorizationMachineCatalog.cs | 6 +- .../FactorizationMachineStatic.cs | 9 ++- .../Standard/LinearClassificationTrainer.cs | 47 +++---------- .../LogisticRegression/LbfgsPredictorBase.cs | 68 +++++++++---------- .../Standard/SdcaCatalog.cs | 17 +++-- .../Standard/SdcaMultiClass.cs | 5 +- .../Standard/SdcaRegression.cs | 5 +- .../Standard/SdcaStatic.cs | 20 ++++-- 32 files changed, 381 insertions(+), 357 deletions(-) diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index f4ea646999..ad2916368d 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -41,11 +41,6 @@ public abstract class TrainerEstimatorBase : ITrainerEstim /// public readonly SchemaShape.Column WeightColumn; - /// - /// The optional groupID column that the ranking trainers expects. - /// - public readonly SchemaShape.Column GroupIdColumn; - protected readonly IHost Host; /// @@ -58,20 +53,17 @@ public abstract class TrainerEstimatorBase : ITrainerEstim public TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, - SchemaShape.Column weight = null, - SchemaShape.Column groupId = null) + SchemaShape.Column weight = null) { Contracts.CheckValue(host, nameof(host)); Host = host; Host.CheckValue(feature, nameof(feature)); Host.CheckValueOrNull(label); Host.CheckValueOrNull(weight); - Host.CheckValueOrNull(groupId); FeatureColumn = feature; LabelColumn = label; WeightColumn = weight; - GroupIdColumn = groupId; } public TTransformer Fit(IDataView input) => TrainTransformer(input); @@ -160,9 +152,39 @@ protected TTransformer TrainTransformer(IDataView trainSet, protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema); - private RoleMappedData MakeRoles(IDataView data) => - new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name); + protected virtual RoleMappedData MakeRoles(IDataView data) => + new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name); IPredictor ITrainer.Train(TrainContext context) => Train(context); } + + /// + /// This represents a basic class for 'simple trainer'. + /// A 'simple trainer' accepts one feature column and one label column, also optionally a weight column. + /// It produces a 'prediction transformer'. + /// + public abstract class TrainerEstimatorBaseWithGroupId : TrainerEstimatorBase + where TTransformer : ISingleFeaturePredictionTransformer + where TModel : IPredictor + { + /// + /// The optional groupID column that the ranking trainers expects. + /// + public readonly SchemaShape.Column GroupIdColumn; + + public TrainerEstimatorBaseWithGroupId(IHost host, + SchemaShape.Column feature, + SchemaShape.Column label, + SchemaShape.Column weight = null, + SchemaShape.Column groupId = null) + :base(host, feature, label, weight) + { + Host.CheckValueOrNull(groupId); + GroupIdColumn = groupId; + } + + protected override RoleMappedData MakeRoles(IDataView data) => + new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name); + + } } diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index 3c1987206e..e966687a9e 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -389,63 +389,6 @@ public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, b return null; return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } - - private static void CheckArgColName(IHostEnvironment host, string defaultColName, string argValue) - { - if (argValue != defaultColName) - throw host.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead."); - } - - /// - /// Check that the label, feature, weights, groupId column names are not supplied in the args of the constructor, through the advancedSettings parameter, - /// for cases when the public constructor is called. - /// The recommendation is to set the column names directly. - /// - public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithGroupId args) - { - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn); - CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn); - CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn); - - if (args.GroupIdColumn != null) - CheckArgColName(host, DefaultColumnNames.GroupId, args.GroupIdColumn); - } - - /// - /// Check that the label, feature, and weights column names are not supplied in the args of the constructor, through the advancedSettings parameter, - /// for cases when the public constructor is called. - /// The recommendation is to set the column names directly. - /// - public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithWeight args) - { - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn); - CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn); - CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn); - } - - /// - /// Check that the label and feature column names are not supplied in the args of the constructor, through the advancedSettings parameter, - /// for cases when the public constructor is called. - /// The recommendation is to set the column names directly. - /// - public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithLabel args) - { - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn); - CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn); - } - - /// - /// If, after applying the advancedArgs delegate, the args are different that the default value - /// and are also different than the value supplied directly to the xtension method, warn the user. - /// - public static void CheckArgsAndAdvancedSettingMismatch(IChannel channel, T methodParam, T defaultVal, T setting, string argName) - { - if (!setting.Equals(defaultVal) && !setting.Equals(methodParam)) - channel.Warning($"The value supplied to advanced settings , is different than the value supplied directly. Using value {setting} for {argName}"); - } } /// diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index 7a68902038..05b8958d8c 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -21,10 +21,24 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh { } - protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, - string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) - : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings) + protected BoostingFastTreeTrainerBase(IHostEnvironment env, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, + int numLeaves, + int numTrees, + int minDocumentsInLeafs, + double learningRate, + Action advancedSettings) + : base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings) { + + if (Args.LearningRates != learningRate) + { + using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments.")) + Args.LearningRates = learningRate; + } } protected override void CheckArgs(IChannel ch) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index acb20621d9..23bed0acef 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.Conversion; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; @@ -45,7 +46,7 @@ internal static class FastTreeShared } public abstract class FastTreeTrainerBase : - TrainerEstimatorBase + TrainerEstimatorBaseWithGroupId where TTransformer: ISingleFeaturePredictionTransformer where TArgs : TreeArgs, new() where TModel : IPredictorProducing @@ -92,26 +93,36 @@ public abstract class FastTreeTrainerBase : /// /// Constructor to use when instantiating the classes deriving from here through the API. /// - private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, - string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) + private protected FastTreeTrainerBase(IHostEnvironment env, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, + int numLeaves, + int numTrees, + int minDocumentsInLeafs, + Action advancedSettings) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { Args = new TArgs(); + // set up the directly provided values + // override with the directly provided values. + Args.NumLeaves = numLeaves; + Args.NumTrees = numTrees; + Args.MinDocumentsInLeafs = minDocumentsInLeafs; + //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - TrainerUtils.CheckArgsHaveDefaultColNames(Host, Args); - Args.LabelColumn = label.Name; Args.FeatureColumn = featureColumn; if (weightColumn != null) - Args.WeightColumn = weightColumn; + Args.WeightColumn = Optional.Explicit(weightColumn); ; if (groupIdColumn != null) - Args.GroupIdColumn = groupIdColumn; + Args.GroupIdColumn = Optional.Explicit(groupIdColumn); ; // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. @@ -159,32 +170,6 @@ protected virtual Float GetMaxLabel() return Float.PositiveInfinity; } - /// - /// If, after applying the advancedSettings delegate, the args are different that the default value - /// and are also different than the value supplied directly to the xtension method, warn the user - /// about which value is being used. - /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune. - /// This list should follow the one in the constructor, and the extension methods on the . - /// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation. - /// - protected void CheckArgsAndAdvancedSettingMismatch(int numLeaves, - int numTrees, - int minDocumentsInLeafs, - double learningRate, - BoostedTreeArgs snapshot, - BoostedTreeArgs currentArgs) - { - using (var ch = Host.Start("Comparing advanced settings with the directly provided values.")) - { - - // Check that the user didn't supply different parameters in the args, from what it specified directly. - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numTrees, snapshot.NumTrees, currentArgs.NumTrees, nameof(numTrees)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDocumentsInLeafs, snapshot.MinDocumentsInLeafs, currentArgs.MinDocumentsInLeafs, nameof(minDocumentsInLeafs)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRates, currentArgs.LearningRates, nameof(learningRate)); - } - } - private void Initialize(IHostEnvironment env) { int numThreads = Args.NumThreads ?? Environment.ProcessorCount; diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 7f7080281d..b864b484ca 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -122,11 +122,11 @@ public sealed partial class FastTreeBinaryClassificationTrainer : /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. /// The learning rate. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. + /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelColumn, string featureColumn, @@ -136,22 +136,10 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, double learningRate = Defaults.LearningRates, Action advancedSettings = null) - : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss _sigmoidParameter = 2.0 * Args.LearningRates; - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, numTrees, minDocumentsInLeafs, learningRate, new Arguments(), Args); - - //override with the directly provided values. - Args.NumLeaves = numLeaves; - Args.NumTrees = numTrees; - Args.MinDocumentsInLeafs = minDocumentsInLeafs; - Args.LearningRates = learningRate; } /// diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index d96f82f741..f9d438ef83 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -67,13 +67,23 @@ public sealed partial class FastTreeRankingTrainer /// The name of the feature column. /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The max number of leaves in each regression tree. + /// Total number of decision trees to create in the ensemble. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. /// A delegate to apply all the advanced arguments to the algorithm. - public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn, - string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + public FastTreeRankingTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string groupIdColumn, + string weightColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 4cc09c9243..8186ef2bb8 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -58,11 +58,11 @@ public sealed partial class FastTreeRegressionTrainer /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. /// The learning rate. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. + /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, string featureColumn, @@ -72,13 +72,8 @@ public FastTreeRegressionTrainer(IHostEnvironment env, int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, double learningRate = Defaults.LearningRates, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, numTrees, minDocumentsInLeafs, learningRate, new Arguments(), Args); } /// diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 6ec0bd99d0..66079ae207 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -54,12 +54,22 @@ public sealed partial class FastTreeTweedieTrainer /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The learning rate. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The max number of leaves in each regression tree. + /// Total number of decision trees to create in the ensemble. /// A delegate to apply all the advanced arguments to the algorithm. - public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + public FastTreeTweedieTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index 9672bcfcdf..57c3f37eeb 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -62,13 +62,18 @@ internal BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args) /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. + /// The learning rate. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// A delegate to apply all the advanced arguments to the algorithm. - public BinaryClassificationGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) + public BinaryClassificationGamTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - _sigmoidParameter = 1; } diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 40a0ff0c93..90fd903a4c 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -51,12 +51,18 @@ internal RegressionGamTrainer(IHostEnvironment env, Arguments args) /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. + /// The learning rate. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// A delegate to apply all the advanced arguments to the algorithm. - public RegressionGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) + public RegressionGamTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); } internal override void CheckLabel(RoleMappedData data) diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index 0a673bd0a9..e4318187d0 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -132,15 +132,26 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight protected IParallelTraining ParallelTraining; - private protected GamTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, - string weightColumn = null, Action advancedSettings = null) + private protected GamTrainerBase(IHostEnvironment env, + string name, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + int minDocumentsInLeafs, + double learningRate, + Action advancedSettings) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Args = new TArgs(); + Args.MinDocuments = minDocumentsInLeafs; + Args.LearningRates = learningRate; + //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); + Args.LabelColumn = label.Name; + Args.FeatureColumn = featureColumn; if (weightColumn != null) Args.WeightColumn = weightColumn; diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index 6043774b43..f29383473c 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -28,9 +28,18 @@ protected RandomForestTrainerBase(IHostEnvironment env, TArgs args, SchemaShape. /// /// Constructor invoked by the API code-path. /// - protected RandomForestTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, - string weightColumn = null, string groupIdColumn = null, bool quantileEnabled = false, Action advancedSettings = null) - : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings) + protected RandomForestTrainerBase(IHostEnvironment env, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, + int numLeaves, + int numTrees, + int minDocumentsInLeafs, + double learningRate, + Action advancedSettings, + bool quantileEnabled = false) + : base(env, label, featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings) { _quantileEnabled = quantileEnabled; } diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index cc4f91d895..317c1eaae0 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -139,12 +139,22 @@ public sealed class Arguments : FastForestArgumentsBase /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The learning rate. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The max number of leaves in each regression tree. + /// Total number of decision trees to create in the ensemble. /// A delegate to apply all the advanced arguments to the algorithm. - public FastForestClassification(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + public FastForestClassification(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 3af5440d13..23b110f072 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -160,12 +160,22 @@ public sealed class Arguments : FastForestArgumentsBase /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The learning rate. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The max number of leaves in each regression tree. + /// Total number of decision trees to create in the ensemble. /// A delegate to apply all the advanced arguments to the algorithm. - public FastForestRegression(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, true, advancedSettings) + public FastForestRegression(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs index 5e1b26a07e..2d6d1c25ef 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -19,7 +19,7 @@ public static class TreeExtensions /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The optional weights column. /// Total number of decision trees to create in the ensemble. /// The maximum number of leaves per decision tree. @@ -46,7 +46,7 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The optional weights column. /// Total number of decision trees to create in the ensemble. /// The maximum number of leaves per decision tree. @@ -73,20 +73,28 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The groupId column. /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. /// Algorithm advanced settings. public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx, string label = DefaultColumnNames.Label, string groupId = DefaultColumnNames.GroupId, string features = DefaultColumnNames.Features, string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, Action advancedSettings = null) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeRankingTrainer(env, label, features, groupId, weights, advancedSettings); + return new FastTreeRankingTrainer(env, label, features, groupId, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); } /// @@ -94,18 +102,26 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. /// Algorithm advanced settings. public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this RegressionContext.RegressionTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, Action advancedSettings = null) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new BinaryClassificationGamTrainer(env, label, features, weights, advancedSettings); + return new BinaryClassificationGamTrainer(env, label, features, weights, minDatapointsInLeafs, learningRate, advancedSettings); } /// @@ -113,18 +129,26 @@ public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this Reg /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. /// Algorithm advanced settings. public static RegressionGamTrainer GeneralizedAdditiveMethods(this BinaryClassificationContext.BinaryClassificationTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, Action advancedSettings = null) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new RegressionGamTrainer(env, label, features, weights, advancedSettings); + return new RegressionGamTrainer(env, label, features, weights, minDatapointsInLeafs, learningRate, advancedSettings); } /// @@ -132,18 +156,26 @@ public static RegressionGamTrainer GeneralizedAdditiveMethods(this BinaryClassif /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. /// Algorithm advanced settings. public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.RegressionTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, Action advancedSettings = null) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeTweedieTrainer(env, label, features, weights, advancedSettings: advancedSettings); + return new FastTreeTweedieTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); } } } diff --git a/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs index 9078ad75cd..6d8c7d73f5 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs @@ -22,7 +22,7 @@ public static class TreeRegressionExtensions /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The optional weights column. /// Total number of decision trees to create in the ensemble. /// The maximum number of leaves per decision tree. @@ -71,7 +71,7 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The optional weights column. /// Total number of decision trees to create in the ensemble. /// The maximum number of leaves per decision tree. @@ -117,7 +117,7 @@ public static (Scalar score, Scalar probability, Scalar pred /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The groupId column. /// The optional weights column. /// Total number of decision trees to create in the ensemble. @@ -145,7 +145,8 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => { - var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, advancedSettings); + var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, numLeaves, + numTrees, minDatapointsInLeafs, learningRate, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index 24aab18377..919bf2634c 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -76,8 +76,6 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, string featureColumn, st string weightColumn = null, Action advancedSettings = null) : this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings)) { - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); } /// diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index fe770b80a4..bd9adbebbc 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -106,11 +106,14 @@ internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. /// The number of leaves to use. /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, int? numLeaves = null, @@ -118,19 +121,8 @@ public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string fe double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); - - // override with the directly provided values - Args.NumBoostRound = numBoostRound; - Args.NumLeaves = numLeaves ?? Args.NumLeaves; - Args.LearningRate = learningRate ?? Args.LearningRate; - Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf; } private protected override IPredictorWithFeatureWeights CreatePredictor() diff --git a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs index 605d35b0b1..56a892be4a 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs @@ -19,13 +19,16 @@ public static class LightGbmExtensions /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The weights column. /// The number of leaves to use. /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static LightGbmRegressorTrainer LightGbm(this RegressionContext.RegressionTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, @@ -46,13 +49,16 @@ public static LightGbmRegressorTrainer LightGbm(this RegressionContext.Regressio /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The weights column. /// The number of leaves to use. /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationContext.BinaryClassificationTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, @@ -74,14 +80,17 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationContext.Bi /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The weights column. /// The groupId column. /// The number of leaves to use. /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static LightGbmRankingTrainer LightGbm(this RankingContext.RankingTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, @@ -104,13 +113,16 @@ public static LightGbmRankingTrainer LightGbm(this RankingContext.RankingTrainer /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The weights column. /// The number of leaves to use. /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index cdb41afb73..bc73e3a41f 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -51,7 +51,10 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// A delegate to apply all the advanced arguments to the algorithm. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public LightGbmMulticlassTrainer(IHostEnvironment env, string labelColumn, string featureColumn, @@ -61,20 +64,8 @@ public LightGbmMulticlassTrainer(IHostEnvironment env, double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeU4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, LoadNameValue, TrainerUtils.MakeU4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); - - // override with the directly provided values - Args.NumBoostRound = numBoostRound; - Args.NumLeaves = numLeaves ?? Args.NumLeaves; - Args.LearningRate = learningRate ?? Args.LearningRate; - Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf; - _numClass = -1; } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index b52236f374..fdd4b09959 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -98,7 +98,10 @@ internal LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// A delegate to apply all the advanced arguments to the algorithm. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public LightGbmRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, @@ -109,20 +112,9 @@ public LightGbmRankingTrainer(IHostEnvironment env, double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); - - // override with the directly provided values - Args.NumBoostRound = numBoostRound; - Args.NumLeaves = numLeaves ?? Args.NumLeaves; - Args.LearningRate = learningRate ?? Args.LearningRate; - Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf; } protected override void CheckDataValid(IChannel ch, RoleMappedData data) diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 1d94738f79..612ef15f6d 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -92,11 +92,14 @@ public sealed class LightGbmRegressorTrainer : LightGbmTrainerBaseThe name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. /// The number of leaves to use. /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, int? numLeaves = null, @@ -104,19 +107,8 @@ public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); - - // override with the directly provided values - Args.NumBoostRound = numBoostRound; - Args.NumLeaves = numLeaves ?? Args.NumLeaves; - Args.LearningRate = learningRate ?? Args.LearningRate; - Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf; } internal LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) diff --git a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs index e527a0ed92..1bf7a6ab7d 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs @@ -22,7 +22,7 @@ public static class LightGbmTrainers /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The weights column. /// The number of leaves to use. /// Number of iterations. @@ -70,7 +70,7 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The weights column. /// The number of leaves to use. /// Number of iterations. @@ -115,7 +115,7 @@ public static (Scalar score, Scalar probability, Scalar pred /// /// The . /// The label column. - /// The features colum. + /// The features column. /// The groupId column. /// The weights column. /// The number of leaves to use. @@ -167,7 +167,10 @@ public static Scalar LightGbm(this RankingContext.RankingTrainers c /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index 4f2f84f013..0d6313a20c 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Training; using Microsoft.ML.Trainers.FastTree.Internal; @@ -26,7 +27,7 @@ internal static class LightGbmShared /// /// Base class for all training with LightGBM. /// - public abstract class LightGbmTrainerBase : TrainerEstimatorBase + public abstract class LightGbmTrainerBase : TrainerEstimatorBaseWithGroupId where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictorProducing { @@ -57,26 +58,37 @@ private sealed class CategoricalMetaData private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false, supportValid: true); public override TrainerInfo Info => _info; - private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, - string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) + private protected LightGbmTrainerBase(IHostEnvironment env, + string name, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, + int? numLeaves, + int? minDataPerLeaf, + double? learningRate, + int numBoostRound, + Action advancedSettings) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { Args = new LightGbmArguments(); + Args.NumLeaves = numLeaves; + Args.MinDataPerLeaf = minDataPerLeaf; + Args.LearningRate = learningRate; + Args.NumBoostRound = numBoostRound; + //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - TrainerUtils.CheckArgsHaveDefaultColNames(Host, Args); - Args.LabelColumn = label.Name; Args.FeatureColumn = featureColumn; if (weightColumn != null) - Args.WeightColumn = weightColumn; + Args.WeightColumn = Optional.Explicit(weightColumn); if (groupIdColumn != null) - Args.GroupIdColumn = groupIdColumn; + Args.GroupIdColumn = Optional.Explicit(groupIdColumn); InitParallelTraining(); } @@ -161,32 +173,6 @@ protected virtual void CheckDataValid(IChannel ch, RoleMappedData data) ch.CheckParam(data.Schema.Label != null, nameof(data), "Need a label column"); } - /// - /// If, after applying the advancedSettings delegate, the args are different that the default value - /// and are also different than the value supplied directly to the xtension method, warn the user - /// about which value is being used. - /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune. - /// This list should follow the one in the constructor, and the extension methods on the . - /// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation. - /// - protected void CheckArgsAndAdvancedSettingMismatch(int? numLeaves, - int? minDataPerLeaf, - double? learningRate, - int numBoostRound, - LightGbmArguments snapshot, - LightGbmArguments currentArgs) - { - using (var ch = Host.Start("Comparing advanced settings with the directly provided values.")) - { - - // Check that the user didn't supply different parameters in the args, from what it specified directly. - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numBoostRound, snapshot.NumBoostRound, currentArgs.NumBoostRound, nameof(numBoostRound)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDataPerLeaf, snapshot.MinDataPerLeaf, currentArgs.MinDataPerLeaf, nameof(minDataPerLeaf)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRate, currentArgs.LearningRate, nameof(learningRate)); - } - } - protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCategarical, int totalCats, bool hiddenMsg=false) { double learningRate = Args.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats); diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs index 564347d9c7..f94511ec76 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs @@ -21,8 +21,10 @@ public static class FactorizationMachineExtensions /// The label, or dependent variable. /// The features, or independent variables. /// The optional example weights. - /// A delegate to set more settings. - /// + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationContext.BinaryClassificationTrainers ctx, string label, string[] features, string weights = null, diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs index 733c98d28a..2a95df5dd7 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs @@ -28,8 +28,10 @@ public static class FactorizationMachineExtensions /// Initial learning rate. /// Number of training iterations. /// Latent space dimensions. - /// A delegate to set more settings. - /// A delegate that is called every time the + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the ./// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive /// the model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to @@ -57,10 +59,11 @@ public static (Scalar score, Scalar predictedLabel) FieldAwareFacto var trainer = new FieldAwareFactorizationMachineTrainer(env, labelCol, featureCols, advancedSettings: args => { - advancedSettings?.Invoke(args); args.LearningRate = learningRate; args.Iters = numIterations; args.LatentDim = numLatentDimensions; + + advancedSettings?.Invoke(args); }); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 94f00e5505..e21ea3a316 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1433,9 +1433,12 @@ internal override void Check(IHostEnvironment env) /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public LinearClassificationTrainer(IHostEnvironment env, - string featureColumn, + string featureColumn, string labelColumn, string weightColumn = null, ISupportSdcaClassificationLoss loss = null, @@ -1682,21 +1685,17 @@ public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, stri Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); _args = new Arguments(); - advancedSettings?.Invoke(_args); - - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - TrainerUtils.CheckArgsHaveDefaultColNames(Host, _args); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(maxIterations, initLearningRate, l2Weight, loss, new Arguments(), _args); + _args.MaxIterations = maxIterations; + _args.InitLearningRate = initLearningRate; + _args.L2Weight = l2Weight; // Apply the advanced args, if the user supplied any. + advancedSettings?.Invoke(_args); + _args.FeatureColumn = featureColumn; _args.LabelColumn = labelColumn; _args.WeightColumn = weightColumn; - _args.MaxIterations = maxIterations; - _args.InitLearningRate = initLearningRate; - _args.L2Weight = l2Weight; + if (loss != null) _args.LossFunction = loss; _args.Check(env); @@ -1719,30 +1718,6 @@ internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Ar _args = args; } - /// - /// If, after applying the advancedSettings delegate, the args are different that the default value - /// and are also different than the value supplied directly to the xtension method, warn the user - /// about which value is being used. - /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune. - /// This list should follow the one in the constructor, and the extension methods on the . - /// - internal void CheckArgsAndAdvancedSettingMismatch(int maxIterations, - double initLearningRate, - float l2Weight, - ISupportClassificationLossFactory loss, - Arguments snapshot, - Arguments currentArgs) - { - using (var ch = Host.Start("Comparing advanced settings with the directly provided values.")) - { - // Check that the user didn't supply different parameters in the args, from what it specified directly. - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, maxIterations, snapshot.MaxIterations, currentArgs.MaxIterations, nameof(maxIterations)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, initLearningRate, snapshot.InitLearningRate, currentArgs.InitLearningRate, nameof(initLearningRate)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, l2Weight, snapshot.L2Weight, currentArgs.L2Weight, nameof(l2Weight)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, loss, snapshot.LossFunction, currentArgs.LossFunction, nameof(loss)); - } - } - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { return new[] diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index 2733a55067..a29da7c472 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -151,31 +151,58 @@ internal static class Defaults private static readonly TrainerInfo _info = new TrainerInfo(caching: true, supportIncrementalTrain: true); public override TrainerInfo Info => _info; - internal LbfgsTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, - string weightColumn, Action advancedSettings, float l1Weight, + internal LbfgsTrainerBase(IHostEnvironment env, + string featureColumn, + SchemaShape.Column labelColumn, + string weightColumn, + Action advancedSettings, + float l1Weight, float l2Weight, float optimizationTolerance, int memorySize, bool enforceNoNegativity) - : this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings), labelColumn, - l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) + : this(env, new TArgs(), labelColumn, + l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity, advancedSettings) { } - internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn, + internal LbfgsTrainerBase(IHostEnvironment env, + TArgs args, + SchemaShape.Column labelColumn, float? l1Weight = null, float? l2Weight = null, float? optimizationTolerance = null, int? memorySize = null, - bool? enforceNoNegativity = null) + bool? enforceNoNegativity = null, + Action advancedSettings = null) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); Args = args; + L2Weight = l2Weight ?? Args.L2Weight; + L1Weight = l1Weight ?? Args.L1Weight; + OptTol = optimizationTolerance ?? Args.OptTol; + MemorySize = memorySize ?? Args.MemorySize; + MaxIterations = Args.MaxIterations; + SgdInitializationTolerance = Args.SgdInitializationTolerance; + Quiet = Args.Quiet; + InitWtsDiameter = Args.InitWtsDiameter; + UseThreads = Args.UseThreads; + NumThreads = Args.NumThreads; + DenseOptimizer = Args.DenseOptimizer; + EnforceNonNegativity = enforceNoNegativity ?? Args.EnforceNonNegativity; + + // Apply the advanced args, if the user supplied any. + advancedSettings?.Invoke(args); + + args.FeatureColumn = FeatureColumn.Name; + args.LabelColumn = LabelColumn.Name; + args.WeightColumn = WeightColumn.Name; + Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null, - nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); + nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative"); Host.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative"); Host.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive"); @@ -189,20 +216,6 @@ internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column l Host.CheckParam(!(optimizationTolerance <= 0), nameof(optimizationTolerance), "Must be positive, if provided."); Host.CheckParam(!(memorySize <= 0), nameof(memorySize), "Must be positive, if provided."); - // Review: Warn about the overriding behavior - L2Weight = l2Weight ?? Args.L2Weight; - L1Weight = l1Weight ?? Args.L1Weight; - OptTol = optimizationTolerance ?? Args.OptTol; - MemorySize = memorySize ?? Args.MemorySize; - MaxIterations = Args.MaxIterations; - SgdInitializationTolerance = Args.SgdInitializationTolerance; - Quiet = Args.Quiet; - InitWtsDiameter = Args.InitWtsDiameter; - UseThreads = Args.UseThreads; - NumThreads = Args.NumThreads; - DenseOptimizer = Args.DenseOptimizer; - EnforceNonNegativity = enforceNoNegativity ?? Args.EnforceNonNegativity; - if (EnforceNonNegativity && ShowTrainingStats) { ShowTrainingStats = false; @@ -216,19 +229,6 @@ internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column l _srcPredictor = default; } - private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn, - string weightColumn, Action advancedSettings) - { - var args = new TArgs(); - - // Apply the advanced args, if the user supplied any. - advancedSettings?.Invoke(args); - args.FeatureColumn = featureColumn; - args.LabelColumn = labelColumn.Name; - args.WeightColumn = weightColumn; - return args; - } - protected virtual int ClassCount => 1; protected int BiasCount => ClassCount; protected int WeightCount => ClassCount * NumFeatures; diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs index 35aff21a43..6b2cbd9aa2 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs @@ -25,9 +25,12 @@ public static class SdcaRegressionExtensions /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. /// The custom loss, if unspecified will be . - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static SdcaRegressionTrainer StochasticDualCoordinateAscent(this RegressionContext.RegressionTrainers ctx, - string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, string weights = null, + string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, string weights = null, ISupportSdcaRegressionLoss loss = null, float? l2Const = null, float? l1Threshold = null, @@ -53,7 +56,10 @@ public static class SdcaBinaryClassificationExtensions /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// /// /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static SdcaMultiClassTrainer StochasticDualCoordinateAscent(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index 2eea592187..ddf239d638 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -57,7 +57,10 @@ public sealed class Arguments : ArgumentsBase /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public SdcaMultiClassTrainer(IHostEnvironment env, string featureColumn, string labelColumn, diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 3e92670cb2..ae20a5ed45 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -62,7 +62,10 @@ public Arguments() /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public SdcaRegressionTrainer(IHostEnvironment env, string featureColumn, string labelColumn, diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs index 2315ac78dd..009e6013c4 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs @@ -27,7 +27,10 @@ public static class SdcaExtensions /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. /// The custom loss, if unspecified will be . - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -81,7 +84,10 @@ public static Scalar Sdca(this RegressionContext.RegressionTrainers ctx, /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -141,7 +147,10 @@ public static (Scalar score, Scalar probability, Scalar pred /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -206,7 +215,10 @@ public static (Scalar score, Scalar predictedLabel) Sdca( /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive From e03ea66dd2fc2b9eead151b922aad5f69f6d3338 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 26 Oct 2018 14:59:38 -0700 Subject: [PATCH 5/8] null check --- .../Standard/LogisticRegression/LbfgsPredictorBase.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index a29da7c472..e46f73c0ef 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -199,7 +199,7 @@ internal LbfgsTrainerBase(IHostEnvironment env, args.FeatureColumn = FeatureColumn.Name; args.LabelColumn = LabelColumn.Name; - args.WeightColumn = WeightColumn.Name; + args.WeightColumn = WeightColumn?.Name; Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null, nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); From cea3b6710c78e772fb4a1be8910713e04741494f Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 26 Oct 2018 16:25:11 -0700 Subject: [PATCH 6/8] fixing passing down the correct column names in lbfgs predictor base. --- .../LogisticRegression/LbfgsPredictorBase.cs | 80 +++++++++++++------ 1 file changed, 54 insertions(+), 26 deletions(-) diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index e46f73c0ef..048dc11d34 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -161,19 +161,24 @@ internal LbfgsTrainerBase(IHostEnvironment env, float optimizationTolerance, int memorySize, bool enforceNoNegativity) - : this(env, new TArgs(), labelColumn, - l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity, advancedSettings) + : this(env, new TArgs + { + FeatureColumn = featureColumn, + LabelColumn = labelColumn.Name, + WeightColumn = weightColumn ?? Optional.Explicit(weightColumn), + L1Weight = l1Weight, + L2Weight = l2Weight, + OptTol = optimizationTolerance, + MemorySize = memorySize, + EnforceNonNegativity = enforceNoNegativity + }, + labelColumn, advancedSettings) { } internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn, - float? l1Weight = null, - float? l2Weight = null, - float? optimizationTolerance = null, - int? memorySize = null, - bool? enforceNoNegativity = null, Action advancedSettings = null) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) @@ -181,28 +186,14 @@ internal LbfgsTrainerBase(IHostEnvironment env, Host.CheckValue(args, nameof(args)); Args = args; - L2Weight = l2Weight ?? Args.L2Weight; - L1Weight = l1Weight ?? Args.L1Weight; - OptTol = optimizationTolerance ?? Args.OptTol; - MemorySize = memorySize ?? Args.MemorySize; - MaxIterations = Args.MaxIterations; - SgdInitializationTolerance = Args.SgdInitializationTolerance; - Quiet = Args.Quiet; - InitWtsDiameter = Args.InitWtsDiameter; - UseThreads = Args.UseThreads; - NumThreads = Args.NumThreads; - DenseOptimizer = Args.DenseOptimizer; - EnforceNonNegativity = enforceNoNegativity ?? Args.EnforceNonNegativity; - // Apply the advanced args, if the user supplied any. advancedSettings?.Invoke(args); args.FeatureColumn = FeatureColumn.Name; args.LabelColumn = LabelColumn.Name; args.WeightColumn = WeightColumn?.Name; - Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null, - nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); + nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative"); Host.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative"); Host.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive"); @@ -211,10 +202,23 @@ internal LbfgsTrainerBase(IHostEnvironment env, Host.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative"); Host.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative"); - Host.CheckParam(!(l2Weight < 0), nameof(l2Weight), "Must be non-negative, if provided."); - Host.CheckParam(!(l1Weight < 0), nameof(l1Weight), "Must be non-negative, if provided"); - Host.CheckParam(!(optimizationTolerance <= 0), nameof(optimizationTolerance), "Must be positive, if provided."); - Host.CheckParam(!(memorySize <= 0), nameof(memorySize), "Must be positive, if provided."); + Host.CheckParam(!(Args.L2Weight < 0), nameof(Args.L2Weight), "Must be non-negative, if provided."); + Host.CheckParam(!(Args.L1Weight < 0), nameof(Args.L1Weight), "Must be non-negative, if provided"); + Host.CheckParam(!(Args.OptTol <= 0), nameof(Args.OptTol), "Must be positive, if provided."); + Host.CheckParam(!(Args.MemorySize <= 0), nameof(Args.MemorySize), "Must be positive, if provided."); + + L2Weight = Args.L2Weight; + L1Weight = Args.L1Weight; + OptTol = Args.OptTol; + MemorySize =Args.MemorySize; + MaxIterations = Args.MaxIterations; + SgdInitializationTolerance = Args.SgdInitializationTolerance; + Quiet = Args.Quiet; + InitWtsDiameter = Args.InitWtsDiameter; + UseThreads = Args.UseThreads; + NumThreads = Args.NumThreads; + DenseOptimizer = Args.DenseOptimizer; + EnforceNonNegativity = Args.EnforceNonNegativity; if (EnforceNonNegativity && ShowTrainingStats) { @@ -229,6 +233,30 @@ internal LbfgsTrainerBase(IHostEnvironment env, _srcPredictor = default; } + private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn, + string weightColumn, + float l1Weight, + float l2Weight, + float optimizationTolerance, + int memorySize, + bool enforceNoNegativity) + { + var args = new TArgs + { + FeatureColumn = featureColumn, + LabelColumn = labelColumn.Name, + WeightColumn = weightColumn ?? Optional.Explicit(weightColumn), + L1Weight = l1Weight, + L2Weight = l2Weight, + OptTol = optimizationTolerance, + MemorySize = memorySize, + EnforceNonNegativity = enforceNoNegativity + }; + + args.WeightColumn = weightColumn; + return args; + } + protected virtual int ClassCount => 1; protected int BiasCount => ClassCount; protected int WeightCount => ClassCount * NumFeatures; From 577c9c407d2540f68ba522d1c00eafc1e25c52f2 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 26 Oct 2018 21:53:48 -0700 Subject: [PATCH 7/8] formatting and comments --- src/Microsoft.ML.FastTree/BoostingFastTree.cs | 2 +- src/Microsoft.ML.FastTree/FastTreeRanking.cs | 2 +- src/Microsoft.ML.FastTree/GamRegression.cs | 6 +++--- .../RandomForestClassification.cs | 10 +++++----- src/Microsoft.ML.FastTree/TreeTrainersStatic.cs | 1 + src/Microsoft.ML.LightGBM/LightGbmStatic.cs | 8 ++++---- .../Standard/LinearClassificationTrainer.cs | 2 +- 7 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index 05b8958d8c..99658d2a89 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -37,7 +37,7 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, if (Args.LearningRates != learningRate) { using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments.")) - Args.LearningRates = learningRate; + Args.LearningRates = learningRate; } } diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index f9d438ef83..b7baf5999e 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -91,7 +91,7 @@ public FastTreeRankingTrainer(IHostEnvironment env, /// Initializes a new instance of by using the legacy class. /// internal FastTreeRankingTrainer(IHostEnvironment env, Arguments args) - : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 90fd903a4c..481a991a8e 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -6,11 +6,11 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.FastTree.Internal; using System; [assembly: LoadableClass(RegressionGamTrainer.Summary, @@ -51,8 +51,8 @@ internal RegressionGamTrainer(IHostEnvironment env, Arguments args) /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// The learning rate. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. /// A delegate to apply all the advanced arguments to the algorithm. public RegressionGamTrainer(IHostEnvironment env, string labelColumn, diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 317c1eaae0..bfa5efa460 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -7,12 +7,12 @@ using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.FastTree.Internal; using System; using System.Linq; @@ -80,7 +80,7 @@ private static VersionInfo GetVersionInfo() public FastForestClassificationPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) - { } + { } private FastForestClassificationPredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) @@ -140,10 +140,10 @@ public sealed class Arguments : FastForestArgumentsBase /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// The learning rate. - /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. /// A delegate to apply all the advanced arguments to the algorithm. public FastForestClassification(IHostEnvironment env, string labelColumn, diff --git a/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs index 6d8c7d73f5..4ef3dddf28 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs @@ -126,6 +126,7 @@ public static (Scalar score, Scalar probability, Scalar pred /// The learning rate. /// Algorithm advanced settings. /// A delegate that is called every time the + /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive /// the linear model that was trained. Note that this action cannot change the result in any way; diff --git a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs index 1bf7a6ab7d..445a1ad6ad 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.StaticPipe public static class LightGbmTrainers { /// - /// Predict a target using a tree regression model trained with the LightGbm trainer. + /// Predict a target using a tree regression model trained with the . /// /// The . /// The label column. @@ -66,7 +66,7 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c } /// - /// Predict a target using a tree binary classification model trained with the LightGbm trainer. + /// Predict a target using a tree binary classification model trained with the . /// /// The . /// The label column. @@ -111,7 +111,7 @@ public static (Scalar score, Scalar probability, Scalar pred } /// - /// LightGbm extension method. + /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . /// /// The . /// The label column. @@ -157,7 +157,7 @@ public static Scalar LightGbm(this RankingContext.RankingTrainers c } /// - /// Predict a target using a tree multiclass classification model trained with the LightGbm trainer. + /// Predict a target using a tree multiclass classification model trained with the . /// /// The multiclass classification context trainer object. /// The label, or dependent variable. diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 79c486b76c..3c3e0ea13b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1438,7 +1438,7 @@ internal override void Check(IHostEnvironment env) /// if both are present and have different values. /// The columns names, however need to be provided directly, not through the . public LinearClassificationTrainer(IHostEnvironment env, - string featureColumn, + string featureColumn, string labelColumn, string weightColumn = null, ISupportSdcaClassificationLoss loss = null, From 2c6d44613ba09551dd382aad4251b1e74f9f7437 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Fri, 26 Oct 2018 22:15:40 -0700 Subject: [PATCH 8/8] fix XML --- src/Microsoft.ML.FastTree/TreeTrainersStatic.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs index 4ef3dddf28..f3008957f4 100644 --- a/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs @@ -126,13 +126,12 @@ public static (Scalar score, Scalar probability, Scalar pred /// The learning rate. /// Algorithm advanced settings. /// A delegate that is called every time the - /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive /// the linear model that was trained. Note that this action cannot change the result in any way; /// it is only a way for the caller to be informed about what was learnt. /// The Score output column indicating the predicted value. - public static Scalar FastTree(this RankingContext.RankingTrainers ctx, + public static Scalar FastTree(this RankingContext.RankingTrainers ctx, Scalar label, Vector features, Key groupId, Scalar weights = null, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees,