Skip to content

Commit 78f4f86

Browse files
committed
update tests to check categorical splits
1 parent 791f31d commit 78f4f86

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,18 @@ public void TestEstimatorMulticlassNaiveBayesTrainer()
177177
return (pipeline, data);
178178
}
179179

180+
/// <summary>
181+
/// Same data as <see cref="GetBinaryClassificationPipeline"/>, but with additional
182+
/// OneHotEncoding to obtain categorical splits in tree models.
183+
/// </summary>
184+
private (IEstimator<ITransformer>, IDataView) GetOneHotBinaryClassificationPipeline()
185+
{
186+
var (pipeline, data) = GetBinaryClassificationPipeline();
187+
var oneHotPipeline = pipeline.Append(ML.Transforms.Categorical.OneHotEncoding("Features"));
188+
189+
return (oneHotPipeline, data);
190+
}
191+
180192

181193
private (IEstimator<ITransformer>, IDataView) GetRankingPipeline()
182194
{

test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -775,9 +775,9 @@ private void CheckSummary(ICanGetSummaryAsIDataView modelParameters, double bias
775775
IEnumerable<SummaryDataRow> summaryDataEnumerable;
776776

777777
if (quantileTrees == null)
778-
summaryDataEnumerable = ML.Data.CreateEnumerable<SummaryDataRow>(summaryDataView, true);
778+
summaryDataEnumerable = ML.Data.CreateEnumerable<SummaryDataRow>(summaryDataView, false);
779779
else
780-
summaryDataEnumerable = ML.Data.CreateEnumerable<QuantileTestSummaryDataRow>(summaryDataView, true);
780+
summaryDataEnumerable = ML.Data.CreateEnumerable<QuantileTestSummaryDataRow>(summaryDataView, false);
781781

782782
var summaryDataEnumerator = summaryDataEnumerable.GetEnumerator();
783783

@@ -810,7 +810,7 @@ public void FastTreeRegressorTestSummary()
810810
{
811811
var dataView = GetRegressionPipeline();
812812
var trainer = ML.Regression.Trainers.FastTree(
813-
new FastTreeRegressionTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5, CategoricalSplit = true});
813+
new FastTreeRegressionTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5});
814814

815815
var transformer = trainer.Fit(dataView);
816816

@@ -828,7 +828,7 @@ public void FastForestRegressorTestSummary()
828828
{
829829
var dataView = GetRegressionPipeline();
830830
var trainer = ML.Regression.Trainers.FastForest(
831-
new FastForestRegressionTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5, CategoricalSplit = true });
831+
new FastForestRegressionTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5});
832832

833833
var transformer = trainer.Fit(dataView);
834834

@@ -846,7 +846,7 @@ public void FastTreeTweedieRegressorTestSummary()
846846
{
847847
var dataView = GetRegressionPipeline();
848848
var trainer = ML.Regression.Trainers.FastTreeTweedie(
849-
new FastTreeTweedieTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5, CategoricalSplit = true });
849+
new FastTreeTweedieTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5});
850850

851851
var transformer = trainer.Fit(dataView);
852852

@@ -864,7 +864,7 @@ public void LightGbmRegressorTestSummary()
864864
{
865865
var dataView = GetRegressionPipeline();
866866
var trainer = ML.Regression.Trainers.LightGbm(
867-
new LightGbmRegressionTrainer.Options { NumberOfIterations = 10, NumberOfThreads = 1, NumberOfLeaves = 5, UseCategoricalSplit = true });
867+
new LightGbmRegressionTrainer.Options { NumberOfIterations = 10, NumberOfThreads = 1, NumberOfLeaves = 5});
868868

869869
var transformer = trainer.Fit(dataView);
870870

@@ -882,7 +882,7 @@ public void FastTreeBinaryClassificationTestSummary()
882882
{
883883
var (pipeline, dataView) = GetBinaryClassificationPipeline();
884884
var estimator = pipeline.Append(ML.BinaryClassification.Trainers.FastTree(
885-
new FastTreeBinaryTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5, CategoricalSplit = true }));
885+
new FastTreeBinaryTrainer.Options { NumberOfTrees = 2, NumberOfThreads = 1, NumberOfLeaves = 5}));
886886

887887
var transformer = estimator.Fit(dataView);
888888

@@ -898,9 +898,9 @@ public void FastTreeBinaryClassificationTestSummary()
898898
[Fact]
899899
public void FastForestBinaryClassificationTestSummary()
900900
{
901-
var (pipeline, dataView) = GetBinaryClassificationPipeline();
901+
var (pipeline, dataView) = GetOneHotBinaryClassificationPipeline();
902902
var estimator = pipeline.Append(ML.BinaryClassification.Trainers.FastForest(
903-
new FastForestBinaryTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5, CategoricalSplit = true }));
903+
new FastForestBinaryTrainer.Options { NumberOfTrees = 2, NumberOfThreads = 1, NumberOfLeaves = 4, CategoricalSplit = true }));
904904

905905
var transformer = estimator.Fit(dataView);
906906

@@ -916,7 +916,7 @@ public void FastForestBinaryClassificationTestSummary()
916916
[LightGBMFact]
917917
public void LightGbmBinaryClassificationTestSummary()
918918
{
919-
var (pipeline, dataView) = GetBinaryClassificationPipeline();
919+
var (pipeline, dataView) = GetOneHotBinaryClassificationPipeline();
920920
var trainer = pipeline.Append(ML.BinaryClassification.Trainers.LightGbm(
921921
new LightGbmBinaryTrainer.Options { NumberOfIterations = 10, NumberOfThreads = 1, NumberOfLeaves = 5, UseCategoricalSplit = true }));
922922

0 commit comments

Comments
 (0)