@@ -19,6 +19,8 @@ internal class CodeGenerator : IProjectGenerator
1919 private readonly Pipeline pipeline ;
2020 private readonly CodeGeneratorSettings settings ;
2121 private readonly ColumnInferenceResults columnInferenceResult ;
22+ private readonly HashSet < string > LightGBMTrainers = new HashSet < string > ( ) { TrainerName . LightGbmBinary . ToString ( ) , TrainerName . LightGbmMulti . ToString ( ) , TrainerName . LightGbmRegression . ToString ( ) } ;
23+ private readonly HashSet < string > mklComponentsTrainers = new HashSet < string > ( ) { TrainerName . OlsRegression . ToString ( ) , TrainerName . SymbolicSgdLogisticRegressionBinary . ToString ( ) } ;
2224
2325 internal CodeGenerator ( Pipeline pipeline , ColumnInferenceResults columnInferenceResult , CodeGeneratorSettings settings )
2426 {
@@ -29,25 +31,32 @@ internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInference
2931
3032 public void GenerateOutput ( )
3133 {
34+ // Get the extra nuget packages to be included in the generated project.
35+ var trainerNodes = pipeline . Nodes . Where ( t => t . NodeType == PipelineNodeType . Trainer ) ;
36+
37+ bool includeLightGbmPackage = false ;
38+ bool includeMklComponentsPackage = false ;
39+ SetRequiredNugetPackages ( trainerNodes , ref includeLightGbmPackage , ref includeMklComponentsPackage ) ;
40+
3241 // Get Namespace
3342 var namespaceValue = Utils . Normalize ( settings . OutputName ) ;
3443 var labelType = columnInferenceResult . TextLoaderOptions . Columns . Where ( t => t . Name == columnInferenceResult . ColumnInformation . LabelColumnName ) . First ( ) . DataKind ;
3544 Type labelTypeCsharp = Utils . GetCSharpType ( labelType ) ;
3645
3746 // Generate Model Project
38- var modelProjectContents = GenerateModelProjectContents ( namespaceValue , labelTypeCsharp ) ;
47+ var modelProjectContents = GenerateModelProjectContents ( namespaceValue , labelTypeCsharp , includeLightGbmPackage , includeMklComponentsPackage ) ;
3948
4049 // Write files to disk.
4150 var modelprojectDir = Path . Combine ( settings . OutputBaseDir , $ "{ settings . OutputName } .Model") ;
4251 var dataModelsDir = Path . Combine ( modelprojectDir , "DataModels" ) ;
4352 var modelProjectName = $ "{ settings . OutputName } .Model.csproj";
4453
45- Utils . WriteOutputToFiles ( modelProjectContents . ObservationCSFileContent , "Observation .cs" , dataModelsDir ) ;
46- Utils . WriteOutputToFiles ( modelProjectContents . PredictionCSFileContent , "Prediction .cs" , dataModelsDir ) ;
54+ Utils . WriteOutputToFiles ( modelProjectContents . ObservationCSFileContent , "SampleObservation .cs" , dataModelsDir ) ;
55+ Utils . WriteOutputToFiles ( modelProjectContents . PredictionCSFileContent , "SamplePrediction .cs" , dataModelsDir ) ;
4756 Utils . WriteOutputToFiles ( modelProjectContents . ModelProjectFileContent , modelProjectName , modelprojectDir ) ;
4857
4958 // Generate ConsoleApp Project
50- var consoleAppProjectContents = GenerateConsoleAppProjectContents ( namespaceValue , labelTypeCsharp ) ;
59+ var consoleAppProjectContents = GenerateConsoleAppProjectContents ( namespaceValue , labelTypeCsharp , includeLightGbmPackage , includeMklComponentsPackage ) ;
5160
5261 // Write files to disk.
5362 var consoleAppProjectDir = Path . Combine ( settings . OutputBaseDir , $ "{ settings . OutputName } .ConsoleApp") ;
@@ -65,12 +74,33 @@ public void GenerateOutput()
6574 Utils . AddProjectsToSolution ( modelprojectDir , modelProjectName , consoleAppProjectDir , consoleAppProjectName , solutionPath ) ;
6675 }
6776
68- internal ( string ConsoleAppProgramCSFileContent , string ConsoleAppProjectFileContent , string modelBuilderCSFileContent ) GenerateConsoleAppProjectContents ( string namespaceValue , Type labelTypeCsharp )
77+ private void SetRequiredNugetPackages ( IEnumerable < PipelineNode > trainerNodes , ref bool includeLightGbmPackage , ref bool includeMklComponentsPackage )
78+ {
79+ foreach ( var node in trainerNodes )
80+ {
81+ PipelineNode currentNode = node ;
82+ if ( currentNode . Name == TrainerName . Ova . ToString ( ) )
83+ {
84+ currentNode = ( PipelineNode ) currentNode . Properties [ "BinaryTrainer" ] ;
85+ }
86+
87+ if ( LightGBMTrainers . Contains ( currentNode . Name ) )
88+ {
89+ includeLightGbmPackage = true ;
90+ }
91+ else if ( mklComponentsTrainers . Contains ( currentNode . Name ) )
92+ {
93+ includeMklComponentsPackage = true ;
94+ }
95+ }
96+ }
97+
98+ internal ( string ConsoleAppProgramCSFileContent , string ConsoleAppProjectFileContent , string modelBuilderCSFileContent ) GenerateConsoleAppProjectContents ( string namespaceValue , Type labelTypeCsharp , bool includeLightGbmPackage , bool includeMklComponentsPackage )
6999 {
70100 var predictProgramCSFileContent = GeneratePredictProgramCSFileContent ( namespaceValue ) ;
71101 predictProgramCSFileContent = Utils . FormatCode ( predictProgramCSFileContent ) ;
72102
73- var predictProjectFileContent = GeneratPredictProjectFileContent ( namespaceValue , true , true ) ;
103+ var predictProjectFileContent = GeneratPredictProjectFileContent ( namespaceValue , includeLightGbmPackage , includeMklComponentsPackage ) ;
74104
75105 var transformsAndTrainers = GenerateTransformsAndTrainers ( ) ;
76106 var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent ( transformsAndTrainers . Usings , transformsAndTrainers . TrainerMethod , transformsAndTrainers . PreTrainerTransforms , transformsAndTrainers . PostTrainerTransforms , namespaceValue , pipeline . CacheBeforeTrainer , labelTypeCsharp . Name ) ;
@@ -79,14 +109,14 @@ public void GenerateOutput()
79109 return ( predictProgramCSFileContent , predictProjectFileContent , modelBuilderCSFileContent ) ;
80110 }
81111
82- internal ( string ObservationCSFileContent , string PredictionCSFileContent , string ModelProjectFileContent ) GenerateModelProjectContents ( string namespaceValue , Type labelTypeCsharp )
112+ internal ( string ObservationCSFileContent , string PredictionCSFileContent , string ModelProjectFileContent ) GenerateModelProjectContents ( string namespaceValue , Type labelTypeCsharp , bool includeLightGbmPackage , bool includeMklComponentsPackage )
83113 {
84114 var classLabels = this . GenerateClassLabels ( ) ;
85115 var observationCSFileContent = GenerateObservationCSFileContent ( namespaceValue , classLabels ) ;
86116 observationCSFileContent = Utils . FormatCode ( observationCSFileContent ) ;
87117 var predictionCSFileContent = GeneratePredictionCSFileContent ( labelTypeCsharp . Name , namespaceValue ) ;
88118 predictionCSFileContent = Utils . FormatCode ( predictionCSFileContent ) ;
89- var modelProjectFileContent = GenerateModelProjectFileContent ( ) ;
119+ var modelProjectFileContent = GenerateModelProjectFileContent ( includeLightGbmPackage , includeMklComponentsPackage ) ;
90120 return ( observationCSFileContent , predictionCSFileContent , modelProjectFileContent ) ;
91121 }
92122
@@ -218,9 +248,9 @@ internal IList<string> GenerateClassLabels()
218248 }
219249
220250 #region Model project
221- private static string GenerateModelProjectFileContent ( )
251+ private static string GenerateModelProjectFileContent ( bool includeLightGbmPackage , bool includeMklComponentsPackage )
222252 {
223- ModelProject modelProject = new ModelProject ( ) ;
253+ ModelProject modelProject = new ModelProject ( ) { IncludeLightGBMPackage = includeLightGbmPackage , IncludeMklComponentsPackage = includeMklComponentsPackage } ;
224254 return modelProject . TransformText ( ) ;
225255 }
226256
@@ -238,9 +268,9 @@ private string GenerateObservationCSFileContent(string namespaceValue, IList<str
238268 #endregion
239269
240270 #region Predict Project
241- private static string GeneratPredictProjectFileContent ( string namespaceValue , bool includeMklComponentsPackage , bool includeLightGBMPackage )
271+ private static string GeneratPredictProjectFileContent ( string namespaceValue , bool includeLightGbmPackage , bool includeMklComponentsPackage )
242272 {
243- var predictProjectFileContent = new PredictProject ( ) { Namespace = namespaceValue , IncludeMklComponentsPackage = includeMklComponentsPackage , IncludeLightGBMPackage = includeLightGBMPackage } ;
273+ var predictProjectFileContent = new PredictProject ( ) { Namespace = namespaceValue , IncludeMklComponentsPackage = includeMklComponentsPackage , IncludeLightGBMPackage = includeLightGbmPackage } ;
244274 return predictProjectFileContent . TransformText ( ) ;
245275 }
246276
@@ -290,6 +320,5 @@ private string GenerateModelBuilderCSFileContent(string usings,
290320 return modelBuilder . TransformText ( ) ;
291321 }
292322 #endregion
293-
294323 }
295324}
0 commit comments