From 96703e39cb8280af6642d29129deb2b01d2bcf3f Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Fri, 11 May 2018 10:19:57 -0700 Subject: [PATCH 1/5] Fix a bug in Tree leaf featurizer entry point, and add a test for it. --- .../TreeEnsembleFeaturizer.cs | 14 ++--- .../Microsoft.ML.Core.Tests.csproj | 1 + .../UnitTests/TestEntryPoints.cs | 53 ++++++++++++++++++- 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index e173e01cbe..253d76d654 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -703,10 +703,12 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments using (var ch = host.Start("Create Tree Ensemble Scorer")) { var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments() { Suffix = args.Suffix }; - var predictor = args.PredictorModel?.Predictor; + var predictor = args.PredictorModel.Predictor; ch.Trace("Prepare data"); RoleMappedData data = null; - args.PredictorModel?.PrepareData(env, input, out data, out var predictor2); + args.PredictorModel.PrepareData(env, input, out data, out var predictor2); + ch.AssertValue(data); + ch.Assert(predictor == predictor2); // Make sure that the given predictor has the correct number of input features. if (predictor is CalibratedPredictorBase) @@ -715,16 +717,16 @@ public static IDataTransform CreateForEntryPoint(IHostEnvironment env, Arguments // be non-null. var vm = predictor as IValueMapper; ch.CheckUserArg(vm != null, nameof(args.PredictorModel), "Predictor does not have compatible type"); - if (data != null && vm?.InputType.VectorSize != data.Schema.Feature.Type.VectorSize) + if (data != null && vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize) { throw ch.ExceptUserArg(nameof(args.PredictorModel), "Predictor expects {0} features, but data has {1} features", - vm?.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); + vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize); } var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor); - var bound = bindable.Bind(env, data?.Schema); - xf = new GenericScorer(env, scorerArgs, input, bound, data?.Schema); + var bound = bindable.Bind(env, data.Schema); + xf = new GenericScorer(env, scorerArgs, data.Data, bound, data.Schema); ch.Done(); } return xf; diff --git a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj index 028e5bebd9..0b74998c12 100644 --- a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj +++ b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj @@ -17,6 +17,7 @@ + \ No newline at end of file diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 4fe503b9b7..04d40e345c 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -6,13 +6,13 @@ using System.Collections.Generic; using System.IO; using System.Linq; -using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Core.Tests.UnitTests; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.EntryPoints.JsonUtils; +using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Learners; using Newtonsoft.Json; @@ -2521,5 +2521,56 @@ public void EntryPointPrepareLabelConvertPredictedLabel() } } } + + [Fact] + public void EntryPointTreeLeafFeaturizer() + { + var dataPath = GetDataPath(@"adult.tiny.with-schema.txt"); + var inputFile = new SimpleFileHandle(Env, dataPath, false, false); + var dataView = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFile }).Data; + var cat = Categorical.CatTransformDict(Env, new CategoricalTransform.Arguments() + { + Data = dataView, + Column = new[] { new CategoricalTransform.Column { Name = "Categories", Source = "Categories" } } + }); + var concat = SchemaManipulation.ConcatColumns(Env, new ConcatTransform.Arguments() + { + Data = cat.OutputData, + Column = new[] { new ConcatTransform.Column { Name = "Features", Source = new[] { "Categories", "NumericFeatures" } } } + }); + + var fastTree = FastTree.FastTree.TrainBinary(Env, new FastTreeBinaryClassificationTrainer.Arguments + { + FeatureColumn = "Features", + NumLeaves = 80, + LabelColumn = DefaultColumnNames.Label, + TrainingData = concat.OutputData + }); + + var combine = ModelOperations.CombineModels(Env, new ModelOperations.PredictorModelInput() + { + PredictorModel = fastTree.PredictorModel, + TransformModels = new[] { cat.Model, concat.Model } + }); + + var treeLeaf = TreeFeaturize.Featurizer(Env, new TreeEnsembleFeaturizerTransform.ArgumentsForEntryPoint + { + Data = dataView, + PredictorModel = combine.PredictorModel + }); + + var view = treeLeaf.OutputData; + Assert.True(view.Schema.TryGetColumnIndex("Trees", out int col)); + VBuffer features = default(VBuffer); + using (var curs = view.GetRowCursor(c => c == col)) + { + var getter = curs.GetGetter>(col); + while (curs.MoveNext()) + { + getter(ref features); + Assert.True(features.Count > 0); + } + } + } } } \ No newline at end of file From 86d2a155d573d2b97f5c54165af2c912c7e5b73c Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 15 May 2018 11:00:05 -0700 Subject: [PATCH 2/5] Improve unit test --- .../UnitTests/TestEntryPoints.cs | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index a1e1bf8341..cf0a99072d 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -2560,15 +2560,27 @@ public void EntryPointTreeLeafFeaturizer() }); var view = treeLeaf.OutputData; - Assert.True(view.Schema.TryGetColumnIndex("Trees", out int col)); - VBuffer features = default(VBuffer); - using (var curs = view.GetRowCursor(c => c == col)) + Assert.True(view.Schema.TryGetColumnIndex("Trees", out int treesCol)); + Assert.True(view.Schema.TryGetColumnIndex("Leaves", out int leavesCol)); + Assert.True(view.Schema.TryGetColumnIndex("Paths", out int pathsCol)); + VBuffer treeValues = default(VBuffer); + VBuffer leafIndicators = default(VBuffer); + VBuffer pathIndicators = default(VBuffer); + using (var curs = view.GetRowCursor(c => c == treesCol || c == leavesCol || c == pathsCol)) { - var getter = curs.GetGetter>(col); + var treesGetter = curs.GetGetter>(treesCol); + var leavesGetter = curs.GetGetter>(leavesCol); + var pathsGetter = curs.GetGetter>(pathsCol); while (curs.MoveNext()) { - getter(ref features); - Assert.True(features.Count > 0); + treesGetter(ref treeValues); + leavesGetter(ref leafIndicators); + pathsGetter(ref pathIndicators); + Assert.True(treeValues.Length == 100); + Assert.True(treeValues.Count == 100); + Assert.True(leafIndicators.Length == 3955); + Assert.True(leafIndicators.Count == 100); + Assert.True(pathIndicators.Length == 3855); } } } From 431ca8e3de9bb7c58c04d4ac92c8612d74482817 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 15 May 2018 11:26:13 -0700 Subject: [PATCH 3/5] Update unit test --- .../UnitTests/TestEntryPoints.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index cf0a99072d..a2e28756bb 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -2576,11 +2576,11 @@ public void EntryPointTreeLeafFeaturizer() treesGetter(ref treeValues); leavesGetter(ref leafIndicators); pathsGetter(ref pathIndicators); - Assert.True(treeValues.Length == 100); - Assert.True(treeValues.Count == 100); - Assert.True(leafIndicators.Length == 3955); - Assert.True(leafIndicators.Count == 100); - Assert.True(pathIndicators.Length == 3855); + Assert.Equal(100, treeValues.Length); + Assert.Equal(100, treeValues.Count); + Assert.Equal(3955, leafIndicators.Length); + Assert.Equal(100, leafIndicators.Count); + Assert.Equal(3855, pathIndicators.Length); } } } From fb7e5e5eefffc72c68f8a87131f8aa0be5238e77 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 15 May 2018 13:05:15 -0700 Subject: [PATCH 4/5] Decrease number of trees and leaves in unit test --- .../UnitTests/TestEntryPoints.cs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index a2e28756bb..ad8ab1bbe0 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -2542,7 +2542,8 @@ public void EntryPointTreeLeafFeaturizer() var fastTree = FastTree.FastTree.TrainBinary(Env, new FastTreeBinaryClassificationTrainer.Arguments { FeatureColumn = "Features", - NumLeaves = 80, + NumTrees = 5, + NumLeaves = 4, LabelColumn = DefaultColumnNames.Label, TrainingData = concat.OutputData }); @@ -2576,11 +2577,11 @@ public void EntryPointTreeLeafFeaturizer() treesGetter(ref treeValues); leavesGetter(ref leafIndicators); pathsGetter(ref pathIndicators); - Assert.Equal(100, treeValues.Length); - Assert.Equal(100, treeValues.Count); - Assert.Equal(3955, leafIndicators.Length); - Assert.Equal(100, leafIndicators.Count); - Assert.Equal(3855, pathIndicators.Length); + Assert.Equal(5, treeValues.Length); + Assert.Equal(5, treeValues.Count); + Assert.Equal(20, leafIndicators.Length); + Assert.Equal(5, leafIndicators.Count); + Assert.Equal(15, pathIndicators.Length); } } } From eec515ed50646160b3f5f0465ee2db88667e0dc5 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 15 May 2018 13:19:27 -0700 Subject: [PATCH 5/5] Trigger build --- test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index ad8ab1bbe0..e8be6c0370 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -2577,6 +2577,7 @@ public void EntryPointTreeLeafFeaturizer() treesGetter(ref treeValues); leavesGetter(ref leafIndicators); pathsGetter(ref pathIndicators); + Assert.Equal(5, treeValues.Length); Assert.Equal(5, treeValues.Count); Assert.Equal(20, leafIndicators.Length);