diff --git a/src/Microsoft.ML.LightGbm/LightGbmArguments.cs b/src/Microsoft.ML.LightGbm/LightGbmArguments.cs
index a05d57fd1b..765b0d297a 100644
--- a/src/Microsoft.ML.LightGbm/LightGbmArguments.cs
+++ b/src/Microsoft.ML.LightGbm/LightGbmArguments.cs
@@ -58,7 +58,6 @@ public BoosterParameterBase(OptionsBase options)
public abstract class OptionsBase : IBoosterParameterFactory
{
internal BoosterParameterBase GetBooster() { return null; }
-
///
/// The minimum loss reduction required to make a further partition on a leaf node of the tree.
///
diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs
index f82b937a14..353856e7bd 100644
--- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs
+++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs
@@ -75,6 +75,12 @@ public enum EvaluateMetricType
LogLoss,
}
+ ///
+ /// Whether training data is unbalanced.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Use for multi-class classification when training data is not balanced", ShortName = "us")]
+ public bool UnbalancedSets = false;
+
///
/// Whether to use softmax loss.
///
@@ -110,6 +116,7 @@ internal override Dictionary ToDictionary(IHost host)
{
var res = base.ToDictionary(host);
+ res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets;
res[GetOptionName(nameof(Sigmoid))] = Sigmoid;
res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());
diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
index b7ea7b1534..c5a734b6b3 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json
+++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
@@ -11974,6 +11974,18 @@
"IsNullable": false,
"Default": "Auto"
},
+ {
+ "Name": "UnbalancedSets",
+ "Type": "Bool",
+ "Desc": "Use for multi-class classification when training data is not balanced",
+ "Aliases": [
+ "us"
+ ],
+ "Required": false,
+ "SortOrder": 150.0,
+ "IsNullable": false,
+ "Default": false
+ },
{
"Name": "UseSoftmax",
"Type": "Bool",
diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs
index ea8026f99e..71253a32b3 100644
--- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs
+++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs
@@ -812,7 +812,7 @@ public void TreeEnsembleFeaturizingPipelineMulticlass()
private class RowWithKey
{
- [KeyType()]
+ [KeyType(4)]
public uint KeyLabel { get; set; }
}
diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
index 73a5fa4149..b95f84b9a6 100644
--- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
+++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
@@ -55,6 +55,28 @@ public void LightGBMBinaryEstimator()
NumberOfLeaves = 10,
NumberOfThreads = 1,
MinimumExampleCountPerLeaf = 2,
+ UnbalancedSets = false, // default value
+ });
+
+ var pipeWithTrainer = pipe.Append(trainer);
+ TestEstimatorCore(pipeWithTrainer, dataView);
+
+ var transformedDataView = pipe.Fit(dataView).Transform(dataView);
+ var model = trainer.Fit(transformedDataView, transformedDataView);
+ Done();
+ }
+
+ [LightGBMFact]
+ public void LightGBMBinaryEstimatorUnbalanced()
+ {
+ var (pipe, dataView) = GetBinaryClassificationPipeline();
+
+ var trainer = ML.BinaryClassification.Trainers.LightGbm(new LightGbmBinaryTrainer.Options
+ {
+ NumberOfLeaves = 10,
+ NumberOfThreads = 1,
+ MinimumExampleCountPerLeaf = 2,
+ UnbalancedSets = true,
});
var pipeWithTrainer = pipe.Append(trainer);
@@ -322,6 +344,44 @@ public void LightGbmMulticlassEstimatorCorrectSigmoid()
Done();
}
+ ///
+ /// LightGbmMulticlass Test of Balanced Data
+ ///
+ [LightGBMFact]
+ public void LightGbmMulticlassEstimatorBalanced()
+ {
+ var (pipeline, dataView) = GetMulticlassPipeline();
+
+ var trainer = ML.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options
+ {
+ UnbalancedSets = false
+ });
+
+ var pipe = pipeline.Append(trainer)
+ .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));
+ TestEstimatorCore(pipe, dataView);
+ Done();
+ }
+
+ ///
+ /// LightGbmMulticlass Test of Unbalanced Data
+ ///
+ [LightGBMFact]
+ public void LightGbmMulticlassEstimatorUnbalanced()
+ {
+ var (pipeline, dataView) = GetMulticlassPipeline();
+
+ var trainer = ML.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options
+ {
+ UnbalancedSets = true
+ });
+
+ var pipe = pipeline.Append(trainer)
+ .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));
+ TestEstimatorCore(pipe, dataView);
+ Done();
+ }
+
// Number of examples
private const int _rowNumber = 1000;
// Number of features
@@ -338,7 +398,7 @@ private class GbmExample
public float[] Score;
}
- private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelString, out List mlnetPredictions, out double[] lgbmRawScores, out double[] lgbmProbabilities)
+ private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelString, out List mlnetPredictions, out double[] lgbmRawScores, out double[] lgbmProbabilities, bool unbalancedSets = false)
{
// Prepare data and train LightGBM model via ML.NET
// Training matrix. It contains all feature vectors.
@@ -372,7 +432,8 @@ private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelStr
MinimumExampleCountPerGroup = 1,
MinimumExampleCountPerLeaf = 1,
UseSoftmax = useSoftmax,
- Sigmoid = sigmoid // Custom sigmoid value.
+ Sigmoid = sigmoid, // Custom sigmoid value.
+ UnbalancedSets = unbalancedSets // false by default
});
var gbm = gbmTrainer.Fit(dataView);
@@ -583,6 +644,35 @@ public void LightGbmMulticlassEstimatorCompareSoftMax()
Done();
}
+ [LightGBMFact]
+ public void LightGbmMulticlassEstimatorCompareUnbalanced()
+ {
+ // Train ML.NET LightGBM and native LightGBM and apply the trained models to the training set.
+ LightGbmHelper(useSoftmax: true, sigmoid: .5, out string modelString, out List mlnetPredictions, out double[] nativeResult1, out double[] nativeResult0, unbalancedSets:true);
+
+ // The i-th predictor returned by LightGBM produces the raw score, denoted by z_i, of the i-th class.
+ // Assume that we have n classes in total. The i-th class probability can be computed via
+ // p_i = exp(z_i) / (exp(z_1) + ... + exp(z_n)).
+ Assert.True(modelString != null);
+ // Compare native LightGBM's and ML.NET's LightGBM results example by example
+ for (int i = 0; i < _rowNumber; ++i)
+ {
+ double sum = 0;
+ for (int j = 0; j < _classNumber; ++j)
+ {
+ Assert.Equal(nativeResult0[j + i * _classNumber], mlnetPredictions[i].Score[j], 6);
+ sum += Math.Exp((float)nativeResult1[j + i * _classNumber]);
+ }
+ for (int j = 0; j < _classNumber; ++j)
+ {
+ double prob = Math.Exp(nativeResult1[j + i * _classNumber]);
+ Assert.Equal(prob / sum, mlnetPredictions[i].Score[j], 6);
+ }
+ }
+
+ Done();
+ }
+
[LightGBMFact]
public void LightGbmInDifferentCulture()
{