diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index ad074fd737..ad2916368d 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -50,7 +50,10 @@ public abstract class TrainerEstimatorBase : ITrainerEstim public abstract PredictionKind PredictionKind { get; } - public TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null) + public TrainerEstimatorBase(IHost host, + SchemaShape.Column feature, + SchemaShape.Column label, + SchemaShape.Column weight = null) { Contracts.CheckValue(host, nameof(host)); Host = host; @@ -149,9 +152,39 @@ protected TTransformer TrainTransformer(IDataView trainSet, protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema); - private RoleMappedData MakeRoles(IDataView data) => + protected virtual RoleMappedData MakeRoles(IDataView data) => new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name); IPredictor ITrainer.Train(TrainContext context) => Train(context); } + + /// + /// This represents a basic class for 'simple trainer'. + /// A 'simple trainer' accepts one feature column and one label column, also optionally a weight column. + /// It produces a 'prediction transformer'. + /// + public abstract class TrainerEstimatorBaseWithGroupId : TrainerEstimatorBase + where TTransformer : ISingleFeaturePredictionTransformer + where TModel : IPredictor + { + /// + /// The optional groupID column that the ranking trainers expects. + /// + public readonly SchemaShape.Column GroupIdColumn; + + public TrainerEstimatorBaseWithGroupId(IHost host, + SchemaShape.Column feature, + SchemaShape.Column label, + SchemaShape.Column weight = null, + SchemaShape.Column groupId = null) + :base(host, feature, label, weight) + { + Host.CheckValueOrNull(groupId); + GroupIdColumn = groupId; + } + + protected override RoleMappedData MakeRoles(IDataView data) => + new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name); + + } } diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index dff5748635..e966687a9e 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -362,9 +362,14 @@ public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn) /// /// The for the label column for regression tasks. /// - /// name of the weight column - public static SchemaShape.Column MakeU4ScalarLabel(string labelColumn) - => new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); + /// name of the weight column + public static SchemaShape.Column MakeU4ScalarColumn(string columnName) + { + if (columnName == null) + return null; + + return new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); + } /// /// The for the feature column. @@ -377,69 +382,13 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn) /// The for the weight column. /// /// name of the weight column - public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn) + /// whether the column is implicitly, or explicitly defined + public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true) { - if (weightColumn == null) + if (weightColumn == null || !isExplicit) return null; return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); } - - private static void CheckArgColName(IHostEnvironment host, string defaultColName, string argValue) - { - if (argValue != defaultColName) - throw host.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead."); - } - - /// - /// Check that the label, feature, weights, groupId column names are not supplied in the args of the constructor, through the advancedSettings parameter, - /// for cases when the public constructor is called. - /// The recommendation is to set the column names directly. - /// - public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithGroupId args) - { - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn); - CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn); - CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn); - - if (args.GroupIdColumn != null) - CheckArgColName(host, DefaultColumnNames.GroupId, args.GroupIdColumn); - } - - /// - /// Check that the label, feature, and weights column names are not supplied in the args of the constructor, through the advancedSettings parameter, - /// for cases when the public constructor is called. - /// The recommendation is to set the column names directly. - /// - public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithWeight args) - { - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn); - CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn); - CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn); - } - - /// - /// Check that the label and feature column names are not supplied in the args of the constructor, through the advancedSettings parameter, - /// for cases when the public constructor is called. - /// The recommendation is to set the column names directly. - /// - public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithLabel args) - { - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn); - CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn); - } - - /// - /// If, after applying the advancedArgs delegate, the args are different that the default value - /// and are also different than the value supplied directly to the xtension method, warn the user. - /// - public static void CheckArgsAndAdvancedSettingMismatch(IChannel channel, T methodParam, T defaultVal, T setting, string argName) - { - if (!setting.Equals(defaultVal) && !setting.Equals(methodParam)) - channel.Warning($"The value supplied to advanced settings , is different than the value supplied directly. Using value {setting} for {argName}"); - } } /// diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index 7a68902038..99658d2a89 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -21,10 +21,24 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh { } - protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, - string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) - : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings) + protected BoostingFastTreeTrainerBase(IHostEnvironment env, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, + int numLeaves, + int numTrees, + int minDocumentsInLeafs, + double learningRate, + Action advancedSettings) + : base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings) { + + if (Args.LearningRates != learningRate) + { + using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments.")) + Args.LearningRates = learningRate; + } } protected override void CheckArgs(IChannel ch) diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 0aa360cd62..23bed0acef 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.Conversion; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Internal.Utilities; @@ -45,7 +46,7 @@ internal static class FastTreeShared } public abstract class FastTreeTrainerBase : - TrainerEstimatorBase + TrainerEstimatorBaseWithGroupId where TTransformer: ISingleFeaturePredictionTransformer where TArgs : TreeArgs, new() where TModel : IPredictorProducing @@ -92,26 +93,36 @@ public abstract class FastTreeTrainerBase : /// /// Constructor to use when instantiating the classes deriving from here through the API. /// - private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, - string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + private protected FastTreeTrainerBase(IHostEnvironment env, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, + int numLeaves, + int numTrees, + int minDocumentsInLeafs, + Action advancedSettings) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { Args = new TArgs(); + // set up the directly provided values + // override with the directly provided values. + Args.NumLeaves = numLeaves; + Args.NumTrees = numTrees; + Args.MinDocumentsInLeafs = minDocumentsInLeafs; + //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - TrainerUtils.CheckArgsHaveDefaultColNames(Host, Args); - Args.LabelColumn = label.Name; Args.FeatureColumn = featureColumn; if (weightColumn != null) - Args.WeightColumn = weightColumn; + Args.WeightColumn = Optional.Explicit(weightColumn); ; if (groupIdColumn != null) - Args.GroupIdColumn = groupIdColumn; + Args.GroupIdColumn = Optional.Explicit(groupIdColumn); ; // The discretization step renders this trainer non-parametric, and therefore it does not need normalization. // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching. @@ -128,7 +139,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l /// Legacy constructor that is used when invoking the classes deriving from this, through maml. /// private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); Args = args; @@ -159,32 +170,6 @@ protected virtual Float GetMaxLabel() return Float.PositiveInfinity; } - /// - /// If, after applying the advancedSettings delegate, the args are different that the default value - /// and are also different than the value supplied directly to the xtension method, warn the user - /// about which value is being used. - /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune. - /// This list should follow the one in the constructor, and the extension methods on the . - /// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation. - /// - protected void CheckArgsAndAdvancedSettingMismatch(int numLeaves, - int numTrees, - int minDocumentsInLeafs, - double learningRate, - BoostedTreeArgs snapshot, - BoostedTreeArgs currentArgs) - { - using (var ch = Host.Start("Comparing advanced settings with the directly provided values.")) - { - - // Check that the user didn't supply different parameters in the args, from what it specified directly. - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numTrees, snapshot.NumTrees, currentArgs.NumTrees, nameof(numTrees)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDocumentsInLeafs, snapshot.MinDocumentsInLeafs, currentArgs.MinDocumentsInLeafs, nameof(minDocumentsInLeafs)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRates, currentArgs.LearningRates, nameof(learningRate)); - } - } - private void Initialize(IHostEnvironment env) { int numThreads = Args.NumThreads ?? Environment.ProcessorCount; diff --git a/src/Microsoft.ML.FastTree/FastTreeCatalog.cs b/src/Microsoft.ML.FastTree/FastTreeCatalog.cs deleted file mode 100644 index 103a5676f1..0000000000 --- a/src/Microsoft.ML.FastTree/FastTreeCatalog.cs +++ /dev/null @@ -1,100 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers.FastTree; -using System; - -namespace Microsoft.ML -{ - /// - /// FastTree extension methods. - /// - public static class FastTreeRegressionExtensions - { - /// - /// Predict a target using a decision tree regression model trained with the . - /// - /// The . - /// The label column. - /// The features column. - /// The optional weights column. - /// Total number of decision trees to create in the ensemble. - /// The maximum number of leaves per decision tree. - /// The minimal number of datapoints allowed in a leaf of a regression tree, out of the subsampled data. - /// The learning rate. - /// Algorithm advanced settings. - public static FastTreeRegressionTrainer FastTree(this RegressionContext.RegressionTrainers ctx, - string label = DefaultColumnNames.Label, - string features = DefaultColumnNames.Features, - string weights = null, - int numLeaves = Defaults.NumLeaves, - int numTrees = Defaults.NumTrees, - int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, - double learningRate = Defaults.LearningRates, - Action advancedSettings = null) - { - Contracts.CheckValue(ctx, nameof(ctx)); - var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeRegressionTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); - } - } - - public static class FastTreeBinaryClassificationExtensions - { - - /// - /// Predict a target using a decision tree binary classification model trained with the . - /// - /// The . - /// The label column. - /// The features column. - /// The optional weights column. - /// Total number of decision trees to create in the ensemble. - /// The maximum number of leaves per decision tree. - /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. - /// The learning rate. - /// Algorithm advanced settings. - public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, - string label = DefaultColumnNames.Label, - string features = DefaultColumnNames.Features, - string weights = null, - int numLeaves = Defaults.NumLeaves, - int numTrees = Defaults.NumTrees, - int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, - double learningRate = Defaults.LearningRates, - Action advancedSettings = null) - { - Contracts.CheckValue(ctx, nameof(ctx)); - var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeBinaryClassificationTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); - } - } - - public static class FastTreeRankingExtensions - { - - /// - /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . - /// - /// The . - /// The label column. - /// The features column. - /// The groupId column. - /// The optional weights column. - /// Algorithm advanced settings. - public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx, - string label = DefaultColumnNames.Label, - string groupId = DefaultColumnNames.GroupId, - string features = DefaultColumnNames.Features, - string weights = null, - Action advancedSettings = null) - { - Contracts.CheckValue(ctx, nameof(ctx)); - var env = CatalogUtils.GetEnvironment(ctx); - return new FastTreeRankingTrainer(env, label, features, groupId, weights, advancedSettings); - } - } -} diff --git a/src/Microsoft.ML.FastTree/FastTreeClassification.cs b/src/Microsoft.ML.FastTree/FastTreeClassification.cs index 7f7080281d..b864b484ca 100644 --- a/src/Microsoft.ML.FastTree/FastTreeClassification.cs +++ b/src/Microsoft.ML.FastTree/FastTreeClassification.cs @@ -122,11 +122,11 @@ public sealed partial class FastTreeBinaryClassificationTrainer : /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. /// The learning rate. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. + /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelColumn, string featureColumn, @@ -136,22 +136,10 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, double learningRate = Defaults.LearningRates, Action advancedSettings = null) - : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - // Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss _sigmoidParameter = 2.0 * Args.LearningRates; - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, numTrees, minDocumentsInLeafs, learningRate, new Arguments(), Args); - - //override with the directly provided values. - Args.NumLeaves = numLeaves; - Args.NumTrees = numTrees; - Args.MinDocumentsInLeafs = minDocumentsInLeafs; - Args.LearningRates = learningRate; } /// diff --git a/src/Microsoft.ML.FastTree/FastTreeRanking.cs b/src/Microsoft.ML.FastTree/FastTreeRanking.cs index d96f82f741..b7baf5999e 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRanking.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRanking.cs @@ -67,13 +67,23 @@ public sealed partial class FastTreeRankingTrainer /// The name of the feature column. /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The max number of leaves in each regression tree. + /// Total number of decision trees to create in the ensemble. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. /// A delegate to apply all the advanced arguments to the algorithm. - public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string groupIdColumn, - string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + public FastTreeRankingTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string groupIdColumn, + string weightColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); } @@ -81,7 +91,7 @@ public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string f /// Initializes a new instance of by using the legacy class. /// internal FastTreeRankingTrainer(IHostEnvironment env, Arguments args) - : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeR4ScalarLabel(args.LabelColumn)) { } diff --git a/src/Microsoft.ML.FastTree/FastTreeRegression.cs b/src/Microsoft.ML.FastTree/FastTreeRegression.cs index 4cc09c9243..8186ef2bb8 100644 --- a/src/Microsoft.ML.FastTree/FastTreeRegression.cs +++ b/src/Microsoft.ML.FastTree/FastTreeRegression.cs @@ -58,11 +58,11 @@ public sealed partial class FastTreeRegressionTrainer /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. /// The learning rate. /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// The max number of leaves in each regression tree. /// Total number of decision trees to create in the ensemble. + /// A delegate to apply all the advanced arguments to the algorithm. public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, string featureColumn, @@ -72,13 +72,8 @@ public FastTreeRegressionTrainer(IHostEnvironment env, int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, double learningRate = Defaults.LearningRates, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, numTrees, minDocumentsInLeafs, learningRate, new Arguments(), Args); } /// diff --git a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs index 479d83399f..66079ae207 100644 --- a/src/Microsoft.ML.FastTree/FastTreeTweedie.cs +++ b/src/Microsoft.ML.FastTree/FastTreeTweedie.cs @@ -35,10 +35,10 @@ namespace Microsoft.ML.Trainers.FastTree public sealed partial class FastTreeTweedieTrainer : BoostingFastTreeTrainerBase, FastTreeTweediePredictor> { - public const string LoadNameValue = "FastTreeTweedieRegression"; - public const string UserNameValue = "FastTree (Boosted Trees) Tweedie Regression"; - public const string Summary = "Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression."; - public const string ShortName = "fttweedie"; + internal const string LoadNameValue = "FastTreeTweedieRegression"; + internal const string UserNameValue = "FastTree (Boosted Trees) Tweedie Regression"; + internal const string Summary = "Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression."; + internal const string ShortName = "fttweedie"; private TestHistory _firstTestSetHistory; private Test _trainRegressionTest; @@ -54,12 +54,22 @@ public sealed partial class FastTreeTweedieTrainer /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The learning rate. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The max number of leaves in each regression tree. + /// Total number of decision trees to create in the ensemble. /// A delegate to apply all the advanced arguments to the algorithm. - public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + public FastTreeTweedieTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/GamClassification.cs b/src/Microsoft.ML.FastTree/GamClassification.cs index 9672bcfcdf..57c3f37eeb 100644 --- a/src/Microsoft.ML.FastTree/GamClassification.cs +++ b/src/Microsoft.ML.FastTree/GamClassification.cs @@ -62,13 +62,18 @@ internal BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args) /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. + /// The learning rate. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. /// A delegate to apply all the advanced arguments to the algorithm. - public BinaryClassificationGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) + public BinaryClassificationGamTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - _sigmoidParameter = 1; } diff --git a/src/Microsoft.ML.FastTree/GamRegression.cs b/src/Microsoft.ML.FastTree/GamRegression.cs index 40a0ff0c93..481a991a8e 100644 --- a/src/Microsoft.ML.FastTree/GamRegression.cs +++ b/src/Microsoft.ML.FastTree/GamRegression.cs @@ -6,11 +6,11 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.FastTree.Internal; using System; [assembly: LoadableClass(RegressionGamTrainer.Summary, @@ -51,12 +51,18 @@ internal RegressionGamTrainer(IHostEnvironment env, Arguments args) /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. /// A delegate to apply all the advanced arguments to the algorithm. - public RegressionGamTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, advancedSettings) + public RegressionGamTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, minDocumentsInLeafs, learningRate, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); } internal override void CheckLabel(RoleMappedData data) diff --git a/src/Microsoft.ML.FastTree/GamTrainer.cs b/src/Microsoft.ML.FastTree/GamTrainer.cs index e44196b04c..e4318187d0 100644 --- a/src/Microsoft.ML.FastTree/GamTrainer.cs +++ b/src/Microsoft.ML.FastTree/GamTrainer.cs @@ -132,15 +132,26 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight protected IParallelTraining ParallelTraining; - private protected GamTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, - string weightColumn = null, Action advancedSettings = null) + private protected GamTrainerBase(IHostEnvironment env, + string name, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + int minDocumentsInLeafs, + double learningRate, + Action advancedSettings) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Args = new TArgs(); + Args.MinDocuments = minDocumentsInLeafs; + Args.LearningRates = learningRate; + //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); + Args.LabelColumn = label.Name; + Args.FeatureColumn = featureColumn; if (weightColumn != null) Args.WeightColumn = weightColumn; @@ -154,7 +165,7 @@ private protected GamTrainerBase(IHostEnvironment env, string name, SchemaShape. private protected GamTrainerBase(IHostEnvironment env, TArgs args, string name, SchemaShape.Column label) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Contracts.CheckValue(env, nameof(env)); Host.CheckValue(args, nameof(args)); diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index 6043774b43..f29383473c 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -28,9 +28,18 @@ protected RandomForestTrainerBase(IHostEnvironment env, TArgs args, SchemaShape. /// /// Constructor invoked by the API code-path. /// - protected RandomForestTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, - string weightColumn = null, string groupIdColumn = null, bool quantileEnabled = false, Action advancedSettings = null) - : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings) + protected RandomForestTrainerBase(IHostEnvironment env, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, + int numLeaves, + int numTrees, + int minDocumentsInLeafs, + double learningRate, + Action advancedSettings, + bool quantileEnabled = false) + : base(env, label, featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings) { _quantileEnabled = quantileEnabled; } diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index cc4f91d895..bfa5efa460 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -7,12 +7,12 @@ using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Trainers.FastTree; -using Microsoft.ML.Trainers.FastTree.Internal; using Microsoft.ML.Runtime.Internal.Calibration; using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.FastTree.Internal; using System; using System.Linq; @@ -80,7 +80,7 @@ private static VersionInfo GetVersionInfo() public FastForestClassificationPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs) : base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs) - { } + { } private FastForestClassificationPredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx, GetVersionInfo()) @@ -139,12 +139,22 @@ public sealed class Arguments : FastForestArgumentsBase /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The max number of leaves in each regression tree. + /// Total number of decision trees to create in the ensemble. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. /// A delegate to apply all the advanced arguments to the algorithm. - public FastForestClassification(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings: advancedSettings) + public FastForestClassification(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/RandomForestRegression.cs b/src/Microsoft.ML.FastTree/RandomForestRegression.cs index 3af5440d13..23b110f072 100644 --- a/src/Microsoft.ML.FastTree/RandomForestRegression.cs +++ b/src/Microsoft.ML.FastTree/RandomForestRegression.cs @@ -160,12 +160,22 @@ public sealed class Arguments : FastForestArgumentsBase /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. /// The name for the column containing the initial weight. + /// The learning rate. + /// The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data. + /// The max number of leaves in each regression tree. + /// Total number of decision trees to create in the ensemble. /// A delegate to apply all the advanced arguments to the algorithm. - public FastForestRegression(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, true, advancedSettings) + public FastForestRegression(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + : base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings) { Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs new file mode 100644 index 0000000000..2d6d1c25ef --- /dev/null +++ b/src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs @@ -0,0 +1,181 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Trainers.FastTree; +using System; + +namespace Microsoft.ML +{ + /// + /// FastTree extension methods. + /// + public static class TreeExtensions + { + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of a regression tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static FastTreeRegressionTrainer FastTree(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeRegressionTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); + } + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeBinaryClassificationTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); + } + + /// + /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . + /// + /// The . + /// The label column. + /// The features column. + /// The groupId column. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx, + string label = DefaultColumnNames.Label, + string groupId = DefaultColumnNames.GroupId, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeRankingTrainer(env, label, features, groupId, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); + } + + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static BinaryClassificationGamTrainer GeneralizedAdditiveMethods(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new BinaryClassificationGamTrainer(env, label, features, weights, minDatapointsInLeafs, learningRate, advancedSettings); + } + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static RegressionGamTrainer GeneralizedAdditiveMethods(this BinaryClassificationContext.BinaryClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new RegressionGamTrainer(env, label, features, weights, minDatapointsInLeafs, learningRate, advancedSettings); + } + + /// + /// Predict a target using a decision tree regression model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The optional weights column. + /// Total number of decision trees to create in the ensemble. + /// The maximum number of leaves per decision tree. + /// The minimal number of datapoints allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.RegressionTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int numLeaves = Defaults.NumLeaves, + int numTrees = Defaults.NumTrees, + int minDatapointsInLeafs = Defaults.MinDocumentsInLeafs, + double learningRate = Defaults.LearningRates, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new FastTreeTweedieTrainer(env, label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings); + } + } +} diff --git a/src/Microsoft.ML.FastTree/FastTreeStatic.cs b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs similarity index 90% rename from src/Microsoft.ML.FastTree/FastTreeStatic.cs rename to src/Microsoft.ML.FastTree/TreeTrainersStatic.cs index 1aa8528b3e..f3008957f4 100644 --- a/src/Microsoft.ML.FastTree/FastTreeStatic.cs +++ b/src/Microsoft.ML.FastTree/TreeTrainersStatic.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.StaticPipe /// /// FastTree extension methods. /// - public static class FastTreeRegressionExtensions + public static class TreeRegressionExtensions { /// /// FastTree extension method. @@ -50,7 +50,7 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c Action advancedSettings = null, Action onFit = null) { - FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => @@ -64,10 +64,6 @@ public static Scalar FastTree(this RegressionContext.RegressionTrainers c return rec.Score; } - } - - public static class FastTreeBinaryClassificationExtensions - { /// /// FastTree extension method. @@ -98,7 +94,7 @@ public static (Scalar score, Scalar probability, Scalar pred Action advancedSettings = null, Action> onFit = null) { - FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => @@ -114,10 +110,6 @@ public static (Scalar score, Scalar probability, Scalar pred return rec.Output; } - } - - public static class FastTreeRankingExtensions - { /// /// FastTree . @@ -139,7 +131,7 @@ public static class FastTreeRankingExtensions /// the linear model that was trained. Note that this action cannot change the result in any way; /// it is only a way for the caller to be informed about what was learnt. /// The Score output column indicating the predicted value. - public static Scalar FastTree(this RankingContext.RankingTrainers ctx, + public static Scalar FastTree(this RankingContext.RankingTrainers ctx, Scalar label, Vector features, Key groupId, Scalar weights = null, int numLeaves = Defaults.NumLeaves, int numTrees = Defaults.NumTrees, @@ -148,12 +140,13 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c Action advancedSettings = null, Action onFit = null) { - FastTreeStaticsUtils.CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, numTrees, minDatapointsInLeafs, learningRate, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.Ranker( (env, labelName, featuresName, groupIdName, weightsName) => { - var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, advancedSettings); + var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, numLeaves, + numTrees, minDatapointsInLeafs, learningRate, advancedSettings); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); return trainer; @@ -161,17 +154,14 @@ public static Scalar FastTree(this RankingContext.RankingTrainers c return rec.Score; } - } - internal class FastTreeStaticsUtils - { internal static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, - int numLeaves, - int numTrees, - int minDatapointsInLeafs, - double learningRate, - Delegate advancedSettings, - Delegate onFit) + int numLeaves, + int numTrees, + int minDatapointsInLeafs, + double learningRate, + Delegate advancedSettings, + Delegate onFit) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); diff --git a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs index 502b6f8e61..cba842a635 100644 --- a/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs +++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs @@ -76,8 +76,6 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, string featureColumn, st string weightColumn = null, Action advancedSettings = null) : this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings)) { - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); } /// @@ -85,7 +83,7 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, string featureColumn, st /// internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args) : base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); Host.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "L2 regularization term cannot be negative"); diff --git a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs index fe770b80a4..bd9adbebbc 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs @@ -106,11 +106,14 @@ internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args) /// The name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. /// The number of leaves to use. /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, int? numLeaves = null, @@ -118,19 +121,8 @@ public LightGbmBinaryTrainer(IHostEnvironment env, string labelColumn, string fe double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); - - // override with the directly provided values - Args.NumBoostRound = numBoostRound; - Args.NumLeaves = numLeaves ?? Args.NumLeaves; - Args.LearningRate = learningRate ?? Args.LearningRate; - Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf; } private protected override IPredictorWithFeatureWeights CreatePredictor() diff --git a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs index 31ae1a0f53..56a892be4a 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmCatalog.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML /// /// Regression trainer estimators. /// - public static class LightGbmRegressionExtensions + public static class LightGbmExtensions { /// /// Predict a target using a decision tree regression model trained with the . @@ -25,7 +25,10 @@ public static class LightGbmRegressionExtensions /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static LightGbmRegressorTrainer LightGbm(this RegressionContext.RegressionTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, @@ -40,13 +43,6 @@ public static LightGbmRegressorTrainer LightGbm(this RegressionContext.Regressio var env = CatalogUtils.GetEnvironment(ctx); return new LightGbmRegressorTrainer(env, label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); } - } - - /// - /// Binary Classification trainer estimators. - /// - public static class LightGbmClassificationExtensions - { /// /// Predict a target using a decision tree binary classification model trained with the . @@ -59,7 +55,10 @@ public static class LightGbmClassificationExtensions /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. - /// Algorithm advanced settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationContext.BinaryClassificationTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, @@ -75,5 +74,69 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationContext.Bi return new LightGbmBinaryTrainer(env, label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); } + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The weights column. + /// The groupId column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . + public static LightGbmRankingTrainer LightGbm(this RankingContext.RankingTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string groupId = DefaultColumnNames.GroupId, + string weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new LightGbmRankingTrainer(env, label, features, groupId, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + } + + /// + /// Predict a target using a decision tree binary classification model trained with the . + /// + /// The . + /// The label column. + /// The features column. + /// The weights column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . + public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + string label = DefaultColumnNames.Label, + string features = DefaultColumnNames.Features, + string weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + { + Contracts.CheckValue(ctx, nameof(ctx)); + var env = CatalogUtils.GetEnvironment(ctx); + return new LightGbmMulticlassTrainer(env, label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + } } } diff --git a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs index 579c612579..bc73e3a41f 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmMulticlassTrainer.cs @@ -46,15 +46,26 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env, LightGbmArguments args) /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. - public LightGbmMulticlassTrainer(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn = null, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . + public LightGbmMulticlassTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string weightColumn = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeU4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); _numClass = -1; } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs index dfd1652394..fdd4b09959 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRankingTrainer.cs @@ -92,15 +92,28 @@ internal LightGbmRankingTrainer(IHostEnvironment env, LightGbmArguments args) /// The private instance of . /// The name of the label column. /// The name of the feature column. - /// The name for the column containing the group ID. - /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. - public LightGbmRankingTrainer(IHostEnvironment env, string labelColumn, string featureColumn, - string groupIdColumn, string weightColumn = null, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) + /// The name of the column containing the group ID. + /// The name of the column containing the initial weight. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . + public LightGbmRankingTrainer(IHostEnvironment env, + string labelColumn, + string featureColumn, + string groupIdColumn, + string weightColumn = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn)); } diff --git a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs index 1d94738f79..612ef15f6d 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmRegressionTrainer.cs @@ -92,11 +92,14 @@ public sealed class LightGbmRegressorTrainer : LightGbmTrainerBaseThe name of the label column. /// The name of the feature column. /// The name for the column containing the initial weight. - /// A delegate to apply all the advanced arguments to the algorithm. /// The number of leaves to use. /// Number of iterations. /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string featureColumn, string weightColumn = null, int? numLeaves = null, @@ -104,19 +107,8 @@ public LightGbmRegressorTrainer(IHostEnvironment env, string labelColumn, string double? learningRate = null, int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, Action advancedSettings = null) - : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) + : base(env, LoadNameValue, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings) { - Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); - Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(numLeaves, minDataPerLeaf, learningRate, numBoostRound, new LightGbmArguments(), Args); - - // override with the directly provided values - Args.NumBoostRound = numBoostRound; - Args.NumLeaves = numLeaves ?? Args.NumLeaves; - Args.LearningRate = learningRate ?? Args.LearningRate; - Args.MinDataPerLeaf = minDataPerLeaf ?? Args.MinDataPerLeaf; } internal LightGbmRegressorTrainer(IHostEnvironment env, LightGbmArguments args) diff --git a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs index 361816389d..445a1ad6ad 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmStatic.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmStatic.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.StaticPipe.Runtime; using System; @@ -14,10 +15,10 @@ namespace Microsoft.ML.StaticPipe /// /// Regression trainer estimators. /// - public static partial class RegressionTrainers + public static class LightGbmTrainers { /// - /// LightGbm extension method. + /// Predict a target using a tree regression model trained with the . /// /// The . /// The label column. @@ -49,7 +50,7 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c Action advancedSettings = null, Action onFit = null) { - LightGbmStaticsUtils.CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.Regression( (env, labelName, featuresName, weightsName) => @@ -63,15 +64,9 @@ public static Scalar LightGbm(this RegressionContext.RegressionTrainers c return rec.Score; } - } - - /// - /// Binary Classification trainer estimators. - /// - public static partial class ClassificationTrainers { /// - /// LightGbm extension method. + /// Predict a target using a tree binary classification model trained with the . /// /// The . /// The label column. @@ -98,7 +93,7 @@ public static (Scalar score, Scalar probability, Scalar pred Action advancedSettings = null, Action> onFit = null) { - LightGbmStaticsUtils.CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); var rec = new TrainerEstimatorReconciler.BinaryClassifier( (env, labelName, featuresName, weightsName) => @@ -114,11 +109,103 @@ public static (Scalar score, Scalar probability, Scalar pred return rec.Output; } - } - internal static class LightGbmStaticsUtils { + /// + /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the . + /// + /// The . + /// The label column. + /// The features column. + /// The groupId column. + /// The weights column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// Algorithm advanced settings. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the result in any way; + /// it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted binary classification score (which will range + /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label. + public static Scalar LightGbm(this RankingContext.RankingTrainers ctx, + Scalar label, Vector features, Key groupId, Scalar weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null, + Action onFit = null) + { + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + Contracts.CheckValue(groupId, nameof(groupId)); + + var rec = new TrainerEstimatorReconciler.Ranker( + (env, labelName, featuresName, groupIdName, weightsName) => + { + var trainer = new LightGbmRankingTrainer(env, labelName, featuresName, groupIdName, weightsName, numLeaves, + minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, groupId, weights); + + return rec.Score; + } + + /// + /// Predict a target using a tree multiclass classification model trained with the . + /// + /// The multiclass classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// The weights column. + /// The number of leaves to use. + /// Number of iterations. + /// The minimal number of documents allowed in a leaf of the tree, out of the subsampled data. + /// The learning rate. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. + public static (Vector score, Key predictedLabel) + LightGbm(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + Key label, + Vector features, + Scalar weights = null, + int? numLeaves = null, + int? minDataPerLeaf = null, + double? learningRate = null, + int numBoostRound = LightGbmArguments.Defaults.NumBoostRound, + Action advancedSettings = null, + Action onFit = null) + { + CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, advancedSettings, onFit); + + var rec = new TrainerEstimatorReconciler.MulticlassClassifier( + (env, labelName, featuresName, weightsName) => + { + var trainer = new LightGbmMulticlassTrainer(env, labelName, featuresName, weightsName, numLeaves, + minDataPerLeaf, learningRate, numBoostRound, advancedSettings); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, weights); + + return rec.Output; + } - internal static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, + private static void CheckUserValues(PipelineColumn label, Vector features, Scalar weights, int? numLeaves, int? minDataPerLeaf, double? learningRate, diff --git a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs index fb4d21dbd9..0d6313a20c 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Training; using Microsoft.ML.Trainers.FastTree.Internal; @@ -26,7 +27,7 @@ internal static class LightGbmShared /// /// Base class for all training with LightGBM. /// - public abstract class LightGbmTrainerBase : TrainerEstimatorBase + public abstract class LightGbmTrainerBase : TrainerEstimatorBaseWithGroupId where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictorProducing { @@ -57,32 +58,43 @@ private sealed class CategoricalMetaData private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false, supportValid: true); public override TrainerInfo Info => _info; - private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, - string weightColumn = null, string groupIdColumn = null, Action advancedSettings = null) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + private protected LightGbmTrainerBase(IHostEnvironment env, + string name, + SchemaShape.Column label, + string featureColumn, + string weightColumn, + string groupIdColumn, + int? numLeaves, + int? minDataPerLeaf, + double? learningRate, + int numBoostRound, + Action advancedSettings) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { Args = new LightGbmArguments(); + Args.NumLeaves = numLeaves; + Args.MinDataPerLeaf = minDataPerLeaf; + Args.LearningRate = learningRate; + Args.NumBoostRound = numBoostRound; + //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - TrainerUtils.CheckArgsHaveDefaultColNames(Host, Args); - Args.LabelColumn = label.Name; Args.FeatureColumn = featureColumn; if (weightColumn != null) - Args.WeightColumn = weightColumn; + Args.WeightColumn = Optional.Explicit(weightColumn); if (groupIdColumn != null) - Args.GroupIdColumn = groupIdColumn; + Args.GroupIdColumn = Optional.Explicit(groupIdColumn); InitParallelTraining(); } private protected LightGbmTrainerBase(IHostEnvironment env, string name, LightGbmArguments args, SchemaShape.Column label) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); @@ -161,32 +173,6 @@ protected virtual void CheckDataValid(IChannel ch, RoleMappedData data) ch.CheckParam(data.Schema.Label != null, nameof(data), "Need a label column"); } - /// - /// If, after applying the advancedSettings delegate, the args are different that the default value - /// and are also different than the value supplied directly to the xtension method, warn the user - /// about which value is being used. - /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune. - /// This list should follow the one in the constructor, and the extension methods on the . - /// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation. - /// - protected void CheckArgsAndAdvancedSettingMismatch(int? numLeaves, - int? minDataPerLeaf, - double? learningRate, - int numBoostRound, - LightGbmArguments snapshot, - LightGbmArguments currentArgs) - { - using (var ch = Host.Start("Comparing advanced settings with the directly provided values.")) - { - - // Check that the user didn't supply different parameters in the args, from what it specified directly. - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numBoostRound, snapshot.NumBoostRound, currentArgs.NumBoostRound, nameof(numBoostRound)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDataPerLeaf, snapshot.MinDataPerLeaf, currentArgs.MinDataPerLeaf, nameof(minDataPerLeaf)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRate, currentArgs.LearningRate, nameof(learningRate)); - } - } - protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCategarical, int totalCats, bool hiddenMsg=false) { double learningRate = Args.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats); diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs index 564347d9c7..f94511ec76 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineCatalog.cs @@ -21,8 +21,10 @@ public static class FactorizationMachineExtensions /// The label, or dependent variable. /// The features, or independent variables. /// The optional example weights. - /// A delegate to set more settings. - /// + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationContext.BinaryClassificationTrainers ctx, string label, string[] features, string weights = null, diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs index 733c98d28a..2a95df5dd7 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineStatic.cs @@ -28,8 +28,10 @@ public static class FactorizationMachineExtensions /// Initial learning rate. /// Number of training iterations. /// Latent space dimensions. - /// A delegate to set more settings. - /// A delegate that is called every time the + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the ./// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive /// the model that was trained. Note that this action cannot change the result in any way; it is only a way for the caller to @@ -57,10 +59,11 @@ public static (Scalar score, Scalar predictedLabel) FieldAwareFacto var trainer = new FieldAwareFactorizationMachineTrainer(env, labelCol, featureCols, advancedSettings: args => { - advancedSettings?.Invoke(args); args.LearningRate = learningRate; args.Iters = numIterations; args.LatentDim = numLatentDimensions; + + advancedSettings?.Invoke(args); }); if (onFit != null) return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 3907131f6b..3c3e0ea13b 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -1433,7 +1433,10 @@ internal override void Check(IHostEnvironment env) /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public LinearClassificationTrainer(IHostEnvironment env, string featureColumn, string labelColumn, @@ -1682,21 +1685,17 @@ public StochasticGradientDescentClassificationTrainer(IHostEnvironment env, stri Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); _args = new Arguments(); - advancedSettings?.Invoke(_args); - - // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly - TrainerUtils.CheckArgsHaveDefaultColNames(Host, _args); - - if (advancedSettings != null) - CheckArgsAndAdvancedSettingMismatch(maxIterations, initLearningRate, l2Weight, loss, new Arguments(), _args); + _args.MaxIterations = maxIterations; + _args.InitLearningRate = initLearningRate; + _args.L2Weight = l2Weight; // Apply the advanced args, if the user supplied any. + advancedSettings?.Invoke(_args); + _args.FeatureColumn = featureColumn; _args.LabelColumn = labelColumn; _args.WeightColumn = weightColumn; - _args.MaxIterations = maxIterations; - _args.InitLearningRate = initLearningRate; - _args.L2Weight = l2Weight; + if (loss != null) _args.LossFunction = loss; _args.Check(env); @@ -1719,30 +1718,6 @@ internal StochasticGradientDescentClassificationTrainer(IHostEnvironment env, Ar _args = args; } - /// - /// If, after applying the advancedSettings delegate, the args are different that the default value - /// and are also different than the value supplied directly to the xtension method, warn the user - /// about which value is being used. - /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune. - /// This list should follow the one in the constructor, and the extension methods on the . - /// - internal void CheckArgsAndAdvancedSettingMismatch(int maxIterations, - double initLearningRate, - float l2Weight, - ISupportClassificationLossFactory loss, - Arguments snapshot, - Arguments currentArgs) - { - using (var ch = Host.Start("Comparing advanced settings with the directly provided values.")) - { - // Check that the user didn't supply different parameters in the args, from what it specified directly. - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, maxIterations, snapshot.MaxIterations, currentArgs.MaxIterations, nameof(maxIterations)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, initLearningRate, snapshot.InitLearningRate, currentArgs.InitLearningRate, nameof(initLearningRate)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, l2Weight, snapshot.L2Weight, currentArgs.L2Weight, nameof(l2Weight)); - TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, loss, snapshot.LossFunction, currentArgs.LossFunction, nameof(loss)); - } - } - protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) { return new[] diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index 5f7a40c5d2..048dc11d34 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -151,31 +151,49 @@ internal static class Defaults private static readonly TrainerInfo _info = new TrainerInfo(caching: true, supportIncrementalTrain: true); public override TrainerInfo Info => _info; - internal LbfgsTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, - string weightColumn, Action advancedSettings, float l1Weight, + internal LbfgsTrainerBase(IHostEnvironment env, + string featureColumn, + SchemaShape.Column labelColumn, + string weightColumn, + Action advancedSettings, + float l1Weight, float l2Weight, float optimizationTolerance, int memorySize, bool enforceNoNegativity) - : this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings), labelColumn, - l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) + : this(env, new TArgs + { + FeatureColumn = featureColumn, + LabelColumn = labelColumn.Name, + WeightColumn = weightColumn ?? Optional.Explicit(weightColumn), + L1Weight = l1Weight, + L2Weight = l2Weight, + OptTol = optimizationTolerance, + MemorySize = memorySize, + EnforceNonNegativity = enforceNoNegativity + }, + labelColumn, advancedSettings) { } - internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn, - float? l1Weight = null, - float? l2Weight = null, - float? optimizationTolerance = null, - int? memorySize = null, - bool? enforceNoNegativity = null) + internal LbfgsTrainerBase(IHostEnvironment env, + TArgs args, + SchemaShape.Column labelColumn, + Action advancedSettings = null) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) + labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit)) { Host.CheckValue(args, nameof(args)); Args = args; + // Apply the advanced args, if the user supplied any. + advancedSettings?.Invoke(args); + + args.FeatureColumn = FeatureColumn.Name; + args.LabelColumn = LabelColumn.Name; + args.WeightColumn = WeightColumn?.Name; Host.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null, - nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); + nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); Host.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative"); Host.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative"); Host.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive"); @@ -184,16 +202,15 @@ internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column l Host.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative"); Host.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative"); - Host.CheckParam(!(l2Weight < 0), nameof(l2Weight), "Must be non-negative, if provided."); - Host.CheckParam(!(l1Weight < 0), nameof(l1Weight), "Must be non-negative, if provided"); - Host.CheckParam(!(optimizationTolerance <= 0), nameof(optimizationTolerance), "Must be positive, if provided."); - Host.CheckParam(!(memorySize <= 0), nameof(memorySize), "Must be positive, if provided."); + Host.CheckParam(!(Args.L2Weight < 0), nameof(Args.L2Weight), "Must be non-negative, if provided."); + Host.CheckParam(!(Args.L1Weight < 0), nameof(Args.L1Weight), "Must be non-negative, if provided"); + Host.CheckParam(!(Args.OptTol <= 0), nameof(Args.OptTol), "Must be positive, if provided."); + Host.CheckParam(!(Args.MemorySize <= 0), nameof(Args.MemorySize), "Must be positive, if provided."); - // Review: Warn about the overriding behavior - L2Weight = l2Weight ?? Args.L2Weight; - L1Weight = l1Weight ?? Args.L1Weight; - OptTol = optimizationTolerance ?? Args.OptTol; - MemorySize = memorySize ?? Args.MemorySize; + L2Weight = Args.L2Weight; + L1Weight = Args.L1Weight; + OptTol = Args.OptTol; + MemorySize =Args.MemorySize; MaxIterations = Args.MaxIterations; SgdInitializationTolerance = Args.SgdInitializationTolerance; Quiet = Args.Quiet; @@ -201,7 +218,7 @@ internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column l UseThreads = Args.UseThreads; NumThreads = Args.NumThreads; DenseOptimizer = Args.DenseOptimizer; - EnforceNonNegativity = enforceNoNegativity ?? Args.EnforceNonNegativity; + EnforceNonNegativity = Args.EnforceNonNegativity; if (EnforceNonNegativity && ShowTrainingStats) { @@ -217,14 +234,25 @@ internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column l } private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn, - string weightColumn, Action advancedSettings) + string weightColumn, + float l1Weight, + float l2Weight, + float optimizationTolerance, + int memorySize, + bool enforceNoNegativity) { - var args = new TArgs(); + var args = new TArgs + { + FeatureColumn = featureColumn, + LabelColumn = labelColumn.Name, + WeightColumn = weightColumn ?? Optional.Explicit(weightColumn), + L1Weight = l1Weight, + L2Weight = l2Weight, + OptTol = optimizationTolerance, + MemorySize = memorySize, + EnforceNonNegativity = enforceNoNegativity + }; - // Apply the advanced args, if the user supplied any. - advancedSettings?.Invoke(args); - args.FeatureColumn = featureColumn; - args.LabelColumn = labelColumn.Name; args.WeightColumn = weightColumn; return args; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 5e2003f3e1..70e896feef 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -91,7 +91,7 @@ public MulticlassLogisticRegression(IHostEnvironment env, string featureColumn, int memorySize = Arguments.Defaults.MemorySize, bool enforceNoNegativity = Arguments.Defaults.EnforceNonNegativity, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), weightColumn, advancedSettings, + : base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), weightColumn, advancedSettings, l1Weight, l2Weight, optimizationTolerance, memorySize, enforceNoNegativity) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -104,7 +104,7 @@ public MulticlassLogisticRegression(IHostEnvironment env, string featureColumn, /// Initializes a new instance of /// internal MulticlassLogisticRegression(IHostEnvironment env, Arguments args) - : base(env, args, TrainerUtils.MakeU4ScalarLabel(args.LabelColumn)) + : base(env, args, TrainerUtils.MakeU4ScalarColumn(args.LabelColumn)) { ShowTrainingStats = Args.ShowTrainingStats; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index 1713affce5..72dd984328 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -53,7 +53,7 @@ public sealed class Arguments : LearnerInputBaseWithLabel /// The name of the feature column. public MultiClassNaiveBayesTrainer(IHostEnvironment env, string featureColumn, string labelColumn) : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(featureColumn), - TrainerUtils.MakeU4ScalarLabel(labelColumn)) + TrainerUtils.MakeU4ScalarColumn(labelColumn)) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); @@ -64,7 +64,7 @@ public MultiClassNaiveBayesTrainer(IHostEnvironment env, string featureColumn, s /// internal MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args) : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), - TrainerUtils.MakeU4ScalarLabel(args.LabelColumn)) + TrainerUtils.MakeU4ScalarColumn(args.LabelColumn)) { Host.CheckValue(args, nameof(args)); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs index 517d79e7b8..6e3ebe5fa7 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaCatalog.cs @@ -26,9 +26,12 @@ public static class SdcaRegressionExtensions /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. /// The custom loss, if unspecified will be . - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static SdcaRegressionTrainer StochasticDualCoordinateAscent(this RegressionContext.RegressionTrainers ctx, - string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, string weights = null, + string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, string weights = null, ISupportSdcaRegressionLoss loss = null, float? l2Const = null, float? l1Threshold = null, @@ -54,7 +57,10 @@ public static class SdcaBinaryClassificationExtensions /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// /// /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public static SdcaMultiClassTrainer StochasticDualCoordinateAscent(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, string label = DefaultColumnNames.Label, string features = DefaultColumnNames.Features, diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs index 3b774e261c..0ec4ccd878 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs @@ -58,7 +58,10 @@ public sealed class Arguments : ArgumentsBase /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public SdcaMultiClassTrainer(IHostEnvironment env, string featureColumn, string labelColumn, @@ -68,7 +71,7 @@ public SdcaMultiClassTrainer(IHostEnvironment env, float? l1Threshold = null, int? maxIterations = null, Action advancedSettings = null) - : base(env, featureColumn, TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, + : base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), advancedSettings, l2Const, l1Threshold, maxIterations) { Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -79,7 +82,7 @@ public SdcaMultiClassTrainer(IHostEnvironment env, internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null) - : base(env, args, TrainerUtils.MakeU4ScalarLabel(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) + : base(env, args, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Host.CheckValue(labelColumn, nameof(labelColumn)); Host.CheckValue(featureColumn, nameof(featureColumn)); diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs index 84e1241a69..9d3e1205cc 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs @@ -63,7 +63,10 @@ public Arguments() /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . public SdcaRegressionTrainer(IHostEnvironment env, string featureColumn, string labelColumn, diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs index d8e15cc460..b803c74d36 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs @@ -15,7 +15,7 @@ namespace Microsoft.ML.StaticPipe /// /// Extension methods and utilities for instantiating SDCA trainer estimators inside statically typed pipelines. /// - public static class SdcaRegressionExtensions + public static class SdcaExtensions { /// /// Predict a target using a linear regression model trained with the SDCA trainer. @@ -28,7 +28,10 @@ public static class SdcaRegressionExtensions /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. /// The custom loss, if unspecified will be . - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -71,10 +74,6 @@ public static Scalar Sdca(this RegressionContext.RegressionTrainers ctx, return rec.Score; } - } - - public static class SdcaBinaryClassificationExtensions - { /// /// Predict a target using a linear binary classification model trained with the SDCA trainer, and log-loss. @@ -86,7 +85,10 @@ public static class SdcaBinaryClassificationExtensions /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -146,7 +148,10 @@ public static (Scalar score, Scalar probability, Scalar pred /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive @@ -199,10 +204,6 @@ public static (Scalar score, Scalar predictedLabel) Sdca( return rec.Output; } - } - - public static class SdcaMulticlassExtensions - { /// /// Predict a target using a linear multiclass classification model trained with the SDCA trainer. @@ -215,7 +216,10 @@ public static class SdcaMulticlassExtensions /// The L2 regularization hyperparameter. /// The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model. /// The maximum number of passes to perform over the data. - /// A delegate to set more settings. + /// A delegate to set more settings. + /// The settings here will override the ones provided in the direct method signature, + /// if both are present and have different values. + /// The columns names, however need to be provided directly, not through the . /// A delegate that is called every time the /// method is called on the /// instance created out of this. This delegate will receive diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 458f777d2b..e689384392 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -754,6 +754,85 @@ public void FastTreeRanking() Assert.InRange(metrics.Ndcg[2], 36.5, 37); } + [Fact] + public void LightGBMRanking() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.adultRanking.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new RankingContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(0), features: c.LoadFloat(9, 14), groupId: c.LoadText(1)), + separator: '\t', hasHeader: true); + + LightGbmRankingPredictor pred = null; + + var est = reader.MakeNewEstimator() + .Append(r => (r.label, r.features, groupId: r.groupId.ToKey())) + .Append(r => (r.label, r.groupId, score: ctx.Trainers.LightGbm(r.label, r.features, r.groupId, onFit: (p) => { pred = p; }))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + + var data = model.Read(dataSource); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.groupId, r => r.score); + Assert.NotNull(metrics); + + Assert.True(metrics.Ndcg.Length == metrics.Dcg.Length && metrics.Dcg.Length == 3); + + Assert.InRange(metrics.Dcg[0], 1.4, 1.6); + Assert.InRange(metrics.Dcg[1], 1.4, 1.8); + Assert.InRange(metrics.Dcg[2], 1.4, 1.8); + + Assert.InRange(metrics.Ndcg[0], 36.5, 37); + Assert.InRange(metrics.Ndcg[1], 36.5, 37); + Assert.InRange(metrics.Ndcg[2], 36.5, 37); + } + + [Fact] + public void MultiClassLightGBM() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.iris.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new MulticlassClassificationContext(env); + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); + + OvaPredictor pred = null; + + // With a custom loss function we no longer get calibrated predictions. + var est = reader.MakeNewEstimator() + .Append(r => (label: r.label.ToKey(), r.features)) + .Append(r => (r.label, preds: ctx.Trainers.LightGbm( + r.label, + r.features, onFit: p => pred = p))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + + var data = model.Read(dataSource); + + // Just output some data on the schema for fun. + var schema = data.AsDynamic.Schema; + for (int c = 0; c < schema.ColumnCount; ++c) + Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}"); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2); + Assert.True(metrics.LogLoss > 0); + Assert.True(metrics.TopKAccuracy > 0); + } + [Fact] public void MultiClassNaiveBayesTrainer() {