Skip to content

Commit d6bf2c4

Browse files
committed
update tests to check categorical splits
1 parent a232546 commit d6bf2c4

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,18 @@ public void TestEstimatorMulticlassNaiveBayesTrainer()
177177
return (pipeline, data);
178178
}
179179

180+
/// <summary>
181+
/// Same data as <see cref="GetBinaryClassificationPipeline"/>, but with additional
182+
/// OneHotEncoding to obtain categorical splits in tree models.
183+
/// </summary>
184+
private (IEstimator<ITransformer>, IDataView) GetOneHotBinaryClassificationPipeline()
185+
{
186+
var (pipeline, data) = GetBinaryClassificationPipeline();
187+
var oneHotPipeline = pipeline.Append(ML.Transforms.Categorical.OneHotEncoding("Features"));
188+
189+
return (oneHotPipeline, data);
190+
}
191+
180192

181193
private (IEstimator<ITransformer>, IDataView) GetRankingPipeline()
182194
{

test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -685,9 +685,9 @@ private void CheckSummary(ICanGetSummaryAsIDataView modelParameters, double bias
685685
IEnumerable<SummaryDataRow> summaryDataEnumerable;
686686

687687
if (quantileTrees == null)
688-
summaryDataEnumerable = ML.Data.CreateEnumerable<SummaryDataRow>(summaryDataView, true);
688+
summaryDataEnumerable = ML.Data.CreateEnumerable<SummaryDataRow>(summaryDataView, false);
689689
else
690-
summaryDataEnumerable = ML.Data.CreateEnumerable<QuantileTestSummaryDataRow>(summaryDataView, true);
690+
summaryDataEnumerable = ML.Data.CreateEnumerable<QuantileTestSummaryDataRow>(summaryDataView, false);
691691

692692
var summaryDataEnumerator = summaryDataEnumerable.GetEnumerator();
693693

@@ -720,7 +720,7 @@ public void FastTreeRegressorTestSummary()
720720
{
721721
var dataView = GetRegressionPipeline();
722722
var trainer = ML.Regression.Trainers.FastTree(
723-
new FastTreeRegressionTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5, CategoricalSplit = true});
723+
new FastTreeRegressionTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5});
724724

725725
var transformer = trainer.Fit(dataView);
726726

@@ -738,7 +738,7 @@ public void FastForestRegressorTestSummary()
738738
{
739739
var dataView = GetRegressionPipeline();
740740
var trainer = ML.Regression.Trainers.FastForest(
741-
new FastForestRegressionTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5, CategoricalSplit = true });
741+
new FastForestRegressionTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5});
742742

743743
var transformer = trainer.Fit(dataView);
744744

@@ -756,7 +756,7 @@ public void FastTreeTweedieRegressorTestSummary()
756756
{
757757
var dataView = GetRegressionPipeline();
758758
var trainer = ML.Regression.Trainers.FastTreeTweedie(
759-
new FastTreeTweedieTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5, CategoricalSplit = true });
759+
new FastTreeTweedieTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5});
760760

761761
var transformer = trainer.Fit(dataView);
762762

@@ -774,7 +774,7 @@ public void LightGbmRegressorTestSummary()
774774
{
775775
var dataView = GetRegressionPipeline();
776776
var trainer = ML.Regression.Trainers.LightGbm(
777-
new LightGbmRegressionTrainer.Options { NumberOfIterations = 10, NumberOfThreads = 1, NumberOfLeaves = 5, UseCategoricalSplit = true });
777+
new LightGbmRegressionTrainer.Options { NumberOfIterations = 10, NumberOfThreads = 1, NumberOfLeaves = 5});
778778

779779
var transformer = trainer.Fit(dataView);
780780

@@ -792,7 +792,7 @@ public void FastTreeBinaryClassificationTestSummary()
792792
{
793793
var (pipeline, dataView) = GetBinaryClassificationPipeline();
794794
var estimator = pipeline.Append(ML.BinaryClassification.Trainers.FastTree(
795-
new FastTreeBinaryTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5, CategoricalSplit = true }));
795+
new FastTreeBinaryTrainer.Options { NumberOfTrees = 2, NumberOfThreads = 1, NumberOfLeaves = 5}));
796796

797797
var transformer = estimator.Fit(dataView);
798798

@@ -808,9 +808,9 @@ public void FastTreeBinaryClassificationTestSummary()
808808
[Fact]
809809
public void FastForestBinaryClassificationTestSummary()
810810
{
811-
var (pipeline, dataView) = GetBinaryClassificationPipeline();
811+
var (pipeline, dataView) = GetOneHotBinaryClassificationPipeline();
812812
var estimator = pipeline.Append(ML.BinaryClassification.Trainers.FastForest(
813-
new FastForestBinaryTrainer.Options { NumberOfTrees = 10, NumberOfThreads = 1, NumberOfLeaves = 5, CategoricalSplit = true }));
813+
new FastForestBinaryTrainer.Options { NumberOfTrees = 2, NumberOfThreads = 1, NumberOfLeaves = 4, CategoricalSplit = true }));
814814

815815
var transformer = estimator.Fit(dataView);
816816

@@ -826,7 +826,7 @@ public void FastForestBinaryClassificationTestSummary()
826826
[LightGBMFact]
827827
public void LightGbmBinaryClassificationTestSummary()
828828
{
829-
var (pipeline, dataView) = GetBinaryClassificationPipeline();
829+
var (pipeline, dataView) = GetOneHotBinaryClassificationPipeline();
830830
var trainer = pipeline.Append(ML.BinaryClassification.Trainers.LightGbm(
831831
new LightGbmBinaryTrainer.Options { NumberOfIterations = 10, NumberOfThreads = 1, NumberOfLeaves = 5, UseCategoricalSplit = true }));
832832

0 commit comments

Comments
 (0)