Skip to content

Commit 0831865

Browse files
authored
Fixing ModelParameter discrepancies (#2968)
* fixing model parameter discrepencies * multiclass LR singe that refactoring is happening in a parallel PR * review comments. Added Multiclass to NaiveBayes * Drop Classification from trainer names - v1 (more trainers to follow) * multiclass LR will be handled separately * Drop Classification from trainer names - v2 (all trainers taken care of) * fix entrypoint file
1 parent 71693b3 commit 0831865

File tree

66 files changed

+436
-436
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+436
-436
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ private class BinaryOutputRow
4747
private readonly static Action<ContinuousInputRow, BinaryOutputRow> GreaterThanAverage = (input, output)
4848
=> output.AboveAverage = input.MedianHomeValue > 22.6;
4949

50-
public static float[] GetLinearModelWeights(OrdinaryLeastSquaresRegressionModelParameters linearModel)
50+
public static float[] GetLinearModelWeights(OlsModelParameters linearModel)
5151
{
5252
return linearModel.Weights.ToArray();
5353
}

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticDualCoordinateAscent.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public static void Example()
6161
// we could do so by tweaking the 'advancedSetting'.
6262
var advancedPipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
6363
.Append(mlContext.BinaryClassification.Trainers.SdcaCalibrated(
64-
new SdcaCalibratedBinaryClassificationTrainer.Options {
64+
new SdcaCalibratedBinaryTrainer.Options {
6565
LabelColumnName = "Sentiment",
6666
FeatureColumnName = "Features",
6767
ConvergenceTolerance = 0.01f, // The learning rate for adjusting bias from being regularized

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticDualCoordinateAscentWithOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public static void Example()
2222
var trainTestData = mlContext.Data.TrainTestSplit(data, testFraction: 0.1);
2323

2424
// Define the trainer options.
25-
var options = new SdcaCalibratedBinaryClassificationTrainer.Options()
25+
var options = new SdcaCalibratedBinaryTrainer.Options()
2626
{
2727
// Make the convergence tolerance tighter.
2828
ConvergenceTolerance = 0.05f,

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public static void Example()
2626
// CC 1.216908,1.248052,1.391902,0.4326252,1.099942,0.9262842,1.334019,1.08762,0.9468155,0.4811099
2727
// DD 0.7871246,1.053327,0.8971719,1.588544,1.242697,1.362964,0.6303943,0.9810045,0.9431419,1.557455
2828

29-
var options = new SdcaMulticlassClassificationTrainer.Options
29+
var options = new SdcaMulticlassTrainer.Options
3030
{
3131
// Add custom loss
3232
LossFunction = new HingeLoss(),

src/Microsoft.ML.FastTree/FastTreeArguments.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
using Microsoft.ML.Runtime;
1010
using Microsoft.ML.Trainers.FastTree;
1111

12-
[assembly: EntryPointModule(typeof(FastTreeBinaryClassificationTrainer.Options))]
12+
[assembly: EntryPointModule(typeof(FastTreeBinaryTrainer.Options))]
1313
[assembly: EntryPointModule(typeof(FastTreeRegressionTrainer.Options))]
1414
[assembly: EntryPointModule(typeof(FastTreeTweedieTrainer.Options))]
1515
[assembly: EntryPointModule(typeof(FastTreeRankingTrainer.Options))]
@@ -52,10 +52,10 @@ public enum EarlyStoppingRankingMetric
5252
}
5353

5454
// XML docs are provided in the other part of this partial class. No need to duplicate the content here.
55-
public sealed partial class FastTreeBinaryClassificationTrainer
55+
public sealed partial class FastTreeBinaryTrainer
5656
{
5757
/// <summary>
58-
/// Options for the <see cref="FastTreeBinaryClassificationTrainer"/>.
58+
/// Options for the <see cref="FastTreeBinaryTrainer"/>.
5959
/// </summary>
6060
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
6161
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
@@ -102,7 +102,7 @@ public Options()
102102
EarlyStoppingMetric = EarlyStoppingMetric.L1Norm;
103103
}
104104

105-
ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeBinaryClassificationTrainer(env, this);
105+
ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeBinaryTrainer(env, this);
106106
}
107107
}
108108

src/Microsoft.ML.FastTree/FastTreeClassification.cs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
using Microsoft.ML.Runtime;
1414
using Microsoft.ML.Trainers.FastTree;
1515

16-
[assembly: LoadableClass(FastTreeBinaryClassificationTrainer.Summary, typeof(FastTreeBinaryClassificationTrainer), typeof(FastTreeBinaryClassificationTrainer.Options),
16+
[assembly: LoadableClass(FastTreeBinaryTrainer.Summary, typeof(FastTreeBinaryTrainer), typeof(FastTreeBinaryTrainer.Options),
1717
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) },
18-
FastTreeBinaryClassificationTrainer.UserNameValue,
19-
FastTreeBinaryClassificationTrainer.LoadNameValue,
18+
FastTreeBinaryTrainer.UserNameValue,
19+
FastTreeBinaryTrainer.LoadNameValue,
2020
"FastTreeClassification",
2121
"FastTree",
2222
"ft",
23-
FastTreeBinaryClassificationTrainer.ShortName,
23+
FastTreeBinaryTrainer.ShortName,
2424

2525
// FastRank names
2626
"FastRankBinaryClassification",
@@ -101,8 +101,8 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad
101101
/// The <see cref="IEstimator{TTransformer}"/> for training a decision tree binary classification model using FastTree.
102102
/// </summary>
103103
/// <include file='doc.xml' path='doc/members/member[@name="FastTree_remarks"]/*' />
104-
public sealed partial class FastTreeBinaryClassificationTrainer :
105-
BoostingFastTreeTrainerBase<FastTreeBinaryClassificationTrainer.Options,
104+
public sealed partial class FastTreeBinaryTrainer :
105+
BoostingFastTreeTrainerBase<FastTreeBinaryTrainer.Options,
106106
BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>,
107107
CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>
108108
{
@@ -118,7 +118,7 @@ public sealed partial class FastTreeBinaryClassificationTrainer :
118118
private double _sigmoidParameter;
119119

120120
/// <summary>
121-
/// Initializes a new instance of <see cref="FastTreeBinaryClassificationTrainer"/>
121+
/// Initializes a new instance of <see cref="FastTreeBinaryTrainer"/>
122122
/// </summary>
123123
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
124124
/// <param name="labelColumnName">The name of the label column.</param>
@@ -128,7 +128,7 @@ public sealed partial class FastTreeBinaryClassificationTrainer :
128128
/// <param name="minimumExampleCountPerLeaf">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
129129
/// <param name="numberOfLeaves">The max number of leaves in each regression tree.</param>
130130
/// <param name="numberOfTrees">Total number of decision trees to create in the ensemble.</param>
131-
internal FastTreeBinaryClassificationTrainer(IHostEnvironment env,
131+
internal FastTreeBinaryTrainer(IHostEnvironment env,
132132
string labelColumnName = DefaultColumnNames.Label,
133133
string featureColumnName = DefaultColumnNames.Features,
134134
string exampleWeightColumnName = null,
@@ -143,11 +143,11 @@ internal FastTreeBinaryClassificationTrainer(IHostEnvironment env,
143143
}
144144

145145
/// <summary>
146-
/// Initializes a new instance of <see cref="FastTreeBinaryClassificationTrainer"/> by using the <see cref="Options"/> class.
146+
/// Initializes a new instance of <see cref="FastTreeBinaryTrainer"/> by using the <see cref="Options"/> class.
147147
/// </summary>
148148
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
149149
/// <param name="options">Algorithm advanced settings.</param>
150-
internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options options)
150+
internal FastTreeBinaryTrainer(IHostEnvironment env, Options options)
151151
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName))
152152
{
153153
// Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss
@@ -278,7 +278,7 @@ private protected override BinaryPredictionTransformer<CalibratedModelParameters
278278
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);
279279

280280
/// <summary>
281-
/// Trains a <see cref="FastTreeBinaryClassificationTrainer"/> using both training and validation data, returns
281+
/// Trains a <see cref="FastTreeBinaryTrainer"/> using both training and validation data, returns
282282
/// a <see cref="BinaryPredictionTransformer{CalibratedModelParametersBase}"/>.
283283
/// </summary>
284284
public BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData)
@@ -403,18 +403,18 @@ public void AdjustTreeOutputs(IChannel ch, InternalRegressionTree tree,
403403
internal static partial class FastTree
404404
{
405405
[TlcModule.EntryPoint(Name = "Trainers.FastTreeBinaryClassifier",
406-
Desc = FastTreeBinaryClassificationTrainer.Summary,
407-
UserName = FastTreeBinaryClassificationTrainer.UserNameValue,
408-
ShortName = FastTreeBinaryClassificationTrainer.ShortName)]
409-
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastTreeBinaryClassificationTrainer.Options input)
406+
Desc = FastTreeBinaryTrainer.Summary,
407+
UserName = FastTreeBinaryTrainer.UserNameValue,
408+
ShortName = FastTreeBinaryTrainer.ShortName)]
409+
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastTreeBinaryTrainer.Options input)
410410
{
411411
Contracts.CheckValue(env, nameof(env));
412412
var host = env.Register("TrainFastTree");
413413
host.CheckValue(input, nameof(input));
414414
EntryPointUtils.CheckInputArgs(host, input);
415415

416-
return TrainerEntryPointsUtils.Train<FastTreeBinaryClassificationTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
417-
() => new FastTreeBinaryClassificationTrainer(host, input),
416+
return TrainerEntryPointsUtils.Train<FastTreeBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
417+
() => new FastTreeBinaryTrainer(host, input),
418418
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
419419
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName),
420420
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.RowGroupColumnName));

0 commit comments

Comments
 (0)