Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
Expand Down Expand Up @@ -124,7 +125,7 @@ private protected OptionsBase() { }
/// </value>
[Argument(ArgumentType.Multiple,
HelpText = "Which booster to use, can be gbtree, gblinear or dart. gbtree and dart use tree based model while gblinear uses linear function.",
Name="Booster",
Name = "Booster",
SortOrder = 3)]
internal IBoosterParameterFactory BoosterFactory = new GradientBooster.Options();

Expand Down Expand Up @@ -285,6 +286,7 @@ private sealed class CategoricalMetaData
public int[] OnehotIndices;
public int[] OnehotBias;
public bool[] IsCategoricalFeature;
public int[] CatIndices;
}

// Contains the passed in options when the API is called
Expand Down Expand Up @@ -317,16 +319,16 @@ private protected LightGbmTrainerBase(IHostEnvironment env,
double? learningRate,
int numberOfIterations)
: this(env, name, new TOptions()
{
NumberOfLeaves = numberOfLeaves,
MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf,
LearningRate = learningRate,
NumberOfIterations = numberOfIterations,
LabelColumnName = labelColumn.Name,
FeatureColumnName = featureColumnName,
ExampleWeightColumnName = exampleWeightColumnName,
RowGroupColumnName = rowGroupColumnName
},
{
NumberOfLeaves = numberOfLeaves,
MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf,
LearningRate = learningRate,
NumberOfIterations = numberOfIterations,
LabelColumnName = labelColumn.Name,
FeatureColumnName = featureColumnName,
ExampleWeightColumnName = exampleWeightColumnName,
RowGroupColumnName = rowGroupColumnName
},
labelColumn)
{
}
Expand Down Expand Up @@ -384,7 +386,7 @@ private protected override TModel TrainModelCore(TrainContext context)
return CreatePredictor();
}

private protected virtual void InitializeBeforeTraining(){}
private protected virtual void InitializeBeforeTraining() { }

private void InitParallelTraining()
{
Expand Down Expand Up @@ -522,6 +524,7 @@ private static List<string> ConstructCategoricalFeatureMetaData(int[] categorica
++j;
}
}
catMetaData.CatIndices = catIndices.Select(int.Parse).ToArray();
return catIndices;
}

Expand Down Expand Up @@ -761,7 +764,7 @@ private void GetFeatureValueDense(IChannel ch, FloatLabelCursor cursor, Categori
}
}
// All-Zero is category 0.
fv = hotIdx - catMetaData.CategoricalBoudaries[i] + 1;
fv = hotIdx - catMetaData.CategoricalBoudaries[i];
}
featureValuesTemp[i] = fv;
}
Expand All @@ -781,8 +784,9 @@ private void GetFeatureValueSparse(IChannel ch, FloatLabelCursor cursor,
var cursorFeaturesIndices = cursor.Features.GetIndices();
if (catMetaData.CategoricalBoudaries != null)
{
List<int> featureIndices = new List<int>();
List<float> values = new List<float>();
Dictionary<int, float> ivPair = new Dictionary<int, float>();
foreach (var idx in catMetaData.CatIndices)
ivPair[idx] = -1;
int lastIdx = -1;
int nhot = 0;
for (int i = 0; i < cursorFeaturesValues.Length; ++i)
Expand All @@ -791,11 +795,10 @@ private void GetFeatureValueSparse(IChannel ch, FloatLabelCursor cursor,
int colIdx = cursorFeaturesIndices[i];
int newColIdx = catMetaData.OnehotIndices[colIdx];
if (catMetaData.IsCategoricalFeature[newColIdx])
fv = catMetaData.OnehotBias[colIdx] + 1;
fv = catMetaData.OnehotBias[colIdx];
if (newColIdx != lastIdx)
{
featureIndices.Add(newColIdx);
values.Add(fv);
ivPair[newColIdx] = fv;
nhot = 1;
}
else
Expand All @@ -804,13 +807,14 @@ private void GetFeatureValueSparse(IChannel ch, FloatLabelCursor cursor,
++nhot;
var prob = rand.NextSingle();
if (prob < 1.0f / nhot)
values[values.Count - 1] = fv;
ivPair[newColIdx] = fv;
}
lastIdx = newColIdx;
}
indices = featureIndices.ToArray();
featureValues = values.ToArray();
cnt = featureIndices.Count;
var sortedIVPair = new SortedDictionary<int, float>(ivPair);
indices = sortedIVPair.Keys.ToArray();
featureValues = sortedIVPair.Values.ToArray();
cnt = ivPair.Count;
}
else
{
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.LightGbm/WrappedLightGbmBooster.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ private static int[] GetCatThresholds(UInt32[] catThreshold, int lowerBound, int
for (int k = 0; k < 32; ++k)
{
int cat = (j - lowerBound) * 32 + k;
if (FindInBitset(catThreshold, lowerBound, upperBound, cat) && cat > 0)
if (FindInBitset(catThreshold, lowerBound, upperBound, cat))
cats.Add(cat);
}
}
Expand Down Expand Up @@ -239,7 +239,7 @@ public InternalTreeEnsemble GetModel(int[] categoricalFeatureBoudaries)
categoricalSplitFeatures[node] = new int[cats.Length];
// Convert Cat thresholds to feature indices.
for (int j = 0; j < cats.Length; ++j)
categoricalSplitFeatures[node][j] = splitFeature[node] + cats[j] - 1;
categoricalSplitFeatures[node][j] = splitFeature[node] + cats[j];

splitFeature[node] = -1;
categoricalSplit[node] = true;
Expand Down