Skip to content

Commit 2cb004f

Browse files
committed
Reuse existing code to do conversion and polish a test
1 parent 80c903a commit 2cb004f

File tree

3 files changed

+80
-85
lines changed

3 files changed

+80
-85
lines changed

src/Microsoft.ML.Onnx/SaveOnnxCommand.cs

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Microsoft.ML.EntryPoints;
1313
using Microsoft.ML.Internal.Utilities;
1414
using Microsoft.ML.Model.Onnx;
15+
using Microsoft.ML.UniversalModelFormat.Onnx;
1516
using Newtonsoft.Json;
1617

1718
[assembly: LoadableClass(SaveOnnxCommand.Summary, typeof(SaveOnnxCommand), typeof(SaveOnnxCommand.Arguments), typeof(SignatureCommand),
@@ -113,9 +114,11 @@ public override void Run()
113114
}
114115
}
115116

116-
private void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataView source, out IDataView trueEnd, out LinkedList<ITransformCanSaveOnnx> transforms)
117+
[BestFriend]
118+
internal static void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataView source, out IDataView trueEnd, out LinkedList<ITransformCanSaveOnnx> transforms)
117119
{
118-
Host.AssertValue(end);
120+
Contracts.AssertValue(end);
121+
119122
source = trueEnd = (end as CompositeDataLoader)?.View ?? end;
120123
IDataTransform transform = source as IDataTransform;
121124
transforms = new LinkedList<ITransformCanSaveOnnx>();
@@ -134,7 +137,51 @@ private void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataV
134137
transform = (source = transform.Source) as IDataTransform;
135138
}
136139

137-
Host.AssertValue(source);
140+
Contracts.AssertValue(source);
141+
}
142+
143+
[BestFriend]
144+
internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx, IDataView inputData, IDataView outputData,
145+
LinkedList<ITransformCanSaveOnnx> transforms, HashSet<string> inputColumnNamesToDrop=null, HashSet<string> outputColumnNamesToDrop=null)
146+
{
147+
inputColumnNamesToDrop = inputColumnNamesToDrop ?? new HashSet<string>();
148+
outputColumnNamesToDrop = outputColumnNamesToDrop ?? new HashSet<string>();
149+
HashSet<string> inputColumns = new HashSet<string>();
150+
// Create graph inputs.
151+
for (int i = 0; i < inputData.Schema.Count; i++)
152+
{
153+
string colName = inputData.Schema[i].Name;
154+
if(inputColumnNamesToDrop.Contains(colName))
155+
continue;
156+
157+
ctx.AddInputVariable(inputData.Schema[i].Type, colName);
158+
inputColumns.Add(colName);
159+
}
160+
161+
// Create graph nodes, outputs and intermediate values.
162+
foreach (var trans in transforms)
163+
trans.SaveAsOnnx(ctx);
164+
165+
// Add graph outputs.
166+
for (int i = 0; i < outputData.Schema.Count; ++i)
167+
{
168+
if (outputData.Schema[i].IsHidden)
169+
continue;
170+
171+
var idataviewColumnName = outputData.Schema[i].Name;
172+
173+
// Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
174+
// _inputToDrop should be removed too.
175+
if (inputColumnNamesToDrop.Contains(idataviewColumnName) || outputColumnNamesToDrop.Contains(idataviewColumnName))
176+
continue;
177+
178+
var variableName = ctx.TryGetVariableName(idataviewColumnName);
179+
var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true);
180+
ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
181+
ctx.AddOutputVariable(outputData.Schema[i].Type, trueVariableName);
182+
}
183+
184+
return ctx.MakeModel();
138185
}
139186

140187
private void Run(IChannel ch)
@@ -210,45 +257,8 @@ private void Run(IChannel ch)
210257
nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
211258
}
212259

213-
HashSet<string> inputColumns = new HashSet<string>();
214-
//Create graph inputs.
215-
for (int i = 0; i < source.Schema.Count; i++)
216-
{
217-
string colName = source.Schema[i].Name;
218-
if(_inputsToDrop.Contains(colName))
219-
continue;
220-
221-
ctx.AddInputVariable(source.Schema[i].Type, colName);
222-
inputColumns.Add(colName);
223-
}
224-
225-
//Create graph nodes, outputs and intermediate values.
226-
foreach (var trans in transforms)
227-
{
228-
Host.Assert(trans.CanSaveOnnx(ctx));
229-
trans.SaveAsOnnx(ctx);
230-
}
231-
232-
//Add graph outputs.
233-
for (int i = 0; i < end.Schema.Count; ++i)
234-
{
235-
if (end.Schema[i].IsHidden)
236-
continue;
237-
238-
var idataviewColumnName = end.Schema[i].Name;
239-
240-
// Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
241-
// _inputToDrop should be removed too.
242-
if (_inputsToDrop.Contains(idataviewColumnName) || _outputsToDrop.Contains(idataviewColumnName))
243-
continue;
244-
245-
var variableName = ctx.TryGetVariableName(idataviewColumnName);
246-
var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true);
247-
ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
248-
ctx.AddOutputVariable(end.Schema[i].Type, trueVariableName);
249-
}
260+
var model = ConvertTransformListToOnnxModel(ctx, source, end, transforms, _inputsToDrop, _outputsToDrop);
250261

251-
var model = ctx.MakeModel();
252262
using (var file = Host.CreateOutputFile(_outputModelPath))
253263
using (var stream = file.CreateWriteStream())
254264
model.WriteTo(stream);
Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,24 @@
1-
using Microsoft.ML.Core.Data;
1+
using System.Collections.Generic;
2+
using Microsoft.ML.Core.Data;
23
using Microsoft.ML.Data;
34
using Microsoft.ML.UniversalModelFormat.Onnx;
45

56
namespace Microsoft.ML.Model.Onnx
67
{
78
public class TransformerChainOnnxConverter
89
{
9-
public static ModelProto Convert<T>(TransformerChain<T> chain, Schema inputSchema) where T : class, ITransformer
10+
public static ModelProto Convert<T>(TransformerChain<T> chain, IDataView inputData, HashSet<string> inputColumnNamesToDrop=null, HashSet<string> outputColumnNamesToDrop=null) where T : class, ITransformer
1011
{
1112
var env = new MLContext();
12-
var onnxContext = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test", Model.Onnx.OnnxVersion.Stable);
13+
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test", OnnxVersion.Stable);
14+
var outputData = chain.Transform(inputData);
15+
IDataView source = null;
16+
IDataView trueEnd = null;
17+
LinkedList<ITransformCanSaveOnnx> transforms = null;
18+
using (var ch = (env as IChannelProvider).Start("ONNX conversion"))
19+
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out source, out trueEnd, out transforms);
1320

14-
for (int i = 0; i < inputSchema.Count; i++)
15-
{
16-
string colName = inputSchema[i].Name;
17-
onnxContext.AddInputVariable(inputSchema[i].Type, colName);
18-
}
19-
20-
foreach (var t in chain)
21-
{
22-
var mapper = t.GetRowToRowMapper(inputSchema);
23-
inputSchema = t.GetOutputSchema(inputSchema);
24-
(mapper as ISaveAsOnnx).SaveAsOnnx(onnxContext);
25-
}
26-
27-
for (int i = 0; i < inputSchema.Count; ++i)
28-
{
29-
if (inputSchema[i].IsHidden)
30-
continue;
31-
32-
var idataviewColumnName = inputSchema[i].Name;
33-
34-
var variableName = onnxContext.TryGetVariableName(idataviewColumnName);
35-
var trueVariableName = onnxContext.AddIntermediateVariable(null, idataviewColumnName, true);
36-
onnxContext.CreateNode("Identity", variableName, trueVariableName, onnxContext.GetNodeName("Identity"), "");
37-
onnxContext.AddOutputVariable(inputSchema[i].Type, trueVariableName);
38-
}
39-
return onnxContext.MakeModel();
21+
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, source, trueEnd, transforms, inputColumnNamesToDrop, outputColumnNamesToDrop);
4022
}
4123
}
4224
}

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using System.Collections.Generic;
2-
using System.Linq;
1+
using System.Linq;
32
using Google.Protobuf;
43
using Microsoft.ML.Data;
54
using Microsoft.ML.Model.Onnx;
@@ -25,41 +24,45 @@ public OnnxConversionTest(ITestOutputHelper output) : base(output)
2524
{
2625
}
2726

27+
/// <summary>
28+
/// In this test, we convert a trained <see cref="TransformerChain"/> into ONNX <see cref="UniversalModelFormat.Onnx.ModelProto"/> file and then
29+
/// call <see cref="OnnxScoringEstimator"/> to evaluate that file. The outputs of <see cref="OnnxScoringEstimator"/> are checked against the original
30+
/// ML.NET model's outputs.
31+
/// </summary>
2832
[Fact]
29-
public void SimplePipelineOnnxConversionTest()
33+
public void SimpleEndToEndOnnxConversionTest()
3034
{
35+
// Step 1: Create and train a ML.NET pipeline.
3136
var trainDataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
3237
var mlContext = new MLContext();
33-
34-
var trainData = mlContext.Data.ReadFromTextFile<AdultData>(trainDataPath,
38+
var data = mlContext.Data.ReadFromTextFile<AdultData>(trainDataPath,
3539
hasHeader: true,
3640
separatorChar: ';'
3741
);
38-
39-
var cachedTrainData = mlContext.Data.Cache(trainData);
40-
42+
var cachedTrainData = mlContext.Data.Cache(data);
4143
var dynamicPipeline =
4244
mlContext.Transforms.Normalize("FeatureVector")
4345
.AppendCacheCheckpoint(mlContext)
4446
.Append(mlContext.Regression.Trainers.StochasticDualCoordinateAscent(labelColumn: "Target", featureColumn: "FeatureVector"));
47+
var model = dynamicPipeline.Fit(data);
48+
var transformedData = model.Transform(data);
4549

46-
var model = dynamicPipeline.Fit(trainData);
47-
var transformedData = model.Transform(trainData);
48-
49-
var onnxModel = TransformerChainOnnxConverter.Convert(model, trainData.Schema);
50-
50+
// Step 2: Convert ML.NET model to ONNX format and save it as a file.
51+
var onnxModel = TransformerChainOnnxConverter.Convert(model, data);
5152
var onnxFileName = "model.onnx";
5253
var onnxFilePath = GetOutputPath(onnxFileName);
5354
using (var file = (mlContext as IHostEnvironment).CreateOutputFile(onnxFilePath))
5455
using (var stream = file.CreateWriteStream())
5556
onnxModel.WriteTo(stream);
5657

58+
// Step 3: Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
5759
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
5860
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
5961
var onnxEstimator = new OnnxScoringEstimator(mlContext, onnxFilePath, inputNames, outputNames);
60-
var onnxTransformer = onnxEstimator.Fit(trainData);
61-
var onnxResult = onnxTransformer.Transform(trainData);
62+
var onnxTransformer = onnxEstimator.Fit(data);
63+
var onnxResult = onnxTransformer.Transform(data);
6264

65+
// Step 4: Compare ONNX and ML.NET results.
6366
using (var expectedCursor = transformedData.GetRowCursor(columnIndex => columnIndex == transformedData.Schema["Score"].Index))
6467
using (var actualCursor = onnxResult.GetRowCursor(columnIndex => columnIndex == onnxResult.Schema["Score0"].Index))
6568
{

0 commit comments

Comments
 (0)