Skip to content

Commit 54a8856

Browse files
committed
Test Kmeans as well
1 parent 2cb004f commit 54a8856

File tree

2 files changed

+155
-16
lines changed

2 files changed

+155
-16
lines changed

src/Microsoft.ML.Onnx/TransformerChainOnnxConverter.cs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,27 @@ public static ModelProto Convert<T>(TransformerChain<T> chain, IDataView inputDa
1212
var env = new MLContext();
1313
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test", OnnxVersion.Stable);
1414
var outputData = chain.Transform(inputData);
15-
IDataView source = null;
16-
IDataView trueEnd = null;
15+
IDataView root = null;
16+
IDataView sink = null;
1717
LinkedList<ITransformCanSaveOnnx> transforms = null;
1818
using (var ch = (env as IChannelProvider).Start("ONNX conversion"))
19-
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out source, out trueEnd, out transforms);
19+
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out root, out sink, out transforms);
2020

21-
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, source, trueEnd, transforms, inputColumnNamesToDrop, outputColumnNamesToDrop);
21+
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, root, sink, transforms, inputColumnNamesToDrop, outputColumnNamesToDrop);
22+
}
23+
24+
public static ModelProto Convert(ITransformer transform, IDataView inputData, HashSet<string> inputColumnNamesToDrop=null, HashSet<string> outputColumnNamesToDrop=null)
25+
{
26+
var env = new MLContext(seed: 1);
27+
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test", OnnxVersion.Stable);
28+
var outputData = transform.Transform(inputData);
29+
IDataView root = null;
30+
IDataView sink = null;
31+
LinkedList<ITransformCanSaveOnnx> transforms = null;
32+
using (var ch = (env as IChannelProvider).Start("ONNX conversion"))
33+
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out root, out sink, out transforms);
34+
35+
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, root, sink, transforms, inputColumnNamesToDrop, outputColumnNamesToDrop);
2236
}
2337
}
2438
}

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 137 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
using System.Linq;
1+
using System.Collections.Generic;
2+
using System.IO;
3+
using System.Linq;
4+
using System.Runtime.InteropServices;
5+
using System.Text.RegularExpressions;
26
using Google.Protobuf;
37
using Microsoft.ML.Data;
48
using Microsoft.ML.Model.Onnx;
59
using Microsoft.ML.RunTests;
610
using Microsoft.ML.Transforms;
11+
using Microsoft.ML.UniversalModelFormat.Onnx;
12+
using Newtonsoft.Json;
713
using Xunit;
814
using Xunit.Abstractions;
915

@@ -32,6 +38,9 @@ public OnnxConversionTest(ITestOutputHelper output) : base(output)
3238
[Fact]
3339
public void SimpleEndToEndOnnxConversionTest()
3440
{
41+
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
42+
return;
43+
3544
// Step 1: Create and train a ML.NET pipeline.
3645
var trainDataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
3746
var mlContext = new MLContext();
@@ -50,34 +59,150 @@ public void SimpleEndToEndOnnxConversionTest()
5059
// Step 2: Convert ML.NET model to ONNX format and save it as a file.
5160
var onnxModel = TransformerChainOnnxConverter.Convert(model, data);
5261
var onnxFileName = "model.onnx";
53-
var onnxFilePath = GetOutputPath(onnxFileName);
54-
using (var file = (mlContext as IHostEnvironment).CreateOutputFile(onnxFilePath))
55-
using (var stream = file.CreateWriteStream())
56-
onnxModel.WriteTo(stream);
62+
var onnxModelPath = GetOutputPath(onnxFileName);
63+
SaveOnnxModel(onnxModel, onnxModelPath, null);
5764

5865
// Step 3: Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
5966
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
6067
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
61-
var onnxEstimator = new OnnxScoringEstimator(mlContext, onnxFilePath, inputNames, outputNames);
68+
var onnxEstimator = new OnnxScoringEstimator(mlContext, onnxModelPath, inputNames, outputNames);
6269
var onnxTransformer = onnxEstimator.Fit(data);
6370
var onnxResult = onnxTransformer.Transform(data);
6471

6572
// Step 4: Compare ONNX and ML.NET results.
66-
using (var expectedCursor = transformedData.GetRowCursor(columnIndex => columnIndex == transformedData.Schema["Score"].Index))
67-
using (var actualCursor = onnxResult.GetRowCursor(columnIndex => columnIndex == onnxResult.Schema["Score0"].Index))
73+
CompareSelectedR4ScalarColumns("Score", "Score0", transformedData, onnxResult, 2);
74+
Done();
75+
}
76+
77+
private class BreastCancerFeatureVector
78+
{
79+
[LoadColumn(1, 9), VectorType(9)]
80+
public float[] Features;
81+
}
82+
83+
private void CreateDummyExamplesToMakeComplierHappy()
84+
{
85+
var dummyExample = new BreastCancerFeatureVector() { Features = null };
86+
}
87+
88+
[Fact]
89+
public void KmeansOnnxConversionTest()
90+
{
91+
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
92+
return;
93+
94+
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
95+
// as a catalog of available operations and as the source of randomness.
96+
var mlContext = new MLContext(seed: 1);
97+
98+
string dataPath = GetDataPath("breast-cancer.txt");
99+
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
100+
var data = mlContext.Data.ReadFromTextFile<BreastCancerFeatureVector>(dataPath,
101+
hasHeader: true,
102+
separatorChar: '\t' );
103+
104+
var pipeline = mlContext.Transforms.Normalize("Features").
105+
Append(mlContext.Clustering.Trainers.KMeans(features: "Features", advancedSettings: settings =>
106+
{
107+
settings.MaxIterations = 1;
108+
settings.K = 4;
109+
settings.NumThreads = 1;
110+
settings.InitAlgorithm = Trainers.KMeans.KMeansPlusPlusTrainer.InitAlgorithm.KMeansPlusPlus;
111+
}));
112+
113+
var model = pipeline.Fit(data);
114+
var transformedData = model.Transform(data);
115+
116+
var onnxModel = TransformerChainOnnxConverter.Convert(model, data);
117+
118+
var onnxFileName = "model.onnx";
119+
var onnxModelPath = GetOutputPath(onnxFileName);
120+
SaveOnnxModel(onnxModel, onnxModelPath, null);
121+
122+
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
123+
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
124+
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
125+
var onnxEstimator = new OnnxScoringEstimator(mlContext, onnxModelPath, inputNames, outputNames);
126+
var onnxTransformer = onnxEstimator.Fit(data);
127+
var onnxResult = onnxTransformer.Transform(data);
128+
129+
CompareSelectedR4VectorColumns("Score", "Score0", transformedData, onnxResult, 3);
130+
Done();
131+
}
132+
133+
private void CompareSelectedR4VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision=6)
134+
{
135+
var leftColumnIndex = left.Schema[leftColumnName].Index;
136+
var rightColumnIndex = right.Schema[rightColumnName].Index;
137+
138+
using (var expectedCursor = left.GetRowCursor(columnIndex => leftColumnIndex == columnIndex))
139+
using (var actualCursor = right.GetRowCursor(columnIndex => rightColumnIndex == columnIndex))
140+
{
141+
VBuffer<float> expected = default;
142+
VBuffer<float> actual = default;
143+
var expectedGetter = expectedCursor.GetGetter<VBuffer<float>>(leftColumnIndex);
144+
var actualGetter = actualCursor.GetGetter<VBuffer<float>>(rightColumnIndex);
145+
while (expectedCursor.MoveNext() && actualCursor.MoveNext())
146+
{
147+
expectedGetter(ref expected);
148+
actualGetter(ref actual);
149+
150+
Assert.Equal(expected.Length, actual.Length);
151+
for (int i = 0; i < expected.Length; ++i)
152+
Assert.Equal(expected.GetItemOrDefault(i), actual.GetItemOrDefault(i), precision);
153+
}
154+
}
155+
}
156+
157+
private void CompareSelectedR4ScalarColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision=6)
158+
{
159+
var leftColumnIndex = left.Schema[leftColumnName].Index;
160+
var rightColumnIndex = right.Schema[rightColumnName].Index;
161+
162+
using (var expectedCursor = left.GetRowCursor(columnIndex => leftColumnIndex == columnIndex))
163+
using (var actualCursor = right.GetRowCursor(columnIndex => rightColumnIndex == columnIndex))
68164
{
69165
float expected = default;
70166
VBuffer<float> actual = default;
71-
var expectedGetter = expectedCursor.GetGetter<float>(transformedData.Schema["Score"].Index);
72-
var actualGetter = actualCursor.GetGetter<VBuffer<float>>(onnxResult.Schema["Score0"].Index);
73-
while(expectedCursor.MoveNext() && actualCursor.MoveNext())
167+
var expectedGetter = expectedCursor.GetGetter<float>(leftColumnIndex);
168+
var actualGetter = actualCursor.GetGetter<VBuffer<float>>(rightColumnIndex);
169+
while (expectedCursor.MoveNext() && actualCursor.MoveNext())
74170
{
75171
expectedGetter(ref expected);
76172
actualGetter(ref actual);
77173

78-
Assert.Equal(expected, actual.GetValues()[0], 1);
174+
// Scalar such as R4 (float) is converted to [1, 1]-tensor in ONNX format for consitency of making batch prediction.
175+
Assert.Equal(1, actual.Length);
176+
Assert.Equal(expected, actual.GetItemOrDefault(0), precision);
79177
}
80178
}
81179
}
180+
181+
private void SaveOnnxModel(ModelProto model, string binaryFormatPath, string textFormatPath)
182+
{
183+
DeleteOutputPath(binaryFormatPath); // Clean if such a file exists.
184+
DeleteOutputPath(textFormatPath);
185+
186+
if (binaryFormatPath != null)
187+
using (var file = Env.CreateOutputFile(binaryFormatPath))
188+
using (var stream = file.CreateWriteStream())
189+
model.WriteTo(stream);
190+
191+
if (textFormatPath != null)
192+
{
193+
using (var file = Env.CreateOutputFile(textFormatPath))
194+
using (var stream = file.CreateWriteStream())
195+
using (var writer = new StreamWriter(stream))
196+
{
197+
var parsedJson = JsonConvert.DeserializeObject(model.ToString());
198+
writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented));
199+
}
200+
201+
// Strip the version information.
202+
var fileText = File.ReadAllText(textFormatPath);
203+
fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\"");
204+
File.WriteAllText(textFormatPath, fileText);
205+
}
206+
}
82207
}
83208
}

0 commit comments

Comments
 (0)