Skip to content

Commit 80c903a

Browse files
committed
Prototype of new ONNX converter and an end-to-end test
1 parent 8a951c5 commit 80c903a

File tree

3 files changed

+123
-0
lines changed

3 files changed

+123
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using Microsoft.ML.Core.Data;
2+
using Microsoft.ML.Data;
3+
using Microsoft.ML.UniversalModelFormat.Onnx;
4+
5+
namespace Microsoft.ML.Model.Onnx
6+
{
7+
public class TransformerChainOnnxConverter
8+
{
9+
public static ModelProto Convert<T>(TransformerChain<T> chain, Schema inputSchema) where T : class, ITransformer
10+
{
11+
var env = new MLContext();
12+
var onnxContext = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test", Model.Onnx.OnnxVersion.Stable);
13+
14+
for (int i = 0; i < inputSchema.Count; i++)
15+
{
16+
string colName = inputSchema[i].Name;
17+
onnxContext.AddInputVariable(inputSchema[i].Type, colName);
18+
}
19+
20+
foreach (var t in chain)
21+
{
22+
var mapper = t.GetRowToRowMapper(inputSchema);
23+
inputSchema = t.GetOutputSchema(inputSchema);
24+
(mapper as ISaveAsOnnx).SaveAsOnnx(onnxContext);
25+
}
26+
27+
for (int i = 0; i < inputSchema.Count; ++i)
28+
{
29+
if (inputSchema[i].IsHidden)
30+
continue;
31+
32+
var idataviewColumnName = inputSchema[i].Name;
33+
34+
var variableName = onnxContext.TryGetVariableName(idataviewColumnName);
35+
var trueVariableName = onnxContext.AddIntermediateVariable(null, idataviewColumnName, true);
36+
onnxContext.CreateNode("Identity", variableName, trueVariableName, onnxContext.GetNodeName("Identity"), "");
37+
onnxContext.AddOutputVariable(inputSchema[i].Type, trueVariableName);
38+
}
39+
return onnxContext.MakeModel();
40+
}
41+
}
42+
}

test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
<ProjectReference Include="..\..\src\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
1515
<ProjectReference Include="..\..\src\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
1616
<ProjectReference Include="..\..\src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
17+
<ProjectReference Include="..\..\src\Microsoft.ML.OnnxTransform\Microsoft.ML.OnnxTransform.csproj" />
1718
<ProjectReference Include="..\..\src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
1819
<ProjectReference Include="..\..\src\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
1920
<ProjectReference Include="..\..\src\Microsoft.ML.Recommender\Microsoft.ML.Recommender.csproj" />
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using System.Collections.Generic;
2+
using System.Linq;
3+
using Google.Protobuf;
4+
using Microsoft.ML.Data;
5+
using Microsoft.ML.Model.Onnx;
6+
using Microsoft.ML.RunTests;
7+
using Microsoft.ML.Transforms;
8+
using Xunit;
9+
using Xunit.Abstractions;
10+
11+
namespace Microsoft.ML.Tests
12+
{
13+
public class OnnxConversionTest : BaseTestBaseline
14+
{
15+
private class AdultData
16+
{
17+
[LoadColumn(0, 10), ColumnName("FeatureVector")]
18+
public float Features { get; set; }
19+
20+
[LoadColumn(11)]
21+
public float Target { get; set; }
22+
}
23+
24+
public OnnxConversionTest(ITestOutputHelper output) : base(output)
25+
{
26+
}
27+
28+
[Fact]
29+
public void SimplePipelineOnnxConversionTest()
30+
{
31+
var trainDataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
32+
var mlContext = new MLContext();
33+
34+
var trainData = mlContext.Data.ReadFromTextFile<AdultData>(trainDataPath,
35+
hasHeader: true,
36+
separatorChar: ';'
37+
);
38+
39+
var cachedTrainData = mlContext.Data.Cache(trainData);
40+
41+
var dynamicPipeline =
42+
mlContext.Transforms.Normalize("FeatureVector")
43+
.AppendCacheCheckpoint(mlContext)
44+
.Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent(labelColumn: "Target", featureColumn: "FeatureVector"));
45+
46+
var model = dynamicPipeline.Fit(trainData);
47+
var transformedData = model.Transform(trainData);
48+
49+
var onnxModel = TransformerChainOnnxConverter.Convert(model, trainData.Schema);
50+
51+
var onnxFileName = "model.onnx";
52+
var onnxFilePath = GetOutputPath(onnxFileName);
53+
using (var file = (mlContext as IHostEnvironment).CreateOutputFile(onnxFilePath))
54+
using (var stream = file.CreateWriteStream())
55+
onnxModel.WriteTo(stream);
56+
57+
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
58+
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
59+
var onnxEstimator = new OnnxScoringEstimator(mlContext, onnxFilePath, inputNames, outputNames);
60+
var onnxTransformer = onnxEstimator.Fit(trainData);
61+
var onnxResult = onnxTransformer.Transform(trainData);
62+
63+
using (var expectedCursor = transformedData.GetRowCursor(columnIndex => columnIndex == transformedData.Schema["Score"].Index))
64+
using (var actualCursor = onnxResult.GetRowCursor(columnIndex => columnIndex == onnxResult.Schema["Score0"].Index))
65+
{
66+
float expected = default;
67+
VBuffer<float> actual = default;
68+
var expectedGetter = expectedCursor.GetGetter<float>(transformedData.Schema["Score"].Index);
69+
var actualGetter = actualCursor.GetGetter<VBuffer<float>>(onnxResult.Schema["Score0"].Index);
70+
while(expectedCursor.MoveNext() && actualCursor.MoveNext())
71+
{
72+
expectedGetter(ref expected);
73+
actualGetter(ref actual);
74+
75+
Assert.Equal(expected, actual.GetValues()[0], 1);
76+
}
77+
}
78+
}
79+
}
80+
}

0 commit comments

Comments
 (0)