1- using System . Linq ;
1+ using System . Collections . Generic ;
2+ using System . IO ;
3+ using System . Linq ;
4+ using System . Runtime . InteropServices ;
5+ using System . Text . RegularExpressions ;
26using Google . Protobuf ;
37using Microsoft . ML . Data ;
48using Microsoft . ML . Model . Onnx ;
59using Microsoft . ML . RunTests ;
610using Microsoft . ML . Transforms ;
11+ using Microsoft . ML . UniversalModelFormat . Onnx ;
12+ using Newtonsoft . Json ;
713using Xunit ;
814using Xunit . Abstractions ;
915
@@ -32,6 +38,9 @@ public OnnxConversionTest(ITestOutputHelper output) : base(output)
3238 [ Fact ]
3339 public void SimpleEndToEndOnnxConversionTest ( )
3440 {
41+ if ( ! RuntimeInformation . IsOSPlatform ( OSPlatform . Windows ) )
42+ return ;
43+
3544 // Step 1: Create and train a ML.NET pipeline.
3645 var trainDataPath = GetDataPath ( TestDatasets . generatedRegressionDataset . trainFilename ) ;
3746 var mlContext = new MLContext ( ) ;
@@ -50,34 +59,150 @@ public void SimpleEndToEndOnnxConversionTest()
5059 // Step 2: Convert ML.NET model to ONNX format and save it as a file.
5160 var onnxModel = TransformerChainOnnxConverter . Convert ( model , data ) ;
5261 var onnxFileName = "model.onnx" ;
53- var onnxFilePath = GetOutputPath ( onnxFileName ) ;
54- using ( var file = ( mlContext as IHostEnvironment ) . CreateOutputFile ( onnxFilePath ) )
55- using ( var stream = file . CreateWriteStream ( ) )
56- onnxModel . WriteTo ( stream ) ;
62+ var onnxModelPath = GetOutputPath ( onnxFileName ) ;
63+ SaveOnnxModel ( onnxModel , onnxModelPath , null ) ;
5764
5865 // Step 3: Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
5966 string [ ] inputNames = onnxModel . Graph . Input . Select ( valueInfoProto => valueInfoProto . Name ) . ToArray ( ) ;
6067 string [ ] outputNames = onnxModel . Graph . Output . Select ( valueInfoProto => valueInfoProto . Name ) . ToArray ( ) ;
61- var onnxEstimator = new OnnxScoringEstimator ( mlContext , onnxFilePath , inputNames , outputNames ) ;
68+ var onnxEstimator = new OnnxScoringEstimator ( mlContext , onnxModelPath , inputNames , outputNames ) ;
6269 var onnxTransformer = onnxEstimator . Fit ( data ) ;
6370 var onnxResult = onnxTransformer . Transform ( data ) ;
6471
6572 // Step 4: Compare ONNX and ML.NET results.
66- using ( var expectedCursor = transformedData . GetRowCursor ( columnIndex => columnIndex == transformedData . Schema [ "Score" ] . Index ) )
67- using ( var actualCursor = onnxResult . GetRowCursor ( columnIndex => columnIndex == onnxResult . Schema [ "Score0" ] . Index ) )
73+ CompareSelectedR4ScalarColumns ( "Score" , "Score0" , transformedData , onnxResult , 2 ) ;
74+ Done ( ) ;
75+ }
76+
77+ private class BreastCancerFeatureVector
78+ {
79+ [ LoadColumn ( 1 , 9 ) , VectorType ( 9 ) ]
80+ public float [ ] Features ;
81+ }
82+
83+ private void CreateDummyExamplesToMakeComplierHappy ( )
84+ {
85+ var dummyExample = new BreastCancerFeatureVector ( ) { Features = null } ;
86+ }
87+
88+ [ Fact ]
89+ public void KmeansOnnxConversionTest ( )
90+ {
91+ if ( ! RuntimeInformation . IsOSPlatform ( OSPlatform . Windows ) )
92+ return ;
93+
94+ // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
95+ // as a catalog of available operations and as the source of randomness.
96+ var mlContext = new MLContext ( seed : 1 ) ;
97+
98+ string dataPath = GetDataPath ( "breast-cancer.txt" ) ;
99+ // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
100+ var data = mlContext . Data . ReadFromTextFile < BreastCancerFeatureVector > ( dataPath ,
101+ hasHeader : true ,
102+ separatorChar : '\t ' ) ;
103+
104+ var pipeline = mlContext . Transforms . Normalize ( "Features" ) .
105+ Append ( mlContext . Clustering . Trainers . KMeans ( features : "Features" , advancedSettings : settings =>
106+ {
107+ settings . MaxIterations = 1 ;
108+ settings . K = 4 ;
109+ settings . NumThreads = 1 ;
110+ settings . InitAlgorithm = Trainers . KMeans . KMeansPlusPlusTrainer . InitAlgorithm . KMeansPlusPlus ;
111+ } ) ) ;
112+
113+ var model = pipeline . Fit ( data ) ;
114+ var transformedData = model . Transform ( data ) ;
115+
116+ var onnxModel = TransformerChainOnnxConverter . Convert ( model , data ) ;
117+
118+ var onnxFileName = "model.onnx" ;
119+ var onnxModelPath = GetOutputPath ( onnxFileName ) ;
120+ SaveOnnxModel ( onnxModel , onnxModelPath , null ) ;
121+
122+ // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
123+ string [ ] inputNames = onnxModel . Graph . Input . Select ( valueInfoProto => valueInfoProto . Name ) . ToArray ( ) ;
124+ string [ ] outputNames = onnxModel . Graph . Output . Select ( valueInfoProto => valueInfoProto . Name ) . ToArray ( ) ;
125+ var onnxEstimator = new OnnxScoringEstimator ( mlContext , onnxModelPath , inputNames , outputNames ) ;
126+ var onnxTransformer = onnxEstimator . Fit ( data ) ;
127+ var onnxResult = onnxTransformer . Transform ( data ) ;
128+
129+ CompareSelectedR4VectorColumns ( "Score" , "Score0" , transformedData , onnxResult , 3 ) ;
130+ Done ( ) ;
131+ }
132+
133+ private void CompareSelectedR4VectorColumns ( string leftColumnName , string rightColumnName , IDataView left , IDataView right , int precision = 6 )
134+ {
135+ var leftColumnIndex = left . Schema [ leftColumnName ] . Index ;
136+ var rightColumnIndex = right . Schema [ rightColumnName ] . Index ;
137+
138+ using ( var expectedCursor = left . GetRowCursor ( columnIndex => leftColumnIndex == columnIndex ) )
139+ using ( var actualCursor = right . GetRowCursor ( columnIndex => rightColumnIndex == columnIndex ) )
140+ {
141+ VBuffer < float > expected = default ;
142+ VBuffer < float > actual = default ;
143+ var expectedGetter = expectedCursor . GetGetter < VBuffer < float > > ( leftColumnIndex ) ;
144+ var actualGetter = actualCursor . GetGetter < VBuffer < float > > ( rightColumnIndex ) ;
145+ while ( expectedCursor . MoveNext ( ) && actualCursor . MoveNext ( ) )
146+ {
147+ expectedGetter ( ref expected ) ;
148+ actualGetter ( ref actual ) ;
149+
150+ Assert . Equal ( expected . Length , actual . Length ) ;
151+ for ( int i = 0 ; i < expected . Length ; ++ i )
152+ Assert . Equal ( expected . GetItemOrDefault ( i ) , actual . GetItemOrDefault ( i ) , precision ) ;
153+ }
154+ }
155+ }
156+
157+ private void CompareSelectedR4ScalarColumns ( string leftColumnName , string rightColumnName , IDataView left , IDataView right , int precision = 6 )
158+ {
159+ var leftColumnIndex = left . Schema [ leftColumnName ] . Index ;
160+ var rightColumnIndex = right . Schema [ rightColumnName ] . Index ;
161+
162+ using ( var expectedCursor = left . GetRowCursor ( columnIndex => leftColumnIndex == columnIndex ) )
163+ using ( var actualCursor = right . GetRowCursor ( columnIndex => rightColumnIndex == columnIndex ) )
68164 {
69165 float expected = default ;
70166 VBuffer < float > actual = default ;
71- var expectedGetter = expectedCursor . GetGetter < float > ( transformedData . Schema [ "Score" ] . Index ) ;
72- var actualGetter = actualCursor . GetGetter < VBuffer < float > > ( onnxResult . Schema [ "Score0" ] . Index ) ;
73- while ( expectedCursor . MoveNext ( ) && actualCursor . MoveNext ( ) )
167+ var expectedGetter = expectedCursor . GetGetter < float > ( leftColumnIndex ) ;
168+ var actualGetter = actualCursor . GetGetter < VBuffer < float > > ( rightColumnIndex ) ;
169+ while ( expectedCursor . MoveNext ( ) && actualCursor . MoveNext ( ) )
74170 {
75171 expectedGetter ( ref expected ) ;
76172 actualGetter ( ref actual ) ;
77173
78- Assert . Equal ( expected , actual . GetValues ( ) [ 0 ] , 1 ) ;
174+ // Scalar such as R4 (float) is converted to [1, 1]-tensor in ONNX format for consitency of making batch prediction.
175+ Assert . Equal ( 1 , actual . Length ) ;
176+ Assert . Equal ( expected , actual . GetItemOrDefault ( 0 ) , precision ) ;
79177 }
80178 }
81179 }
180+
181+ private void SaveOnnxModel ( ModelProto model , string binaryFormatPath , string textFormatPath )
182+ {
183+ DeleteOutputPath ( binaryFormatPath ) ; // Clean if such a file exists.
184+ DeleteOutputPath ( textFormatPath ) ;
185+
186+ if ( binaryFormatPath != null )
187+ using ( var file = Env . CreateOutputFile ( binaryFormatPath ) )
188+ using ( var stream = file . CreateWriteStream ( ) )
189+ model . WriteTo ( stream ) ;
190+
191+ if ( textFormatPath != null )
192+ {
193+ using ( var file = Env . CreateOutputFile ( textFormatPath ) )
194+ using ( var stream = file . CreateWriteStream ( ) )
195+ using ( var writer = new StreamWriter ( stream ) )
196+ {
197+ var parsedJson = JsonConvert . DeserializeObject ( model . ToString ( ) ) ;
198+ writer . Write ( JsonConvert . SerializeObject ( parsedJson , Formatting . Indented ) ) ;
199+ }
200+
201+ // Strip the version information.
202+ var fileText = File . ReadAllText ( textFormatPath ) ;
203+ fileText = Regex . Replace ( fileText , "\" producerVersion\" : \" ([^\" ]+)\" " , "\" producerVersion\" : \" ##VERSION##\" " ) ;
204+ File . WriteAllText ( textFormatPath , fileText ) ;
205+ }
206+ }
82207 }
83208}
0 commit comments