Skip to content

Commit 6722dbf

Browse files
Bugfix/hardwired sigmoid (dotnet#3850)
* fixed Hardcoded sigmoid value. Let Microsoft.ML.Tests see internals of Microsoft.ML.StandardTrainers * Removed extra whitespace between comment and code * changed which parameter was internal * changed unneeded internal parameter back to private * Put comments in correct format, sorted the using list * Added direct tests between ML.NET and LightGBM with non-default sigmoid. * Update test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs Co-Authored-By: Wei-Sheng Chin <[email protected]> * Update test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs Co-Authored-By: Wei-Sheng Chin <[email protected]> * Update test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs Co-Authored-By: Wei-Sheng Chin <[email protected]> * renamed sigmoid to sigmoidScale * changed sigmoid from a default parameter so it can come before the out params
1 parent f6ba818 commit 6722dbf

File tree

5 files changed

+140
-9
lines changed

5 files changed

+140
-9
lines changed

src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ private protected override CalibratedModelParametersBase<LightGbmBinaryModelPara
232232
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
233233
var innerArgs = LightGbmInterfaceUtils.JoinParameters(base.GbmOptions);
234234
var pred = new LightGbmBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs);
235-
var cali = new PlattCalibrator(Host, -0.5, 0);
235+
var cali = new PlattCalibrator(Host, -LightGbmTrainerOptions.Sigmoid, 0);
236236
return new FeatureWeightsCalibratedModelParameters<LightGbmBinaryModelParameters, PlattCalibrator>(Host, pred, cali);
237237
}
238238

src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ private protected override OneVersusAllModelParameters CreatePredictor()
185185
for (int i = 0; i < _tlcNumClass; ++i)
186186
{
187187
var pred = CreateBinaryPredictor(i, innerArgs);
188-
var cali = new PlattCalibrator(Host, -0.5, 0);
188+
var cali = new PlattCalibrator(Host, -LightGbmTrainerOptions.Sigmoid, 0);
189189
predictors[i] = new FeatureWeightsCalibratedModelParameters<LightGbmBinaryModelParameters, PlattCalibrator>(Host, pred, cali);
190190
}
191191
string obj = (string)GetGbmParameters()["objective"];

src/Microsoft.ML.StandardTrainers/Properties/AssemblyInfo.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Ensemble" + PublicKey.Value)]
99
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipe" + PublicKey.Value)]
1010
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)]
11+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
1112
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)]
1213

1314
[assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)]

src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ private static VersionInfo GetVersionInfo()
267267
/// <summary>
268268
/// Retrieves the model parameters.
269269
/// </summary>
270-
private ImmutableArray<object> SubModelParameters => _impl.Predictors.Cast<object>().ToImmutableArray();
270+
internal ImmutableArray<object> SubModelParameters => _impl.Predictors.Cast<object>().ToImmutableArray();
271271

272272
/// <summary>
273273
/// The type of the prediction task.

test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs

Lines changed: 136 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
using System.Collections.Generic;
77
using System.Linq;
88
using System.Threading;
9+
using Microsoft.ML.Calibrators;
910
using Microsoft.ML.Data;
1011
using Microsoft.ML.Internal.Utilities;
11-
using Microsoft.ML.Trainers.LightGbm;
1212
using Microsoft.ML.RunTests;
1313
using Microsoft.ML.Runtime;
1414
using Microsoft.ML.TestFramework.Attributes;
1515
using Microsoft.ML.Trainers.FastTree;
16+
using Microsoft.ML.Trainers.LightGbm;
1617
using Microsoft.ML.Transforms;
1718
using Xunit;
1819

@@ -64,6 +65,31 @@ public void LightGBMBinaryEstimator()
6465
Done();
6566
}
6667

68+
/// <summary>
69+
/// LightGBMBinaryTrainer CorrectSigmoid test
70+
/// </summary>
71+
[LightGBMFact]
72+
public void LightGBMBinaryEstimatorCorrectSigmoid()
73+
{
74+
var (pipe, dataView) = GetBinaryClassificationPipeline();
75+
var sigmoid = .789;
76+
77+
var trainer = ML.BinaryClassification.Trainers.LightGbm(new LightGbmBinaryTrainer.Options
78+
{
79+
NumberOfLeaves = 10,
80+
NumberOfThreads = 1,
81+
MinimumExampleCountPerLeaf = 2,
82+
Sigmoid = sigmoid
83+
});
84+
85+
var transformedDataView = pipe.Fit(dataView).Transform(dataView);
86+
var model = trainer.Fit(transformedDataView, transformedDataView);
87+
88+
// The slope in the model calibrator should be equal to the negative of the sigmoid passed into the trainer.
89+
Assert.Equal(sigmoid, -model.Model.Calibrator.Slope);
90+
Done();
91+
}
92+
6793

6894
[Fact]
6995
public void GAMClassificationEstimator()
@@ -251,6 +277,32 @@ public void LightGbmMulticlassEstimator()
251277
Done();
252278
}
253279

280+
/// <summary>
281+
/// LightGbmMulticlass CorrectSigmoid test
282+
/// </summary>
283+
[LightGBMFact]
284+
public void LightGbmMulticlassEstimatorCorrectSigmoid()
285+
{
286+
var (pipeline, dataView) = GetMulticlassPipeline();
287+
var sigmoid = .789;
288+
289+
var trainer = ML.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options
290+
{
291+
Sigmoid = sigmoid
292+
});
293+
294+
var pipe = pipeline.Append(trainer)
295+
.Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));
296+
297+
var transformedDataView = pipe.Fit(dataView).Transform(dataView);
298+
var model = trainer.Fit(transformedDataView, transformedDataView);
299+
300+
// The slope in the all the calibrators should be equal to the negative of the sigmoid passed into the trainer.
301+
Assert.True(model.Model.SubModelParameters.All(predictor =>
302+
((FeatureWeightsCalibratedModelParameters<LightGbmBinaryModelParameters, PlattCalibrator>)predictor).Calibrator.Slope == -sigmoid));
303+
Done();
304+
}
305+
254306
// Number of examples
255307
private const int _rowNumber = 1000;
256308
// Number of features
@@ -267,7 +319,7 @@ private class GbmExample
267319
public float[] Score;
268320
}
269321

270-
private void LightGbmHelper(bool useSoftmax, out string modelString, out List<GbmExample> mlnetPredictions, out double[] lgbmRawScores, out double[] lgbmProbabilities)
322+
private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelString, out List<GbmExample> mlnetPredictions, out double[] lgbmRawScores, out double[] lgbmProbabilities)
271323
{
272324
// Prepare data and train LightGBM model via ML.NET
273325
// Training matrix. It contains all feature vectors.
@@ -300,7 +352,8 @@ private void LightGbmHelper(bool useSoftmax, out string modelString, out List<Gb
300352
NumberOfIterations = numberOfTrainingIterations,
301353
MinimumExampleCountPerGroup = 1,
302354
MinimumExampleCountPerLeaf = 1,
303-
UseSoftmax = useSoftmax
355+
UseSoftmax = useSoftmax,
356+
Sigmoid = sigmoid // Custom sigmoid value.
304357
});
305358

306359
var gbm = gbmTrainer.Fit(dataView);
@@ -376,14 +429,15 @@ private void LightGbmHelper(bool useSoftmax, out string modelString, out List<Gb
376429
[LightGBMFact]
377430
public void LightGbmMulticlassEstimatorCompareOva()
378431
{
432+
float sigmoidScale = 0.5f; // Constant used train LightGBM. See gbmParams["sigmoid"] in the helper function.
433+
379434
// Train ML.NET LightGBM and native LightGBM and apply the trained models to the training set.
380-
LightGbmHelper(useSoftmax: false, out string modelString, out List<GbmExample> mlnetPredictions, out double[] nativeResult1, out double[] nativeResult0);
435+
LightGbmHelper(useSoftmax: false, sigmoid: sigmoidScale, out string modelString, out List<GbmExample> mlnetPredictions, out double[] nativeResult1, out double[] nativeResult0);
381436

382437
// The i-th predictor returned by LightGBM produces the raw score, denoted by z_i, of the i-th class.
383438
// Assume that we have n classes in total. The i-th class probability can be computed via
384439
// p_i = sigmoid(sigmoidScale * z_i) / (sigmoid(sigmoidScale * z_1) + ... + sigmoid(sigmoidScale * z_n)).
385440
Assert.True(modelString != null);
386-
float sigmoidScale = 0.5f; // Constant used train LightGBM. See gbmParams["sigmoid"] in the helper function.
387441
// Compare native LightGBM's and ML.NET's LightGBM results example by example
388442
for (int i = 0; i < _rowNumber; ++i)
389443
{
@@ -405,11 +459,87 @@ public void LightGbmMulticlassEstimatorCompareOva()
405459
Done();
406460
}
407461

462+
/// <summary>
463+
/// Test LightGBM's sigmoid parameter with a custom value. This test checks if ML.NET and LightGBM produce the same result.
464+
/// </summary>
465+
[LightGBMFact]
466+
public void LightGbmMulticlassEstimatorCompareOvaUsingSigmoids()
467+
{
468+
var sigmoidScale = .790;
469+
// Train ML.NET LightGBM and native LightGBM and apply the trained models to the training set.
470+
LightGbmHelper(useSoftmax: false, sigmoid: sigmoidScale, out string modelString, out List<GbmExample> mlnetPredictions, out double[] nativeResult1, out double[] nativeResult0);
471+
472+
// The i-th predictor returned by LightGBM produces the raw score, denoted by z_i, of the i-th class.
473+
// Assume that we have n classes in total. The i-th class probability can be computed via
474+
// p_i = sigmoid(sigmoidScale * z_i) / (sigmoid(sigmoidScale * z_1) + ... + sigmoid(sigmoidScale * z_n)).
475+
Assert.True(modelString != null);
476+
477+
// Compare native LightGBM's and ML.NET's LightGBM results example by example
478+
for (int i = 0; i < _rowNumber; ++i)
479+
{
480+
double sum = 0;
481+
for (int j = 0; j < _classNumber; ++j)
482+
{
483+
Assert.Equal(nativeResult0[j + i * _classNumber], mlnetPredictions[i].Score[j], 6);
484+
if (float.IsNaN((float)nativeResult1[j + i * _classNumber]))
485+
continue;
486+
sum += MathUtils.SigmoidSlow((float)sigmoidScale * (float)nativeResult1[j + i * _classNumber]);
487+
}
488+
for (int j = 0; j < _classNumber; ++j)
489+
{
490+
double prob = MathUtils.SigmoidSlow((float)sigmoidScale * (float)nativeResult1[j + i * _classNumber]);
491+
Assert.Equal(prob / sum, mlnetPredictions[i].Score[j], 6);
492+
}
493+
}
494+
495+
Done();
496+
}
497+
498+
/// <summary>
499+
/// Make sure different sigmoid parameters produce different scores. In this test, two LightGBM models are trained with two different sigmoid values.
500+
/// </summary>
501+
[LightGBMFact]
502+
public void LightGbmMulticlassEstimatorCompareOvaUsingDifferentSigmoids()
503+
{
504+
// Run native implemenation twice, see that results are different with different sigmoid values.
505+
var firstSigmoidScale = .790;
506+
var secondSigmoidScale = .2;
507+
508+
// Train native LightGBM with both sigmoid values and apply the trained models to the training set.
509+
LightGbmHelper(useSoftmax: false, sigmoid: firstSigmoidScale, out string firstModelString, out List<GbmExample> firstMlnetPredictions, out double[] firstNativeResult1, out double[] firstNativeResult0);
510+
LightGbmHelper(useSoftmax: false, sigmoid: secondSigmoidScale, out string secondModelString, out List<GbmExample> secondMlnetPredictions, out double[] secondNativeResult1, out double[] secondNativeResult0);
511+
512+
// Compare native LightGBM's results when 2 different sigmoid values are used.
513+
for (int i = 0; i < _rowNumber; ++i)
514+
{
515+
var areEqual = true;
516+
for (int j = 0; j < _classNumber; ++j)
517+
{
518+
if (float.IsNaN((float)firstNativeResult1[j + i * _classNumber]))
519+
continue;
520+
if (float.IsNaN((float)secondNativeResult1[j + i * _classNumber]))
521+
continue;
522+
523+
// Testing to make sure that at least 1 value is different. This avoids false positives when values are 0
524+
// even for the same sigmoid value.
525+
areEqual &= firstMlnetPredictions[i].Score[j].Equals(secondMlnetPredictions[i].Score[j]);
526+
527+
// Testing that the native result is different before we apply the sigmoid.
528+
Assert.NotEqual((float)firstNativeResult1[j + i * _classNumber], (float)secondNativeResult1[j + i * _classNumber], 6);
529+
}
530+
531+
// There should be at least 1 value that is different in the row.
532+
Assert.False(areEqual);
533+
}
534+
535+
Done();
536+
}
537+
408538
[LightGBMFact]
409539
public void LightGbmMulticlassEstimatorCompareSoftMax()
410540
{
411541
// Train ML.NET LightGBM and native LightGBM and apply the trained models to the training set.
412-
LightGbmHelper(useSoftmax: true, out string modelString, out List<GbmExample> mlnetPredictions, out double[] nativeResult1, out double[] nativeResult0);
542+
LightGbmHelper(useSoftmax: true, sigmoid: .5, out string modelString, out List<GbmExample> mlnetPredictions, out double[] nativeResult1, out double[] nativeResult0);
413543

414544
// The i-th predictor returned by LightGBM produces the raw score, denoted by z_i, of the i-th class.
415545
// Assume that we have n classes in total. The i-th class probability can be computed via

0 commit comments

Comments
 (0)