Skip to content

Commit 95019e5

Browse files
authored
Scrub Fast Tree Family (#2753)
1 parent e285bed commit 95019e5

File tree

40 files changed

+1081
-1060
lines changed

40 files changed

+1081
-1060
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,12 @@ public static void Example()
3030
// We will train a FastTreeRegression model with 1 tree on these two columns to predict Age.
3131
string outputColumnName = "Features";
3232
var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Parity", "Induced" })
33-
.Append(ml.Regression.Trainers.FastTree(labelColumnName: "Age", featureColumnName: outputColumnName, numTrees: 1, numLeaves: 2, minDatapointsInLeaves: 1));
33+
.Append(ml.Regression.Trainers.FastTree(labelColumnName: "Age", featureColumnName: outputColumnName, numberOfTrees: 1, numberOfLeaves: 2, minimumExampleCountPerLeaf: 1));
3434

3535
var model = pipeline.Fit(trainData);
3636

3737
// Get the trained model parameters.
3838
var modelParams = model.LastTransformer.Model;
39-
40-
// Let's see where an example with Parity = 1 and Induced = 1 would end up in the single trained tree.
41-
var testRow = new VBuffer<float>(2, new[] { 1.0f, 1.0f });
42-
// Use the path object to pass to GetLeaf, which will populate path with the IDs of th nodes from root to leaf.
43-
List<int> path = default;
44-
// Get the ID of the leaf this example ends up in tree 0.
45-
var leafID = modelParams.GetLeaf(0, in testRow, ref path);
46-
// Get the leaf value for this leaf ID in tree 0.
47-
var leafValue = modelParams.GetLeafValue(0, leafID);
48-
Console.WriteLine("The leaf value in tree 0 is: " + leafValue);
4939
}
5040
}
5141
}

docs/samples/Microsoft.ML.Samples/Dynamic/GeneralizedAdditiveModels.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public static void Example()
2828
.ToArray();
2929
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
3030
.Append(mlContext.Regression.Trainers.GeneralizedAdditiveModels(
31-
labelColumnName: labelName, featureColumnName: "Features", maxBins: 16));
31+
labelColumnName: labelName, featureColumnName: "Features", maxBinCountPerFeature: 16));
3232
var fitPipeline = pipeline.Fit(data);
3333

3434
// Extract the model from the pipeline
@@ -37,7 +37,7 @@ public static void Example()
3737
// Now investigate the properties of the Generalized Additive Model: The intercept and shape functions.
3838

3939
// The intercept for the GAM models represent the average prediction for the training data
40-
var intercept = gamModel.Intercept;
40+
var intercept = gamModel.Bias;
4141
// Expected output: Average predicted cost: 22.53
4242
Console.WriteLine($"Average predicted cost: {intercept:0.00}");
4343

@@ -93,7 +93,7 @@ public static void Example()
9393
// Distillation." <a href='https://arxiv.org/abs/1710.06169'>arXiv:1710.06169</a>."
9494
Console.WriteLine();
9595
Console.WriteLine("Student-Teacher Ratio");
96-
for (int i = 0; i < teacherRatioBinUpperBounds.Length; i++)
96+
for (int i = 0; i < teacherRatioBinUpperBounds.Count; i++)
9797
{
9898
Console.WriteLine($"x < {teacherRatioBinUpperBounds[i]:0.00} => {teacherRatioBinEffects[i]:0.000}");
9999
}

docs/samples/Microsoft.ML.Samples/Static/FastTreeBinaryClassification.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ public static void FastTreeBinaryClassification()
7878
Score: mlContext.BinaryClassification.Trainers.FastTree(
7979
row.Label,
8080
row.Features,
81-
numTrees: 100, // try: (int) 20-2000
82-
numLeaves: 20, // try: (int) 2-128
83-
minDatapointsInLeaves: 10, // try: (int) 1-100
81+
numberOfTrees: 100, // try: (int) 20-2000
82+
numberOfLeaves: 20, // try: (int) 2-128
83+
minimumExampleCountPerLeaf: 10, // try: (int) 1-100
8484
learningRate: 0.2))) // try: (float) 0.025-0.4
8585
.Append(row => (
8686
Label: row.Label,

docs/samples/Microsoft.ML.Samples/Static/FastTreeRegression.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ public static void FastTreeRegression()
3838
.Append(r => (r.label, score: mlContext.Regression.Trainers.FastTree(
3939
r.label,
4040
r.features,
41-
numTrees: 100, // try: (int) 20-2000
42-
numLeaves: 20, // try: (int) 2-128
43-
minDatapointsInLeaves: 10, // try: (int) 1-100
41+
numberOfTrees: 100, // try: (int) 20-2000
42+
numberOfLeaves: 20, // try: (int) 2-128
43+
minimumExampleCountPerLeaf: 10, // try: (int) 1-100
4444
learningRate: 0.2, // try: (float) 0.025-0.4
4545
onFit: p => pred = p)
4646
)

src/Microsoft.ML.FastTree/BoostingFastTree.cs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@ private protected BoostingFastTreeTrainerBase(IHostEnvironment env, TOptions opt
1818

1919
private protected BoostingFastTreeTrainerBase(IHostEnvironment env,
2020
SchemaShape.Column label,
21-
string featureColumn,
22-
string weightColumn,
23-
string groupIdColumn,
24-
int numLeaves,
25-
int numTrees,
26-
int minDatapointsInLeaves,
21+
string featureColumnName,
22+
string exampleWeightColumnName,
23+
string rowGroupColumnName,
24+
int numberOfLeaves,
25+
int numberOfTrees,
26+
int minimumExampleCountPerLeaf,
2727
double learningRate)
28-
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves)
28+
: base(env, label, featureColumnName, exampleWeightColumnName, rowGroupColumnName, numberOfLeaves, numberOfTrees, minimumExampleCountPerLeaf)
2929
{
30-
FastTreeTrainerOptions.LearningRates = learningRate;
30+
FastTreeTrainerOptions.LearningRate = learningRate;
3131
}
3232

3333
private protected override void CheckOptions(IChannel ch)
@@ -40,10 +40,10 @@ private protected override void CheckOptions(IChannel ch)
4040
if (FastTreeTrainerOptions.CompressEnsemble && FastTreeTrainerOptions.WriteLastEnsemble)
4141
throw ch.Except("Ensemble compression cannot be done when forcing to write last ensemble (hl)");
4242

43-
if (FastTreeTrainerOptions.NumLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumLeaves - 1)
43+
if (FastTreeTrainerOptions.NumberOfLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumberOfLeaves - 1)
4444
throw ch.Except("Histogram pool size (ps) must be at least 2.");
4545

46-
if (FastTreeTrainerOptions.NumLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumLeaves - 1)
46+
if (FastTreeTrainerOptions.NumberOfLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumberOfLeaves - 1)
4747
throw ch.Except("Histogram pool size (ps) must be at most numLeaves - 1.");
4848

4949
if (FastTreeTrainerOptions.EnablePruning && !HasValidSet)
@@ -61,12 +61,12 @@ private protected override void CheckOptions(IChannel ch)
6161
private protected override TreeLearner ConstructTreeLearner(IChannel ch)
6262
{
6363
return new LeastSquaresRegressionTreeLearner(
64-
TrainSet, FastTreeTrainerOptions.NumLeaves, FastTreeTrainerOptions.MinDocumentsInLeafs, FastTreeTrainerOptions.EntropyCoefficient,
64+
TrainSet, FastTreeTrainerOptions.NumberOfLeaves, FastTreeTrainerOptions.MinimumExampleCountPerLeaf, FastTreeTrainerOptions.EntropyCoefficient,
6565
FastTreeTrainerOptions.FeatureFirstUsePenalty, FastTreeTrainerOptions.FeatureReusePenalty, FastTreeTrainerOptions.SoftmaxTemperature,
66-
FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.RngSeed, FastTreeTrainerOptions.SplitFraction, FastTreeTrainerOptions.FilterZeroLambdas,
67-
FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaxCategoricalGroupsPerNode,
68-
FastTreeTrainerOptions.MaxCategoricalSplitPoints, BsrMaxTreeOutput(), ParallelTraining,
69-
FastTreeTrainerOptions.MinDocsPercentageForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinDocsForCategoricalSplit, FastTreeTrainerOptions.Bias);
66+
FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.FeatureFractionPerSplit, FastTreeTrainerOptions.FilterZeroLambdas,
67+
FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaximumCategoricalGroupCountPerNode,
68+
FastTreeTrainerOptions.MaximumCategoricalSplitPointCount, BsrMaxTreeOutput(), ParallelTraining,
69+
FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit, FastTreeTrainerOptions.Bias);
7070
}
7171

7272
private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
@@ -94,7 +94,7 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(
9494
optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch);
9595
optimizationAlgorithm.Smoothing = FastTreeTrainerOptions.Smoothing;
9696
optimizationAlgorithm.DropoutRate = FastTreeTrainerOptions.DropoutRate;
97-
optimizationAlgorithm.DropoutRng = new Random(FastTreeTrainerOptions.RngSeed);
97+
optimizationAlgorithm.DropoutRng = new Random(FastTreeTrainerOptions.Seed);
9898
optimizationAlgorithm.PreScoreUpdateEvent += PrintTestGraph;
9999

100100
return optimizationAlgorithm;
@@ -162,7 +162,7 @@ private protected override int GetBestIteration(IChannel ch)
162162
private protected double BsrMaxTreeOutput()
163163
{
164164
if (FastTreeTrainerOptions.BestStepRankingRegressionTrees)
165-
return FastTreeTrainerOptions.MaxTreeOutput;
165+
return FastTreeTrainerOptions.MaximumTreeOutput;
166166
else
167167
return -1;
168168
}

0 commit comments

Comments
 (0)