55using System ;
66using System . Collections . Generic ;
77using System . IO ;
8- using System . Linq ;
98using Microsoft . Data . DataView ;
109using Microsoft . ML . Auto ;
1110using Microsoft . ML . Core . Data ;
1211using Microsoft . ML . Data ;
12+ using mlnet ;
13+ using mlnet . Utilities ;
14+ using NLog ;
1315
1416namespace Microsoft . ML . CLI
1517{
1618 internal class NewCommand
1719 {
1820 private Options options ;
21+ private static Logger logger = LogManager . GetCurrentClassLogger ( ) ;
1922
2023 internal NewCommand ( Options options )
2124 {
@@ -26,12 +29,13 @@ internal void Run()
2629 {
2730 if ( options . MlTask == TaskKind . MulticlassClassification )
2831 {
29- Console . WriteLine ( $ "Unsupported ml-task : { options . MlTask } ") ;
32+ Console . WriteLine ( $ "{ Strings . UnsupportedMlTask } : { options . MlTask } ") ;
3033 }
3134
3235 var context = new MLContext ( ) ;
3336
3437 //Check what overload method of InferColumns needs to be called.
38+ logger . Log ( LogLevel . Info , Strings . InferColumns ) ;
3539 ( TextLoader . Arguments TextLoaderArgs , IEnumerable < ( string Name , ColumnPurpose Purpose ) > ColumnPurpopses ) columnInference = default ( ( TextLoader . Arguments TextLoaderArgs , IEnumerable < ( string Name , ColumnPurpose Purpose ) > ColumnPurpopses ) ) ;
3640 if ( options . LabelName != null )
3741 {
@@ -42,50 +46,80 @@ internal void Run()
4246 columnInference = context . Data . InferColumns ( options . TrainDataset . FullName , options . LabelIndex , groupColumns : false ) ;
4347 }
4448
49+ logger . Log ( LogLevel . Info , Strings . CreateDataLoader ) ;
4550 var textLoader = context . Data . CreateTextLoader ( columnInference . TextLoaderArgs ) ;
4651
52+ logger . Log ( LogLevel . Info , Strings . LoadData ) ;
4753 IDataView trainData = textLoader . Read ( options . TrainDataset . FullName ) ;
4854 IDataView validationData = options . ValidationDataset == null ? null : textLoader . Read ( options . ValidationDataset . FullName ) ;
4955
5056 //Explore the models
51- Pipeline pipeline = null ;
52- var result = ExploreModels ( context , trainData , validationData , pipeline ) ;
57+ ( Pipeline , ITransformer ) result = default ;
58+ Console . WriteLine ( $ "{ Strings . ExplorePipeline } : { options . MlTask } ") ;
59+ try
60+ {
61+ result = ExploreModels ( context , trainData , validationData ) ;
62+ }
63+ catch ( Exception e )
64+ {
65+ logger . Log ( LogLevel . Error , $ "{ Strings . ExplorePipelineException } :") ;
66+ logger . Log ( LogLevel . Error , e . StackTrace ) ;
67+ logger . Log ( LogLevel . Error , Strings . Exiting ) ;
68+ return ;
69+ }
5370
5471 //Get the best pipeline
72+ Pipeline pipeline = null ;
5573 pipeline = result . Item1 ;
5674 var model = result . Item2 ;
5775
76+ //Save the model
77+ logger . Log ( LogLevel . Info , Strings . SavingBestModel ) ;
78+ var modelPath = Path . Combine ( @options . OutputBaseDir , options . OutputName ) ;
79+ SaveModel ( model , modelPath , $ "{ options . OutputName } _model.zip", context ) ;
80+
81+
5882 //Generate code
59- var codeGenerator = new CodeGenerator ( pipeline , columnInference , new CodeGeneratorOptions ( ) { TrainDataset = options . TrainDataset , MlTask = options . MlTask , TestDataset = options . TestDataset } ) ;
83+ logger . Log ( LogLevel . Info , Strings . GenerateProject ) ;
84+ var codeGenerator = new CodeGenerator (
85+ pipeline ,
86+ columnInference ,
87+ new CodeGeneratorOptions ( )
88+ {
89+ TrainDataset = options . TrainDataset ,
90+ MlTask = options . MlTask ,
91+ TestDataset = options . TestDataset ,
92+ OutputName = options . OutputName ,
93+ OutputBaseDir = options . OutputBaseDir
94+ } ) ;
6095 codeGenerator . GenerateOutput ( ) ;
61-
62- //Save the model
63- SaveModel ( model , @"./BestModel" , "model.zip" , context ) ;
6496 }
6597
6698 private ( Pipeline , ITransformer ) ExploreModels (
6799 MLContext context ,
68100 IDataView trainData ,
69- IDataView validationData ,
70- Pipeline pipeline )
101+ IDataView validationData )
71102 {
72103 ITransformer model = null ;
73104 string label = options . LabelName ?? "Label" ; // It is guaranteed training dataview to have Label column
105+ Pipeline pipeline = null ;
74106
75107 if ( options . MlTask == TaskKind . BinaryClassification )
76108 {
77- var result = context . BinaryClassification . AutoFit ( trainData , label , validationData , options . Timeout ) ;
78- result = result . OrderByDescending ( t => t . Metrics . Accuracy ) . ToList ( ) ;
79- var bestIteration = result . FirstOrDefault ( ) ;
109+ var progressReporter = new ProgressHandlers . BinaryClassificationHandler ( ) ;
110+ var result = context . BinaryClassification . AutoFit ( trainData , label , validationData , options . Timeout , progressCallback : progressReporter ) ;
111+ logger . Log ( LogLevel . Info , Strings . RetrieveBestPipeline ) ;
112+ var bestIteration = result . Best ( ) ;
80113 pipeline = bestIteration . Pipeline ;
81114 model = bestIteration . Model ;
82115 }
83116
84117 if ( options . MlTask == TaskKind . Regression )
85118 {
86- var result = context . Regression . AutoFit ( trainData , label , validationData , options . Timeout ) ;
87- result = result . OrderByDescending ( t => t . Metrics . RSquared ) . ToList ( ) ;
88- var bestIteration = result . FirstOrDefault ( ) ;
119+ var progressReporter = new ProgressHandlers . RegressionHandler ( ) ;
120+ var result = context . Regression . AutoFit ( trainData , label , validationData , options . Timeout , progressCallback : progressReporter ) ;
121+ logger . Log ( LogLevel . Info , Strings . RetrieveBestPipeline ) ;
122+ var bestIteration = result . Best ( ) ;
89123 pipeline = bestIteration . Pipeline ;
90124 model = bestIteration . Model ;
91125 }
@@ -105,7 +139,7 @@ private static void SaveModel(ITransformer model, string ModelPath, string model
105139 {
106140 Directory . CreateDirectory ( ModelPath ) ;
107141 }
108- ModelPath = ModelPath + "/" + modelName ;
142+ ModelPath = Path . Combine ( ModelPath , modelName ) ;
109143 using ( var fs = File . Create ( ModelPath ) )
110144 model . SaveTo ( mlContext , fs ) ;
111145 }
0 commit comments