diff --git a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs index 4fef5699c0..78056c4e47 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Microsoft.ML.CommandLine; using Microsoft.ML.Data; @@ -124,7 +125,7 @@ private protected OptionsBase() { } /// [Argument(ArgumentType.Multiple, HelpText = "Which booster to use, can be gbtree, gblinear or dart. gbtree and dart use tree based model while gblinear uses linear function.", - Name="Booster", + Name = "Booster", SortOrder = 3)] internal IBoosterParameterFactory BoosterFactory = new GradientBooster.Options(); @@ -285,6 +286,7 @@ private sealed class CategoricalMetaData public int[] OnehotIndices; public int[] OnehotBias; public bool[] IsCategoricalFeature; + public int[] CatIndices; } // Contains the passed in options when the API is called @@ -317,16 +319,16 @@ private protected LightGbmTrainerBase(IHostEnvironment env, double? learningRate, int numberOfIterations) : this(env, name, new TOptions() - { - NumberOfLeaves = numberOfLeaves, - MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf, - LearningRate = learningRate, - NumberOfIterations = numberOfIterations, - LabelColumnName = labelColumn.Name, - FeatureColumnName = featureColumnName, - ExampleWeightColumnName = exampleWeightColumnName, - RowGroupColumnName = rowGroupColumnName - }, + { + NumberOfLeaves = numberOfLeaves, + MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf, + LearningRate = learningRate, + NumberOfIterations = numberOfIterations, + LabelColumnName = labelColumn.Name, + FeatureColumnName = featureColumnName, + ExampleWeightColumnName = exampleWeightColumnName, + RowGroupColumnName = rowGroupColumnName + }, labelColumn) { } @@ -384,7 +386,7 @@ private protected override TModel TrainModelCore(TrainContext context) return CreatePredictor(); } - private protected virtual void InitializeBeforeTraining(){} + private protected virtual void InitializeBeforeTraining() { } private void InitParallelTraining() { @@ -522,6 +524,7 @@ private static List ConstructCategoricalFeatureMetaData(int[] categorica ++j; } } + catMetaData.CatIndices = catIndices.Select(int.Parse).ToArray(); return catIndices; } @@ -761,7 +764,7 @@ private void GetFeatureValueDense(IChannel ch, FloatLabelCursor cursor, Categori } } // All-Zero is category 0. - fv = hotIdx - catMetaData.CategoricalBoudaries[i] + 1; + fv = hotIdx - catMetaData.CategoricalBoudaries[i]; } featureValuesTemp[i] = fv; } @@ -781,8 +784,9 @@ private void GetFeatureValueSparse(IChannel ch, FloatLabelCursor cursor, var cursorFeaturesIndices = cursor.Features.GetIndices(); if (catMetaData.CategoricalBoudaries != null) { - List featureIndices = new List(); - List values = new List(); + Dictionary ivPair = new Dictionary(); + foreach (var idx in catMetaData.CatIndices) + ivPair[idx] = -1; int lastIdx = -1; int nhot = 0; for (int i = 0; i < cursorFeaturesValues.Length; ++i) @@ -791,11 +795,10 @@ private void GetFeatureValueSparse(IChannel ch, FloatLabelCursor cursor, int colIdx = cursorFeaturesIndices[i]; int newColIdx = catMetaData.OnehotIndices[colIdx]; if (catMetaData.IsCategoricalFeature[newColIdx]) - fv = catMetaData.OnehotBias[colIdx] + 1; + fv = catMetaData.OnehotBias[colIdx]; if (newColIdx != lastIdx) { - featureIndices.Add(newColIdx); - values.Add(fv); + ivPair[newColIdx] = fv; nhot = 1; } else @@ -804,13 +807,14 @@ private void GetFeatureValueSparse(IChannel ch, FloatLabelCursor cursor, ++nhot; var prob = rand.NextSingle(); if (prob < 1.0f / nhot) - values[values.Count - 1] = fv; + ivPair[newColIdx] = fv; } lastIdx = newColIdx; } - indices = featureIndices.ToArray(); - featureValues = values.ToArray(); - cnt = featureIndices.Count; + var sortedIVPair = new SortedDictionary(ivPair); + indices = sortedIVPair.Keys.ToArray(); + featureValues = sortedIVPair.Values.ToArray(); + cnt = ivPair.Count; } else { diff --git a/src/Microsoft.ML.LightGbm/WrappedLightGbmBooster.cs b/src/Microsoft.ML.LightGbm/WrappedLightGbmBooster.cs index 882600a7bf..a02d21de6b 100644 --- a/src/Microsoft.ML.LightGbm/WrappedLightGbmBooster.cs +++ b/src/Microsoft.ML.LightGbm/WrappedLightGbmBooster.cs @@ -179,7 +179,7 @@ private static int[] GetCatThresholds(UInt32[] catThreshold, int lowerBound, int for (int k = 0; k < 32; ++k) { int cat = (j - lowerBound) * 32 + k; - if (FindInBitset(catThreshold, lowerBound, upperBound, cat) && cat > 0) + if (FindInBitset(catThreshold, lowerBound, upperBound, cat)) cats.Add(cat); } } @@ -239,7 +239,7 @@ public InternalTreeEnsemble GetModel(int[] categoricalFeatureBoudaries) categoricalSplitFeatures[node] = new int[cats.Length]; // Convert Cat thresholds to feature indices. for (int j = 0; j < cats.Length; ++j) - categoricalSplitFeatures[node][j] = splitFeature[node] + cats[j] - 1; + categoricalSplitFeatures[node][j] = splitFeature[node] + cats[j]; splitFeature[node] = -1; categoricalSplit[node] = true;