Skip to content

Commit c9cb3cc

Browse files
authored
[AutoML] Generated project - FastTree nuget package inclusion dynamically (#3567)
* added support for fast tree nuget pack inclusion in generated project * fix testcase * changed the tool name in telemetry message * dummy commit * remove space * dummy commit to trigger build
1 parent 7191ebe commit c9cb3cc

File tree

10 files changed

+60
-27
lines changed

10 files changed

+60
-27
lines changed

src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ internal class CodeGenerator : IProjectGenerator
2121
private readonly ColumnInferenceResults columnInferenceResult;
2222
private readonly HashSet<string> LightGBMTrainers = new HashSet<string>() { TrainerName.LightGbmBinary.ToString(), TrainerName.LightGbmMulti.ToString(), TrainerName.LightGbmRegression.ToString() };
2323
private readonly HashSet<string> mklComponentsTrainers = new HashSet<string>() { TrainerName.OlsRegression.ToString(), TrainerName.SymbolicSgdLogisticRegressionBinary.ToString() };
24+
private readonly HashSet<string> FastTreeTrainers = new HashSet<string>() { TrainerName.FastForestBinary.ToString(), TrainerName.FastForestRegression.ToString(), TrainerName.FastTreeBinary.ToString(), TrainerName.FastTreeRegression.ToString(), TrainerName.FastTreeTweedieRegression.ToString() };
25+
2426

2527
internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInferenceResult, CodeGeneratorSettings settings)
2628
{
@@ -36,15 +38,16 @@ public void GenerateOutput()
3638

3739
bool includeLightGbmPackage = false;
3840
bool includeMklComponentsPackage = false;
39-
SetRequiredNugetPackages(trainerNodes, ref includeLightGbmPackage, ref includeMklComponentsPackage);
41+
bool includeFastTreeePackage = false;
42+
SetRequiredNugetPackages(trainerNodes, ref includeLightGbmPackage, ref includeMklComponentsPackage, ref includeFastTreeePackage);
4043

4144
// Get Namespace
4245
var namespaceValue = Utils.Normalize(settings.OutputName);
4346
var labelType = columnInferenceResult.TextLoaderOptions.Columns.Where(t => t.Name == columnInferenceResult.ColumnInformation.LabelColumnName).First().DataKind;
4447
Type labelTypeCsharp = Utils.GetCSharpType(labelType);
4548

4649
// Generate Model Project
47-
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage);
50+
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage);
4851

4952
// Write files to disk.
5053
var modelprojectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.Model");
@@ -56,7 +59,7 @@ public void GenerateOutput()
5659
Utils.WriteOutputToFiles(modelProjectContents.ModelProjectFileContent, modelProjectName, modelprojectDir);
5760

5861
// Generate ConsoleApp Project
59-
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage);
62+
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage);
6063

6164
// Write files to disk.
6265
var consoleAppProjectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.ConsoleApp");
@@ -74,7 +77,7 @@ public void GenerateOutput()
7477
Utils.AddProjectsToSolution(modelprojectDir, modelProjectName, consoleAppProjectDir, consoleAppProjectName, solutionPath);
7578
}
7679

77-
private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, ref bool includeLightGbmPackage, ref bool includeMklComponentsPackage)
80+
private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, ref bool includeLightGbmPackage, ref bool includeMklComponentsPackage, ref bool includeFastTreePackage)
7881
{
7982
foreach (var node in trainerNodes)
8083
{
@@ -92,15 +95,19 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
9295
{
9396
includeMklComponentsPackage = true;
9497
}
98+
else if (FastTreeTrainers.Contains(currentNode.Name))
99+
{
100+
includeFastTreePackage = true;
101+
}
95102
}
96103
}
97104

98-
internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage)
105+
internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
99106
{
100107
var predictProgramCSFileContent = GeneratePredictProgramCSFileContent(namespaceValue);
101108
predictProgramCSFileContent = Utils.FormatCode(predictProgramCSFileContent);
102109

103-
var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, includeLightGbmPackage, includeMklComponentsPackage);
110+
var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage);
104111

105112
var transformsAndTrainers = GenerateTransformsAndTrainers();
106113
var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent(transformsAndTrainers.Usings, transformsAndTrainers.TrainerMethod, transformsAndTrainers.PreTrainerTransforms, transformsAndTrainers.PostTrainerTransforms, namespaceValue, pipeline.CacheBeforeTrainer, labelTypeCsharp.Name);
@@ -109,14 +116,14 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
109116
return (predictProgramCSFileContent, predictProjectFileContent, modelBuilderCSFileContent);
110117
}
111118

112-
internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage)
119+
internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
113120
{
114121
var classLabels = this.GenerateClassLabels();
115122
var observationCSFileContent = GenerateObservationCSFileContent(namespaceValue, classLabels);
116123
observationCSFileContent = Utils.FormatCode(observationCSFileContent);
117124
var predictionCSFileContent = GeneratePredictionCSFileContent(labelTypeCsharp.Name, namespaceValue);
118125
predictionCSFileContent = Utils.FormatCode(predictionCSFileContent);
119-
var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage);
126+
var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage);
120127
return (observationCSFileContent, predictionCSFileContent, modelProjectFileContent);
121128
}
122129

@@ -248,9 +255,9 @@ internal IList<string> GenerateClassLabels()
248255
}
249256

250257
#region Model project
251-
private static string GenerateModelProjectFileContent(bool includeLightGbmPackage, bool includeMklComponentsPackage)
258+
private static string GenerateModelProjectFileContent(bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
252259
{
253-
ModelProject modelProject = new ModelProject() { IncludeLightGBMPackage = includeLightGbmPackage, IncludeMklComponentsPackage = includeMklComponentsPackage };
260+
ModelProject modelProject = new ModelProject() { IncludeLightGBMPackage = includeLightGbmPackage, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeFastTreePackage = includeFastTreePackage };
254261
return modelProject.TransformText();
255262
}
256263

@@ -268,9 +275,9 @@ private string GenerateObservationCSFileContent(string namespaceValue, IList<str
268275
#endregion
269276

270277
#region Predict Project
271-
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeLightGbmPackage, bool includeMklComponentsPackage)
278+
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
272279
{
273-
var predictProjectFileContent = new PredictProject() { Namespace = namespaceValue, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeLightGBMPackage = includeLightGbmPackage };
280+
var predictProjectFileContent = new PredictProject() { Namespace = namespaceValue, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeLightGBMPackage = includeLightGbmPackage, IncludeFastTreePackage = includeFastTreePackage };
274281
return predictProjectFileContent.TransformText();
275282
}
276283

src/mlnet/CodeGenerator/CSharp/TrainerGeneratorFactory.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.Linq;
76
using Microsoft.ML.Auto;
87
using static Microsoft.ML.CLI.CodeGenerator.CSharp.TrainerGenerators;
98

src/mlnet/CodeGenerator/CSharp/TrainerGenerators.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,8 +555,6 @@ public override string[] GenerateUsings()
555555
{
556556
return binaryTrainerUsings;
557557
}
558-
559558
}
560-
561559
}
562560
}

src/mlnet/Telemetry/MlTelemetry.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public class MlTelemetry
1616

1717
public void SetCommandAndParameters(string command, IEnumerable<string> parameters)
1818
{
19-
if(parameters != null)
19+
if (parameters != null)
2020
{
2121
_parameters.AddRange(parameters);
2222
}
@@ -28,7 +28,7 @@ public void LogAutoTrainMlCommand(string dataFileName, string task, long dataFil
2828
{
2929
CheckFistTimeUse();
3030

31-
if(!_enabled)
31+
if (!_enabled)
3232
{
3333
return;
3434
}
@@ -71,7 +71,7 @@ private void CheckFistTimeUse()
7171
@"Welcome to the ML.NET CLI!
7272
--------------------------
7373
Learn more about ML.NET CLI: https://aka.ms/mlnet-cli
74-
Use 'dotnet ml --help' to see available commands or visit: https://aka.ms/mlnet-cli-docs
74+
Use 'mlnet --help' to see available commands or visit: https://aka.ms/mlnet-cli-docs
7575
7676
Telemetry
7777
---------

src/mlnet/Templates/Console/ModelProject.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,20 @@ public virtual string TransformText()
5757
#line 18 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
5858
}
5959

60+
#line default
61+
#line hidden
62+
63+
#line 19 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
64+
if(IncludeFastTreePackage){
65+
66+
#line default
67+
#line hidden
68+
this.Write(" <PackageReference Include=\"Microsoft.ML.FastTree\" Version=\"1.0.0-preview\" />\r" +
69+
"\n");
70+
71+
#line 21 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
72+
}
73+
6074
#line default
6175
#line hidden
6276
this.Write(" </ItemGroup>\r\n\r\n <ItemGroup>\r\n <None Update=\"MLModel.zip\">\r\n <CopyToOu" +
@@ -65,10 +79,11 @@ public virtual string TransformText()
6579
return this.GenerationEnvironment.ToString();
6680
}
6781

68-
#line 28 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
82+
#line 31 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
6983

7084
public bool IncludeLightGBMPackage {get;set;}
7185
public bool IncludeMklComponentsPackage {get;set;}
86+
public bool IncludeFastTreePackage {get;set;}
7287

7388

7489
#line default

src/mlnet/Templates/Console/ModelProject.tt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
<#}#>
1616
<# if(IncludeMklComponentsPackage){ #>
1717
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
18+
<#}#>
19+
<# if(IncludeFastTreePackage){ #>
20+
<PackageReference Include="Microsoft.ML.FastTree" Version="1.0.0-preview" />
1821
<#}#>
1922
</ItemGroup>
2023

@@ -28,4 +31,5 @@
2831
<#+
2932
public bool IncludeLightGBMPackage {get;set;}
3033
public bool IncludeMklComponentsPackage {get;set;}
34+
public bool IncludeFastTreePackage {get;set;}
3135
#>

src/mlnet/Templates/Console/PredictProject.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ public virtual string TransformText()
3737
if(IncludeMklComponentsPackage){
3838
this.Write(" <PackageReference Include=\"Microsoft.ML.Mkl.Components\" Version=\"1.0.0-previe" +
3939
"w\" />\r\n");
40+
}
41+
if(IncludeFastTreePackage){
42+
this.Write(" <PackageReference Include=\"Microsoft.ML.FastTree\" Version=\"1.0.0-preview\" />\r" +
43+
"\n");
4044
}
4145
this.Write(" </ItemGroup>\r\n <ItemGroup>\r\n <ProjectReference Include=\"..\\");
4246
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
@@ -49,6 +53,7 @@ public virtual string TransformText()
4953
public string Namespace {get;set;}
5054
public bool IncludeLightGBMPackage {get;set;}
5155
public bool IncludeMklComponentsPackage {get;set;}
56+
public bool IncludeFastTreePackage {get;set;}
5257

5358
}
5459
#region Base class

src/mlnet/Templates/Console/PredictProject.tt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
<#}#>
1818
<# if(IncludeMklComponentsPackage){ #>
1919
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
20+
<#}#>
21+
<# if(IncludeFastTreePackage){ #>
22+
<PackageReference Include="Microsoft.ML.FastTree" Version="1.0.0-preview" />
2023
<#}#>
2124
</ItemGroup>
2225
<ItemGroup>
@@ -27,4 +30,5 @@
2730
public string Namespace {get;set;}
2831
public bool IncludeLightGBMPackage {get;set;}
2932
public bool IncludeMklComponentsPackage {get;set;}
33+
public bool IncludeFastTreePackage {get;set;}
3034
#>

test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ModelProjectFileContentTest.approved.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
<PackageReference Include="Microsoft.ML" Version="1.0.0-preview" />
88
<PackageReference Include="Microsoft.ML.LightGBM" Version="1.0.0-preview" />
99
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
10+
<PackageReference Include="Microsoft.ML.FastTree" Version="1.0.0-preview" />
1011
</ItemGroup>
1112

1213
<ItemGroup>

test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public void ConsoleAppModelBuilderCSFileContentOvaTest()
4242
LabelName = "Label",
4343
ModelPath = "x:\\models\\model.zip"
4444
});
45-
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
45+
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);
4646

4747
Approvals.Verify(result.modelBuilderCSFileContent);
4848
}
@@ -65,7 +65,7 @@ public void ConsoleAppModelBuilderCSFileContentBinaryTest()
6565
LabelName = "Label",
6666
ModelPath = "x:\\models\\model.zip"
6767
});
68-
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
68+
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);
6969

7070
Approvals.Verify(result.modelBuilderCSFileContent);
7171
}
@@ -88,7 +88,7 @@ public void ConsoleAppModelBuilderCSFileContentRegressionTest()
8888
LabelName = "Label",
8989
ModelPath = "x:\\models\\model.zip"
9090
});
91-
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
91+
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);
9292

9393
Approvals.Verify(result.modelBuilderCSFileContent);
9494
}
@@ -111,7 +111,7 @@ public void ModelProjectFileContentTest()
111111
LabelName = "Label",
112112
ModelPath = "x:\\models\\model.zip"
113113
});
114-
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true);
114+
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true, true);
115115

116116
Approvals.Verify(result.ModelProjectFileContent);
117117
}
@@ -134,7 +134,7 @@ public void ObservationCSFileContentTest()
134134
LabelName = "Label",
135135
ModelPath = "x:\\models\\model.zip"
136136
});
137-
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true);
137+
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true, false);
138138

139139
Approvals.Verify(result.ObservationCSFileContent);
140140
}
@@ -158,7 +158,7 @@ public void PredictionCSFileContentTest()
158158
LabelName = "Label",
159159
ModelPath = "x:\\models\\model.zip"
160160
});
161-
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true);
161+
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true, false);
162162

163163
Approvals.Verify(result.PredictionCSFileContent);
164164
}
@@ -181,7 +181,7 @@ public void ConsoleAppProgramCSFileContentTest()
181181
LabelName = "Label",
182182
ModelPath = "x:\\models\\model.zip"
183183
});
184-
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
184+
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);
185185

186186
Approvals.Verify(result.ConsoleAppProgramCSFileContent);
187187
}
@@ -204,7 +204,7 @@ public void ConsoleAppProjectFileContentTest()
204204
LabelName = "Label",
205205
ModelPath = "x:\\models\\model.zip"
206206
});
207-
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
207+
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);
208208

209209
Approvals.Verify(result.ConsoleAppProjectFileContent);
210210
}

0 commit comments

Comments
 (0)