Skip to content

Commit 1d88e46

Browse files
codemzseerhardt
authored andcommitted
LightGBM (dotnet#392)
* LightGBM and test. * add test baselines and nuget source for lightGBM binaries. * Add entrypoint for lightGBM. * add unsafe flag for release build. * update nuget version. * make lightgbm test single threaded. * install gcc on OS machines to resolve dependencies on openmp thatis needed by lightgbm native code. * PR comments. Leave BREW and GCC in bash script to verify macOS tests work. * remove brew and gcc from build script. * PR feedback. * disable test on macOS. * disable test on macOS. * PR feedback.
1 parent bdfe25d commit 1d88e46

File tree

116 files changed

+72868
-50
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

116 files changed

+72868
-50
lines changed

Microsoft.ML.sln

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "netstandard2.0", "netstanda
116116
EndProject
117117
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Sweeper.Tests", "test\Microsoft.ML.Sweeper.Tests\Microsoft.ML.Sweeper.Tests.csproj", "{3DEB504D-7A07-48CE-91A2-8047461CB3D4}"
118118
EndProject
119+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.LightGBM", "src\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj", "{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}"
120+
EndProject
119121
Global
120122
GlobalSection(SolutionConfigurationPlatforms) = preSolution
121123
Debug|Any CPU = Debug|Any CPU
@@ -222,6 +224,10 @@ Global
222224
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Debug|Any CPU.Build.0 = Debug|Any CPU
223225
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Release|Any CPU.ActiveCfg = Release|Any CPU
224226
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Release|Any CPU.Build.0 = Release|Any CPU
227+
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
228+
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Debug|Any CPU.Build.0 = Debug|Any CPU
229+
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Release|Any CPU.ActiveCfg = Release|Any CPU
230+
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Release|Any CPU.Build.0 = Release|Any CPU
225231
EndGlobalSection
226232
GlobalSection(SolutionProperties) = preSolution
227233
HideSolutionNode = FALSE
@@ -260,6 +266,7 @@ Global
260266
{487213C9-E8A9-4F94-85D7-28A05DBBFE3A} = {DEC8F776-49F7-4D87-836C-FE4DC057D08C}
261267
{9252A8EB-ABFB-440C-AB4D-1D562753CE0F} = {487213C9-E8A9-4F94-85D7-28A05DBBFE3A}
262268
{3DEB504D-7A07-48CE-91A2-8047461CB3D4} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
269+
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
263270
EndGlobalSection
264271
GlobalSection(ExtensibilityGlobals) = postSolution
265272
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}

build/Dependencies.props

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
<SystemCodeDomPackageVersion>4.4.0</SystemCodeDomPackageVersion>
88
<SystemReflectionEmitLightweightPackageVersion>4.3.0</SystemReflectionEmitLightweightPackageVersion>
99
<PublishSymbolsPackageVersion>1.0.0-beta-62824-02</PublishSymbolsPackageVersion>
10+
<LightGBMPackageVersion>2.1.2.2</LightGBMPackageVersion>
1011
</PropertyGroup>
1112
</Project>

docs/building/unix-instructions.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@ macOS 10.12 or higher is needed to build dotnet/machinelearning.
4242

4343
On macOS a few components are needed which are not provided by a default developer setup:
4444
* cmake 3.10.3
45+
* gcc
4546
* All the requirements necessary to run .NET Core 2.0 applications. To view macOS prerequisites click [here](https://docs.microsoft.com/en-us/dotnet/core/macos-prerequisites?tabs=netcore2x).
4647

47-
One way of obtaining CMake is via [Homebrew](http://brew.sh):
48+
One way of obtaining CMake and gcc is via [Homebrew](http://brew.sh):
4849
```sh
4950
$ brew install cmake
51+
$ brew install gcc
5052
```
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
<Project Sdk="Microsoft.NET.Sdk" DefaultTargets="Pack">
2+
3+
<PropertyGroup>
4+
<TargetFramework>netstandard2.0</TargetFramework>
5+
<PackageDescription>ML.NET component for LightGBM</PackageDescription>
6+
</PropertyGroup>
7+
8+
<ItemGroup>
9+
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
10+
<PackageReference Include="LightGBM" Version="$(LightGBMPackageVersion)" />
11+
</ItemGroup>
12+
13+
</Project>
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<Project DefaultTargets="Pack">
2+
3+
<Import Project="Microsoft.ML.LightGBM.nupkgproj" />
4+
5+
</Project>

src/Microsoft.ML.LightGBM/LightGbmArguments.cs

Lines changed: 414 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.EntryPoints;
8+
using Microsoft.ML.Runtime.FastTree;
9+
using Microsoft.ML.Runtime.Internal.Calibration;
10+
using Microsoft.ML.Runtime.Internal.Internallearn;
11+
using Microsoft.ML.Runtime.LightGBM;
12+
using Microsoft.ML.Runtime.Model;
13+
14+
[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(LightGbmArguments),
15+
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) },
16+
"LightGBM Binary Classification", LightGbmBinaryTrainer.LoadNameValue, LightGbmBinaryTrainer.ShortName, DocName = "trainer/LightGBM.md")]
17+
18+
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(LightGbmBinaryPredictor), null, typeof(SignatureLoadModel),
19+
"LightGBM Binary Executor",
20+
LightGbmBinaryPredictor.LoaderSignature)]
21+
22+
[assembly: LoadableClass(typeof(void), typeof(LightGbm), null, typeof(SignatureEntryPointModule), "LightGBM")]
23+
24+
namespace Microsoft.ML.Runtime.LightGBM
25+
{
26+
public sealed class LightGbmBinaryPredictor : FastTreePredictionWrapper
27+
{
28+
public const string LoaderSignature = "LightGBMBinaryExec";
29+
public const string RegistrationName = "LightGBMBinaryPredictor";
30+
private static VersionInfo GetVersionInfo()
31+
{
32+
// REVIEW: can we decouple the version from FastTree predictor version ?
33+
return new VersionInfo(
34+
modelSignature: "LGBBINCL",
35+
// verWrittenCur: 0x00010001, // Initial
36+
// verWrittenCur: 0x00010002, // _numFeatures serialized
37+
// verWrittenCur: 0x00010003, // Ini content out of predictor
38+
//verWrittenCur: 0x00010004, // Add _defaultValueForMissing
39+
verWrittenCur: 0x00010005, // Categorical splits.
40+
verReadableCur: 0x00010004,
41+
verWeCanReadBack: 0x00010001,
42+
loaderSignature: LoaderSignature);
43+
}
44+
45+
protected override uint VerNumFeaturesSerialized { get { return 0x00010002; } }
46+
47+
protected override uint VerDefaultValueSerialized { get { return 0x00010004; } }
48+
49+
protected override uint VerCategoricalSplitSerialized { get { return 0x00010005; } }
50+
51+
internal LightGbmBinaryPredictor(IHostEnvironment env, FastTree.Internal.Ensemble trainedEnsemble, int featureCount, string innerArgs)
52+
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
53+
{
54+
}
55+
56+
private LightGbmBinaryPredictor(IHostEnvironment env, ModelLoadContext ctx)
57+
: base(env, RegistrationName, ctx, GetVersionInfo())
58+
{
59+
}
60+
61+
protected override void SaveCore(ModelSaveContext ctx)
62+
{
63+
base.SaveCore(ctx);
64+
ctx.SetVersionInfo(GetVersionInfo());
65+
}
66+
67+
public static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
68+
{
69+
Contracts.CheckValue(env, nameof(env));
70+
env.CheckValue(ctx, nameof(ctx));
71+
ctx.CheckAtModel(GetVersionInfo());
72+
var predictor = new LightGbmBinaryPredictor(env, ctx);
73+
ICalibrator calibrator;
74+
ctx.LoadModelOrNull<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
75+
if (calibrator == null)
76+
return predictor;
77+
return new CalibratedPredictor(env, predictor, calibrator);
78+
}
79+
80+
public override PredictionKind PredictionKind { get { return PredictionKind.BinaryClassification; } }
81+
}
82+
83+
public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase<float, IPredictorWithFeatureWeights<float>>
84+
{
85+
public const string Summary = "LightGBM Binary Classifier";
86+
public const string LoadNameValue = "LightGBMBinary";
87+
public const string ShortName = "LightGBM";
88+
89+
public LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args)
90+
: base(env, args, PredictionKind.BinaryClassification, "LGBBINCL")
91+
{
92+
}
93+
94+
public override IPredictorWithFeatureWeights<float> CreatePredictor()
95+
{
96+
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
97+
var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options);
98+
var pred = new LightGbmBinaryPredictor(Host, TrainedEnsemble, FeatureCount, innerArgs);
99+
var cali = new PlattCalibrator(Host, -0.5, 0);
100+
return new FeatureWeightsCalibratedPredictor(Host, pred, cali);
101+
}
102+
103+
protected override void CheckDataValid(IChannel ch, RoleMappedData data)
104+
{
105+
Host.AssertValue(ch);
106+
base.CheckDataValid(ch, data);
107+
var labelType = data.Schema.Label.Type;
108+
if (!(labelType.IsBool || labelType.IsKey || labelType == NumberType.R4))
109+
{
110+
throw ch.ExceptParam(nameof(data),
111+
$"Label column '{data.Schema.Label.Name}' is of type '{labelType}', but must be key, boolean or R4.");
112+
}
113+
}
114+
115+
protected override void CheckAndUpdateParametersBeforeTraining(IChannel ch, RoleMappedData data, float[] labels, int[] groups)
116+
{
117+
Options["objective"] = "binary";
118+
// Add default metric.
119+
if (!Options.ContainsKey("metric"))
120+
Options["metric"] = "binary_logloss";
121+
}
122+
}
123+
124+
/// <summary>
125+
/// A component to train an LightGBM model.
126+
/// </summary>
127+
public static partial class LightGbm
128+
{
129+
[TlcModule.EntryPoint(
130+
Name = "Trainers.LightGbmBinaryClassifier",
131+
Desc = "Train an LightGBM binary class model",
132+
UserName = LightGbmBinaryTrainer.Summary,
133+
ShortName = LightGbmBinaryTrainer.ShortName)]
134+
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, LightGbmArguments input)
135+
{
136+
Contracts.CheckValue(env, nameof(env));
137+
var host = env.Register("TrainLightGBM");
138+
host.CheckValue(input, nameof(input));
139+
EntryPointUtils.CheckInputArgs(host, input);
140+
141+
return LearnerEntryPointsUtils.Train<LightGbmArguments, CommonOutputs.BinaryClassificationOutput>(host, input,
142+
() => new LightGbmBinaryTrainer(host, input),
143+
getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
144+
getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));
145+
}
146+
}
147+
}

0 commit comments

Comments
 (0)