@@ -131,6 +131,14 @@ private class BreastCancerMulticlassExample
131131 [ LoadColumn ( 2 , 9 ) , VectorType ( 8 ) ]
132132 public float [ ] Features ;
133133 }
134+ private class BreastCancerBinaryClassification
135+ {
136+ [ LoadColumn ( 0 ) ]
137+ public bool Label ;
138+
139+ [ LoadColumn ( 2 , 9 ) , VectorType ( 8 ) ]
140+ public float [ ] Features ;
141+ }
134142
135143 [ LessThanNetCore30OrNotNetCoreFact ( "netcoreapp3.0 output differs from Baseline. Tracked by https://github.com/dotnet/machinelearning/issues/2087" ) ]
136144 public void KmeansOnnxConversionTest ( )
@@ -187,6 +195,55 @@ public void KmeansOnnxConversionTest()
187195 Done ( ) ;
188196 }
189197
198+ [ Fact ]
199+ public void binaryClassificationTrainersOnnxConversionTest ( )
200+ {
201+ var mlContext = new MLContext ( seed : 1 ) ;
202+ string dataPath = GetDataPath ( "breast-cancer.txt" ) ;
203+ // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
204+ var dataView = mlContext . Data . LoadFromTextFile < BreastCancerBinaryClassification > ( dataPath , separatorChar : '\t ' , hasHeader : true ) ;
205+ IEstimator < ITransformer > [ ] estimators = {
206+ mlContext . BinaryClassification . Trainers . SymbolicSgdLogisticRegression ( ) ,
207+ mlContext . BinaryClassification . Trainers . SgdCalibrated ( ) ,
208+ mlContext . BinaryClassification . Trainers . AveragedPerceptron ( ) ,
209+ mlContext . BinaryClassification . Trainers . FastForest ( ) ,
210+ mlContext . BinaryClassification . Trainers . LinearSvm ( ) ,
211+ mlContext . BinaryClassification . Trainers . SdcaNonCalibrated ( ) ,
212+ mlContext . BinaryClassification . Trainers . SgdNonCalibrated ( ) ,
213+ mlContext . BinaryClassification . Trainers . FastTree ( ) ,
214+ mlContext . BinaryClassification . Trainers . LbfgsLogisticRegression ( ) ,
215+ mlContext . BinaryClassification . Trainers . LightGbm ( ) ,
216+ mlContext . BinaryClassification . Trainers . SdcaLogisticRegression ( ) ,
217+ mlContext . BinaryClassification . Trainers . SgdCalibrated ( ) ,
218+ mlContext . BinaryClassification . Trainers . SymbolicSgdLogisticRegression ( ) ,
219+ } ;
220+ var initialPipeline = mlContext . Transforms . ReplaceMissingValues ( "Features" ) .
221+ Append ( mlContext . Transforms . NormalizeMinMax ( "Features" ) ) ;
222+ foreach ( var estimator in estimators )
223+ {
224+ var pipeline = initialPipeline . Append ( estimator ) ;
225+ var model = pipeline . Fit ( dataView ) ;
226+ var transformedData = model . Transform ( dataView ) ;
227+ var onnxModel = mlContext . Model . ConvertToOnnxProtobuf ( model , dataView ) ;
228+ // Compare model scores produced by ML.NET and ONNX's runtime.
229+ if ( IsOnnxRuntimeSupported ( ) )
230+ {
231+ var onnxFileName = $ "{ estimator . ToString ( ) } .onnx";
232+ var onnxModelPath = GetOutputPath ( onnxFileName ) ;
233+ SaveOnnxModel ( onnxModel , onnxModelPath , null ) ;
234+ // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
235+ string [ ] inputNames = onnxModel . Graph . Input . Select ( valueInfoProto => valueInfoProto . Name ) . ToArray ( ) ;
236+ string [ ] outputNames = onnxModel . Graph . Output . Select ( valueInfoProto => valueInfoProto . Name ) . ToArray ( ) ;
237+ var onnxEstimator = mlContext . Transforms . ApplyOnnxModel ( outputNames , inputNames , onnxModelPath ) ;
238+ var onnxTransformer = onnxEstimator . Fit ( dataView ) ;
239+ var onnxResult = onnxTransformer . Transform ( dataView ) ;
240+ CompareSelectedR4ScalarColumns ( transformedData . Schema [ 5 ] . Name , outputNames [ 3 ] , transformedData , onnxResult , 3 ) ;
241+ CompareSelectedScalarColumns < Boolean > ( transformedData . Schema [ 4 ] . Name , outputNames [ 2 ] , transformedData , onnxResult ) ;
242+ }
243+
244+ }
245+ Done ( ) ;
246+ }
190247 private class DataPoint
191248 {
192249 [ VectorType ( 3 ) ]
@@ -853,7 +910,8 @@ private void CreateDummyExamplesToMakeComplierHappy()
853910 var dummyExample = new BreastCancerFeatureVector ( ) { Features = null } ;
854911 var dummyExample1 = new BreastCancerCatFeatureExample ( ) { Label = false , F1 = 0 , F2 = "Amy" } ;
855912 var dummyExample2 = new BreastCancerMulticlassExample ( ) { Label = "Amy" , Features = null } ;
856- var dummyExample3 = new SmallSentimentExample ( ) { Tokens = null } ;
913+ var dummyExample3 = new BreastCancerBinaryClassification ( ) { Label = false , Features = null } ;
914+ var dummyExample4 = new SmallSentimentExample ( ) { Tokens = null } ;
857915 }
858916
859917 private void CompareResults ( string leftColumnName , string rightColumnName , IDataView left , IDataView right )
@@ -984,7 +1042,34 @@ private void CompareSelectedR4ScalarColumns(string leftColumnName, string rightC
9841042
9851043 // Scalar such as R4 (float) is converted to [1, 1]-tensor in ONNX format for consitency of making batch prediction.
9861044 Assert . Equal ( 1 , actual . Length ) ;
987- Assert . Equal ( expected , actual . GetItemOrDefault ( 0 ) , precision ) ;
1045+ CompareNumbersWithTolerance ( expected , actual . GetItemOrDefault ( 0 ) , null , precision ) ;
1046+ }
1047+ }
1048+ }
1049+ private void CompareSelectedScalarColumns < T > ( string leftColumnName , string rightColumnName , IDataView left , IDataView right )
1050+ {
1051+ var leftColumn = left . Schema [ leftColumnName ] ;
1052+ var rightColumn = right . Schema [ rightColumnName ] ;
1053+
1054+ using ( var expectedCursor = left . GetRowCursor ( leftColumn ) )
1055+ using ( var actualCursor = right . GetRowCursor ( rightColumn ) )
1056+ {
1057+ T expected = default ;
1058+ VBuffer < T > actual = default ;
1059+ var expectedGetter = expectedCursor . GetGetter < T > ( leftColumn ) ;
1060+ var actualGetter = actualCursor . GetGetter < VBuffer < T > > ( rightColumn ) ;
1061+ while ( expectedCursor . MoveNext ( ) && actualCursor . MoveNext ( ) )
1062+ {
1063+ expectedGetter ( ref expected ) ;
1064+ actualGetter ( ref actual ) ;
1065+ var actualVal = actual . GetItemOrDefault ( 0 ) ;
1066+
1067+ Assert . Equal ( 1 , actual . Length ) ;
1068+
1069+ if ( typeof ( T ) == typeof ( ReadOnlyMemory < Char > ) )
1070+ Assert . Equal ( expected . ToString ( ) , actualVal . ToString ( ) ) ;
1071+ else
1072+ Assert . Equal ( expected , actualVal ) ;
9881073 }
9891074 }
9901075 }
0 commit comments