Skip to content

Commit 6120c7f

Browse files
committed
1. Introduce ONNX conversion as an extention to MLContext
2. Address minor comments Remove two best friends
1 parent 54a8856 commit 6120c7f

File tree

5 files changed

+51
-48
lines changed

5 files changed

+51
-48
lines changed

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ public sealed class ModelOperationsCatalog
1717

1818
public ExplainabilityTransforms Explainability { get; }
1919

20+
public PortabilityTransforms Portability { get; }
21+
2022
internal ModelOperationsCatalog(IHostEnvironment env)
2123
{
2224
Contracts.AssertValue(env);
2325
Environment = env;
2426

2527
Explainability = new ExplainabilityTransforms(this);
28+
Portability = new PortabilityTransforms(this);
2629
}
2730

2831
public abstract class SubCatalogBase
@@ -33,7 +36,6 @@ protected SubCatalogBase(ModelOperationsCatalog owner)
3336
{
3437
Environment = owner.Environment;
3538
}
36-
3739
}
3840

3941
/// <summary>
@@ -60,6 +62,17 @@ internal ExplainabilityTransforms(ModelOperationsCatalog owner) : base(owner)
6062
}
6163
}
6264

65+
/// <summary>
66+
/// The catalog of model protability operations. Member function of this classes are able to convert the associated object to a protable format,
67+
/// so that the fitted pipeline can easily be depolyed to other platforms. Currently, the only supported format is ONNX (https://github.com/onnx/onnx).
68+
/// </summary>
69+
public sealed class PortabilityTransforms : SubCatalogBase
70+
{
71+
internal PortabilityTransforms(ModelOperationsCatalog owner) : base(owner)
72+
{
73+
}
74+
}
75+
6376
/// <summary>
6477
/// Create a prediction engine for one-time prediction.
6578
/// </summary>
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using System.Collections.Generic;
2+
using Microsoft.ML.Core.Data;
3+
using Microsoft.ML.Data;
4+
using Microsoft.ML.Model.Onnx;
5+
using Microsoft.ML.UniversalModelFormat.Onnx;
6+
7+
namespace Microsoft.ML
8+
{
9+
public static class ProtabilityCatalog
10+
{
11+
/// <summary>
12+
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
13+
/// </summary>
14+
/// <param name="catalog">A field in <see cref="MLContext"/> which this function associated with.</param>
15+
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
16+
/// <param name="inputData">The input of the specified transform.</param>
17+
/// <returns></returns>
18+
public static ModelProto ConvertToOnnx(this ModelOperationsCatalog.PortabilityTransforms catalog, ITransformer transform, IDataView inputData)
19+
{
20+
var env = new MLContext(seed: 1);
21+
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.microsoft", OnnxVersion.Stable);
22+
var outputData = transform.Transform(inputData);
23+
IDataView root = null;
24+
IDataView sink = null;
25+
LinkedList<ITransformCanSaveOnnx> transforms = null;
26+
using (var ch = (env as IChannelProvider).Start("ONNX conversion"))
27+
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out root, out sink, out transforms);
28+
29+
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, root, sink, transforms, null, null);
30+
}
31+
}
32+
}

src/Microsoft.ML.Onnx/SaveOnnxCommand.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ public override void Run()
114114
}
115115
}
116116

117-
[BestFriend]
118117
internal static void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataView source, out IDataView trueEnd, out LinkedList<ITransformCanSaveOnnx> transforms)
119118
{
120119
Contracts.AssertValue(end);
@@ -140,7 +139,6 @@ internal static void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, ou
140139
Contracts.AssertValue(source);
141140
}
142141

143-
[BestFriend]
144142
internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx, IDataView inputData, IDataView outputData,
145143
LinkedList<ITransformCanSaveOnnx> transforms, HashSet<string> inputColumnNamesToDrop=null, HashSet<string> outputColumnNamesToDrop=null)
146144
{

src/Microsoft.ML.Onnx/TransformerChainOnnxConverter.cs

Lines changed: 0 additions & 38 deletions
This file was deleted.

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
using System.Collections.Generic;
2-
using System.IO;
1+
using System.IO;
32
using System.Linq;
43
using System.Runtime.InteropServices;
54
using System.Text.RegularExpressions;
65
using Google.Protobuf;
76
using Microsoft.ML.Data;
8-
using Microsoft.ML.Model.Onnx;
97
using Microsoft.ML.RunTests;
108
using Microsoft.ML.Transforms;
119
using Microsoft.ML.UniversalModelFormat.Onnx;
@@ -43,7 +41,7 @@ public void SimpleEndToEndOnnxConversionTest()
4341

4442
// Step 1: Create and train a ML.NET pipeline.
4543
var trainDataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
46-
var mlContext = new MLContext();
44+
var mlContext = new MLContext(seed: 1, conc: 1);
4745
var data = mlContext.Data.ReadFromTextFile<AdultData>(trainDataPath,
4846
hasHeader: true,
4947
separatorChar: ';'
@@ -57,7 +55,7 @@ public void SimpleEndToEndOnnxConversionTest()
5755
var transformedData = model.Transform(data);
5856

5957
// Step 2: Convert ML.NET model to ONNX format and save it as a file.
60-
var onnxModel = TransformerChainOnnxConverter.Convert(model, data);
58+
var onnxModel = mlContext.Model.Portability.ConvertToOnnx(model, data);
6159
var onnxFileName = "model.onnx";
6260
var onnxModelPath = GetOutputPath(onnxFileName);
6361
SaveOnnxModel(onnxModel, onnxModelPath, null);
@@ -93,7 +91,7 @@ public void KmeansOnnxConversionTest()
9391

9492
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
9593
// as a catalog of available operations and as the source of randomness.
96-
var mlContext = new MLContext(seed: 1);
94+
var mlContext = new MLContext(seed: 1, conc: 1);
9795

9896
string dataPath = GetDataPath("breast-cancer.txt");
9997
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
@@ -113,7 +111,7 @@ public void KmeansOnnxConversionTest()
113111
var model = pipeline.Fit(data);
114112
var transformedData = model.Transform(data);
115113

116-
var onnxModel = TransformerChainOnnxConverter.Convert(model, data);
114+
var onnxModel = mlContext.Model.Portability.ConvertToOnnx(model, data);
117115

118116
var onnxFileName = "model.onnx";
119117
var onnxModelPath = GetOutputPath(onnxFileName);

0 commit comments

Comments
 (0)