44
55using System ;
66using System . IO ;
7+ using System . IO . Compression ;
78using System . Linq ;
89using Microsoft . ML . Calibrators ;
910using Microsoft . ML . Data ;
11+ using Microsoft . ML . Functional . Tests . Datasets ;
1012using Microsoft . ML . RunTests ;
1113using Microsoft . ML . Trainers . FastTree ;
1214using Microsoft . ML . Transforms ;
1517
1618namespace Microsoft . ML . Functional . Tests
1719{
18- public partial class ModelLoadingTests : TestDataPipeBase
20+ public partial class ModelFiles : TestDataPipeBase
1921 {
20- public ModelLoadingTests ( ITestOutputHelper output ) : base ( output )
22+ public ModelFiles ( ITestOutputHelper output ) : base ( output )
2123 {
2224 }
2325
@@ -30,6 +32,101 @@ private class InputData
3032 public float [ ] Features { get ; set ; }
3133 }
3234
35+ /// <summary>
36+ /// Model Files: The (minimum) nuget version can be found in the model file.
37+ /// </summary>
38+ [ Fact ]
39+ public void DetermineNugetVersionFromModel ( )
40+ {
41+ var mlContext = new MLContext ( seed : 1 ) ;
42+
43+ // Get the dataset.
44+ var data = mlContext . Data . LoadFromTextFile < HousingRegression > ( GetDataPath ( TestDatasets . housing . trainFilename ) , hasHeader : true ) ;
45+
46+ // Create a pipeline to train on the housing data.
47+ var pipeline = mlContext . Transforms . Concatenate ( "Features" , HousingRegression . Features )
48+ . Append ( mlContext . Regression . Trainers . FastTree (
49+ new FastTreeRegressionTrainer . Options { NumberOfThreads = 1 , NumberOfTrees = 10 } ) ) ;
50+
51+ // Fit the pipeline.
52+ var model = pipeline . Fit ( data ) ;
53+
54+ // Save model to a file.
55+ var modelPath = DeleteOutputPath ( "determineNugetVersionFromModel.zip" ) ;
56+ mlContext . Model . Save ( model , data . Schema , modelPath ) ;
57+
58+ // Check that the version can be extracted from the model.
59+ var versionFileName = @"TrainingInfo" + Path . DirectorySeparatorChar + "Version.txt" ;
60+ using ( ZipArchive archive = ZipFile . OpenRead ( modelPath ) )
61+ {
62+ // The version of the entire model is kept in the version file.
63+ var versionPath = archive . Entries . First ( x => x . FullName == versionFileName ) ;
64+ Assert . NotNull ( versionPath ) ;
65+ using ( var stream = versionPath . Open ( ) )
66+ using ( var reader = new StreamReader ( stream ) )
67+ {
68+ // The only line in the file is the version of the model.
69+ var line = reader . ReadLine ( ) ;
70+ Assert . Equal ( @"1.0.0.0" , line ) ;
71+ }
72+ }
73+ }
74+
75+ /// <summary>
76+ /// Model Files: Save a model, including all transforms, then load and make predictions.
77+ /// </summary>
78+ /// <remarks>
79+ /// Serves two scenarios:
80+ /// 1. I can train a model and save it to a file, including transforms.
81+ /// 2. Training and prediction happen in different processes (or even different machines).
82+ /// The actual test will not run in different processes, but will simulate the idea that the
83+ /// "communication pipe" is just a serialized model of some form.
84+ /// </remarks>
85+ [ Fact ]
86+ public void FitPipelineSaveModelAndPredict ( )
87+ {
88+ var mlContext = new MLContext ( seed : 1 ) ;
89+
90+ // Get the dataset.
91+ var data = mlContext . Data . LoadFromTextFile < HousingRegression > ( GetDataPath ( TestDatasets . housing . trainFilename ) , hasHeader : true ) ;
92+
93+ // Create a pipeline to train on the housing data.
94+ var pipeline = mlContext . Transforms . Concatenate ( "Features" , HousingRegression . Features )
95+ . Append ( mlContext . Regression . Trainers . FastTree (
96+ new FastTreeRegressionTrainer . Options { NumberOfThreads = 1 , NumberOfTrees = 10 } ) ) ;
97+
98+ // Fit the pipeline.
99+ var model = pipeline . Fit ( data ) ;
100+
101+ var modelPath = DeleteOutputPath ( "fitPipelineSaveModelAndPredict.zip" ) ;
102+ // Save model to a file.
103+ mlContext . Model . Save ( model , data . Schema , modelPath ) ;
104+
105+ // Load model from a file.
106+ ITransformer serializedModel ;
107+ using ( var file = File . OpenRead ( modelPath ) )
108+ {
109+ serializedModel = mlContext . Model . Load ( file , out var serializedSchema ) ;
110+ CheckSameSchemas ( data . Schema , serializedSchema ) ;
111+ }
112+
113+ // Create prediction engine and test predictions.
114+ var originalPredictionEngine = mlContext . Model . CreatePredictionEngine < HousingRegression , ScoreColumn > ( model ) ;
115+ var serializedPredictionEngine = mlContext . Model . CreatePredictionEngine < HousingRegression , ScoreColumn > ( serializedModel ) ;
116+
117+ // Take a handful of examples out of the dataset and compute predictions.
118+ var dataEnumerator = mlContext . Data . CreateEnumerable < HousingRegression > ( mlContext . Data . TakeRows ( data , 5 ) , false ) ;
119+ foreach ( var row in dataEnumerator )
120+ {
121+ var originalPrediction = originalPredictionEngine . Predict ( row ) ;
122+ var serializedPrediction = serializedPredictionEngine . Predict ( row ) ;
123+ // Check that the predictions are identical.
124+ Assert . Equal ( originalPrediction . Score , serializedPrediction . Score ) ;
125+ }
126+
127+ Done ( ) ;
128+ }
129+
33130 [ Fact ]
34131 public void LoadModelAndExtractPredictor ( )
35132 {
0 commit comments