diff --git a/src/Microsoft.ML.LightGbm/LightGbmArguments.cs b/src/Microsoft.ML.LightGbm/LightGbmArguments.cs index a05d57fd1b..765b0d297a 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmArguments.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmArguments.cs @@ -58,7 +58,6 @@ public BoosterParameterBase(OptionsBase options) public abstract class OptionsBase : IBoosterParameterFactory { internal BoosterParameterBase GetBooster() { return null; } - /// /// The minimum loss reduction required to make a further partition on a leaf node of the tree. /// diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index f82b937a14..353856e7bd 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -75,6 +75,12 @@ public enum EvaluateMetricType LogLoss, } + /// + /// Whether training data is unbalanced. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Use for multi-class classification when training data is not balanced", ShortName = "us")] + public bool UnbalancedSets = false; + /// /// Whether to use softmax loss. /// @@ -110,6 +116,7 @@ internal override Dictionary ToDictionary(IHost host) { var res = base.ToDictionary(host); + res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets; res[GetOptionName(nameof(Sigmoid))] = Sigmoid; res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index b7ea7b1534..c5a734b6b3 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -11974,6 +11974,18 @@ "IsNullable": false, "Default": "Auto" }, + { + "Name": "UnbalancedSets", + "Type": "Bool", + "Desc": "Use for multi-class classification when training data is not balanced", + "Aliases": [ + "us" + ], + "Required": false, + "SortOrder": 150.0, + "IsNullable": false, + "Default": false + }, { "Name": "UseSoftmax", "Type": "Bool", diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs index ea8026f99e..71253a32b3 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs @@ -812,7 +812,7 @@ public void TreeEnsembleFeaturizingPipelineMulticlass() private class RowWithKey { - [KeyType()] + [KeyType(4)] public uint KeyLabel { get; set; } } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index 73a5fa4149..b95f84b9a6 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -55,6 +55,28 @@ public void LightGBMBinaryEstimator() NumberOfLeaves = 10, NumberOfThreads = 1, MinimumExampleCountPerLeaf = 2, + UnbalancedSets = false, // default value + }); + + var pipeWithTrainer = pipe.Append(trainer); + TestEstimatorCore(pipeWithTrainer, dataView); + + var transformedDataView = pipe.Fit(dataView).Transform(dataView); + var model = trainer.Fit(transformedDataView, transformedDataView); + Done(); + } + + [LightGBMFact] + public void LightGBMBinaryEstimatorUnbalanced() + { + var (pipe, dataView) = GetBinaryClassificationPipeline(); + + var trainer = ML.BinaryClassification.Trainers.LightGbm(new LightGbmBinaryTrainer.Options + { + NumberOfLeaves = 10, + NumberOfThreads = 1, + MinimumExampleCountPerLeaf = 2, + UnbalancedSets = true, }); var pipeWithTrainer = pipe.Append(trainer); @@ -322,6 +344,44 @@ public void LightGbmMulticlassEstimatorCorrectSigmoid() Done(); } + /// + /// LightGbmMulticlass Test of Balanced Data + /// + [LightGBMFact] + public void LightGbmMulticlassEstimatorBalanced() + { + var (pipeline, dataView) = GetMulticlassPipeline(); + + var trainer = ML.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options + { + UnbalancedSets = false + }); + + var pipe = pipeline.Append(trainer) + .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); + TestEstimatorCore(pipe, dataView); + Done(); + } + + /// + /// LightGbmMulticlass Test of Unbalanced Data + /// + [LightGBMFact] + public void LightGbmMulticlassEstimatorUnbalanced() + { + var (pipeline, dataView) = GetMulticlassPipeline(); + + var trainer = ML.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options + { + UnbalancedSets = true + }); + + var pipe = pipeline.Append(trainer) + .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); + TestEstimatorCore(pipe, dataView); + Done(); + } + // Number of examples private const int _rowNumber = 1000; // Number of features @@ -338,7 +398,7 @@ private class GbmExample public float[] Score; } - private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelString, out List mlnetPredictions, out double[] lgbmRawScores, out double[] lgbmProbabilities) + private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelString, out List mlnetPredictions, out double[] lgbmRawScores, out double[] lgbmProbabilities, bool unbalancedSets = false) { // Prepare data and train LightGBM model via ML.NET // Training matrix. It contains all feature vectors. @@ -372,7 +432,8 @@ private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelStr MinimumExampleCountPerGroup = 1, MinimumExampleCountPerLeaf = 1, UseSoftmax = useSoftmax, - Sigmoid = sigmoid // Custom sigmoid value. + Sigmoid = sigmoid, // Custom sigmoid value. + UnbalancedSets = unbalancedSets // false by default }); var gbm = gbmTrainer.Fit(dataView); @@ -583,6 +644,35 @@ public void LightGbmMulticlassEstimatorCompareSoftMax() Done(); } + [LightGBMFact] + public void LightGbmMulticlassEstimatorCompareUnbalanced() + { + // Train ML.NET LightGBM and native LightGBM and apply the trained models to the training set. + LightGbmHelper(useSoftmax: true, sigmoid: .5, out string modelString, out List mlnetPredictions, out double[] nativeResult1, out double[] nativeResult0, unbalancedSets:true); + + // The i-th predictor returned by LightGBM produces the raw score, denoted by z_i, of the i-th class. + // Assume that we have n classes in total. The i-th class probability can be computed via + // p_i = exp(z_i) / (exp(z_1) + ... + exp(z_n)). + Assert.True(modelString != null); + // Compare native LightGBM's and ML.NET's LightGBM results example by example + for (int i = 0; i < _rowNumber; ++i) + { + double sum = 0; + for (int j = 0; j < _classNumber; ++j) + { + Assert.Equal(nativeResult0[j + i * _classNumber], mlnetPredictions[i].Score[j], 6); + sum += Math.Exp((float)nativeResult1[j + i * _classNumber]); + } + for (int j = 0; j < _classNumber; ++j) + { + double prob = Math.Exp(nativeResult1[j + i * _classNumber]); + Assert.Equal(prob / sum, mlnetPredictions[i].Score[j], 6); + } + } + + Done(); + } + [LightGBMFact] public void LightGbmInDifferentCulture() {