@@ -21,6 +21,8 @@ internal class CodeGenerator : IProjectGenerator
2121 private readonly ColumnInferenceResults columnInferenceResult ;
2222 private readonly HashSet < string > LightGBMTrainers = new HashSet < string > ( ) { TrainerName . LightGbmBinary . ToString ( ) , TrainerName . LightGbmMulti . ToString ( ) , TrainerName . LightGbmRegression . ToString ( ) } ;
2323 private readonly HashSet < string > mklComponentsTrainers = new HashSet < string > ( ) { TrainerName . OlsRegression . ToString ( ) , TrainerName . SymbolicSgdLogisticRegressionBinary . ToString ( ) } ;
24+ private readonly HashSet < string > FastTreeTrainers = new HashSet < string > ( ) { TrainerName . FastForestBinary . ToString ( ) , TrainerName . FastForestRegression . ToString ( ) , TrainerName . FastTreeBinary . ToString ( ) , TrainerName . FastTreeRegression . ToString ( ) , TrainerName . FastTreeTweedieRegression . ToString ( ) } ;
25+
2426
2527 internal CodeGenerator ( Pipeline pipeline , ColumnInferenceResults columnInferenceResult , CodeGeneratorSettings settings )
2628 {
@@ -36,15 +38,16 @@ public void GenerateOutput()
3638
3739 bool includeLightGbmPackage = false ;
3840 bool includeMklComponentsPackage = false ;
39- SetRequiredNugetPackages ( trainerNodes , ref includeLightGbmPackage , ref includeMklComponentsPackage ) ;
41+ bool includeFastTreeePackage = false ;
42+ SetRequiredNugetPackages ( trainerNodes , ref includeLightGbmPackage , ref includeMklComponentsPackage , ref includeFastTreeePackage ) ;
4043
4144 // Get Namespace
4245 var namespaceValue = Utils . Normalize ( settings . OutputName ) ;
4346 var labelType = columnInferenceResult . TextLoaderOptions . Columns . Where ( t => t . Name == columnInferenceResult . ColumnInformation . LabelColumnName ) . First ( ) . DataKind ;
4447 Type labelTypeCsharp = Utils . GetCSharpType ( labelType ) ;
4548
4649 // Generate Model Project
47- var modelProjectContents = GenerateModelProjectContents ( namespaceValue , labelTypeCsharp , includeLightGbmPackage , includeMklComponentsPackage ) ;
50+ var modelProjectContents = GenerateModelProjectContents ( namespaceValue , labelTypeCsharp , includeLightGbmPackage , includeMklComponentsPackage , includeFastTreeePackage ) ;
4851
4952 // Write files to disk.
5053 var modelprojectDir = Path . Combine ( settings . OutputBaseDir , $ "{ settings . OutputName } .Model") ;
@@ -56,7 +59,7 @@ public void GenerateOutput()
5659 Utils . WriteOutputToFiles ( modelProjectContents . ModelProjectFileContent , modelProjectName , modelprojectDir ) ;
5760
5861 // Generate ConsoleApp Project
59- var consoleAppProjectContents = GenerateConsoleAppProjectContents ( namespaceValue , labelTypeCsharp , includeLightGbmPackage , includeMklComponentsPackage ) ;
62+ var consoleAppProjectContents = GenerateConsoleAppProjectContents ( namespaceValue , labelTypeCsharp , includeLightGbmPackage , includeMklComponentsPackage , includeFastTreeePackage ) ;
6063
6164 // Write files to disk.
6265 var consoleAppProjectDir = Path . Combine ( settings . OutputBaseDir , $ "{ settings . OutputName } .ConsoleApp") ;
@@ -74,7 +77,7 @@ public void GenerateOutput()
7477 Utils . AddProjectsToSolution ( modelprojectDir , modelProjectName , consoleAppProjectDir , consoleAppProjectName , solutionPath ) ;
7578 }
7679
77- private void SetRequiredNugetPackages ( IEnumerable < PipelineNode > trainerNodes , ref bool includeLightGbmPackage , ref bool includeMklComponentsPackage )
80+ private void SetRequiredNugetPackages ( IEnumerable < PipelineNode > trainerNodes , ref bool includeLightGbmPackage , ref bool includeMklComponentsPackage , ref bool includeFastTreePackage )
7881 {
7982 foreach ( var node in trainerNodes )
8083 {
@@ -92,15 +95,19 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
9295 {
9396 includeMklComponentsPackage = true ;
9497 }
98+ else if ( FastTreeTrainers . Contains ( currentNode . Name ) )
99+ {
100+ includeFastTreePackage = true ;
101+ }
95102 }
96103 }
97104
98- internal ( string ConsoleAppProgramCSFileContent , string ConsoleAppProjectFileContent , string modelBuilderCSFileContent ) GenerateConsoleAppProjectContents ( string namespaceValue , Type labelTypeCsharp , bool includeLightGbmPackage , bool includeMklComponentsPackage )
105+ internal ( string ConsoleAppProgramCSFileContent , string ConsoleAppProjectFileContent , string modelBuilderCSFileContent ) GenerateConsoleAppProjectContents ( string namespaceValue , Type labelTypeCsharp , bool includeLightGbmPackage , bool includeMklComponentsPackage , bool includeFastTreePackage )
99106 {
100107 var predictProgramCSFileContent = GeneratePredictProgramCSFileContent ( namespaceValue ) ;
101108 predictProgramCSFileContent = Utils . FormatCode ( predictProgramCSFileContent ) ;
102109
103- var predictProjectFileContent = GeneratPredictProjectFileContent ( namespaceValue , includeLightGbmPackage , includeMklComponentsPackage ) ;
110+ var predictProjectFileContent = GeneratPredictProjectFileContent ( namespaceValue , includeLightGbmPackage , includeMklComponentsPackage , includeFastTreePackage ) ;
104111
105112 var transformsAndTrainers = GenerateTransformsAndTrainers ( ) ;
106113 var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent ( transformsAndTrainers . Usings , transformsAndTrainers . TrainerMethod , transformsAndTrainers . PreTrainerTransforms , transformsAndTrainers . PostTrainerTransforms , namespaceValue , pipeline . CacheBeforeTrainer , labelTypeCsharp . Name ) ;
@@ -109,14 +116,14 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
109116 return ( predictProgramCSFileContent , predictProjectFileContent , modelBuilderCSFileContent ) ;
110117 }
111118
112- internal ( string ObservationCSFileContent , string PredictionCSFileContent , string ModelProjectFileContent ) GenerateModelProjectContents ( string namespaceValue , Type labelTypeCsharp , bool includeLightGbmPackage , bool includeMklComponentsPackage )
119+ internal ( string ObservationCSFileContent , string PredictionCSFileContent , string ModelProjectFileContent ) GenerateModelProjectContents ( string namespaceValue , Type labelTypeCsharp , bool includeLightGbmPackage , bool includeMklComponentsPackage , bool includeFastTreePackage )
113120 {
114121 var classLabels = this . GenerateClassLabels ( ) ;
115122 var observationCSFileContent = GenerateObservationCSFileContent ( namespaceValue , classLabels ) ;
116123 observationCSFileContent = Utils . FormatCode ( observationCSFileContent ) ;
117124 var predictionCSFileContent = GeneratePredictionCSFileContent ( labelTypeCsharp . Name , namespaceValue ) ;
118125 predictionCSFileContent = Utils . FormatCode ( predictionCSFileContent ) ;
119- var modelProjectFileContent = GenerateModelProjectFileContent ( includeLightGbmPackage , includeMklComponentsPackage ) ;
126+ var modelProjectFileContent = GenerateModelProjectFileContent ( includeLightGbmPackage , includeMklComponentsPackage , includeFastTreePackage ) ;
120127 return ( observationCSFileContent , predictionCSFileContent , modelProjectFileContent ) ;
121128 }
122129
@@ -248,9 +255,9 @@ internal IList<string> GenerateClassLabels()
248255 }
249256
250257 #region Model project
251- private static string GenerateModelProjectFileContent ( bool includeLightGbmPackage , bool includeMklComponentsPackage )
258+ private static string GenerateModelProjectFileContent ( bool includeLightGbmPackage , bool includeMklComponentsPackage , bool includeFastTreePackage )
252259 {
253- ModelProject modelProject = new ModelProject ( ) { IncludeLightGBMPackage = includeLightGbmPackage , IncludeMklComponentsPackage = includeMklComponentsPackage } ;
260+ ModelProject modelProject = new ModelProject ( ) { IncludeLightGBMPackage = includeLightGbmPackage , IncludeMklComponentsPackage = includeMklComponentsPackage , IncludeFastTreePackage = includeFastTreePackage } ;
254261 return modelProject . TransformText ( ) ;
255262 }
256263
@@ -268,9 +275,9 @@ private string GenerateObservationCSFileContent(string namespaceValue, IList<str
268275 #endregion
269276
270277 #region Predict Project
271- private static string GeneratPredictProjectFileContent ( string namespaceValue , bool includeLightGbmPackage , bool includeMklComponentsPackage )
278+ private static string GeneratPredictProjectFileContent ( string namespaceValue , bool includeLightGbmPackage , bool includeMklComponentsPackage , bool includeFastTreePackage )
272279 {
273- var predictProjectFileContent = new PredictProject ( ) { Namespace = namespaceValue , IncludeMklComponentsPackage = includeMklComponentsPackage , IncludeLightGBMPackage = includeLightGbmPackage } ;
280+ var predictProjectFileContent = new PredictProject ( ) { Namespace = namespaceValue , IncludeMklComponentsPackage = includeMklComponentsPackage , IncludeLightGBMPackage = includeLightGbmPackage , IncludeFastTreePackage = includeFastTreePackage } ;
274281 return predictProjectFileContent . TransformText ( ) ;
275282 }
276283
0 commit comments