Skip to content

Commit 549b389

Browse files
authored
Added onnx export support for CopyColumns (#4486)
1 parent 9af92a4 commit 549b389

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

src/Microsoft.ML.Data/Transforms/ColumnCopying.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
197197
private readonly DataViewSchema _schema;
198198
private readonly (string outputColumnName, string inputColumnName)[] _columns;
199199

200-
public bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental;
200+
public bool CanSaveOnnx(OnnxContext ctx) => true;
201201

202202
internal Mapper(ColumnCopyingTransformer parent, DataViewSchema inputSchema, (string outputColumnName, string inputColumnName)[] columns)
203203
: base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
@@ -233,15 +233,16 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
233233

234234
public void SaveAsOnnx(OnnxContext ctx)
235235
{
236-
var opType = "CSharp";
236+
var opType = "Identity";
237237

238238
foreach (var column in _columns)
239239
{
240240
var srcVariableName = ctx.GetVariableName(column.inputColumnName);
241+
if (!ctx.ContainsColumn(srcVariableName))
242+
continue;
241243
_schema.TryGetColumnIndex(column.inputColumnName, out int colIndex);
242244
var dstVariableName = ctx.AddIntermediateVariable(_schema[colIndex].Type, column.outputColumnName);
243-
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
244-
node.AddAttribute("type", LoaderSignature);
245+
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
245246
}
246247
}
247248
}

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,6 +1187,39 @@ void MulticlassTrainersOnnxConversionTest()
11871187
Done();
11881188
}
11891189

1190+
[Fact]
1191+
void CopyColumnsOnnxTest()
1192+
{
1193+
var mlContext = new MLContext(seed: 1);
1194+
1195+
var trainDataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
1196+
var dataView = mlContext.Data.LoadFromTextFile<AdultData>(trainDataPath,
1197+
separatorChar: ';',
1198+
hasHeader: true);
1199+
1200+
var pipeline = mlContext.Transforms.CopyColumns("Target1", "Target");
1201+
var model = pipeline.Fit(dataView);
1202+
var transformedData = model.Transform(dataView);
1203+
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
1204+
1205+
var onnxFileName = "copycolumns.onnx";
1206+
var onnxModelPath = GetOutputPath(onnxFileName);
1207+
1208+
SaveOnnxModel(onnxModel, onnxModelPath, null);
1209+
1210+
if (IsOnnxRuntimeSupported())
1211+
{
1212+
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
1213+
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
1214+
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
1215+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
1216+
var onnxTransformer = onnxEstimator.Fit(dataView);
1217+
var onnxResult = onnxTransformer.Transform(dataView);
1218+
CompareSelectedR4ScalarColumns(model.ColumnPairs[0].outputColumnName, outputNames[2], transformedData, onnxResult);
1219+
}
1220+
Done();
1221+
}
1222+
11901223
private void CreateDummyExamplesToMakeComplierHappy()
11911224
{
11921225
var dummyExample = new BreastCancerFeatureVector() { Features = null };

0 commit comments

Comments
 (0)