Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion;
model.ModelVersion = modelVersion;
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 1 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 7 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 9 });
model.Graph = new GraphProto();
var graph = model.Graph;
graph.Node.Add(nodes);
Expand Down
70 changes: 69 additions & 1 deletion src/Microsoft.ML.PCA/PcaTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
Expand Down Expand Up @@ -511,7 +512,7 @@ internal static void ValidatePcaInput(IExceptionContext ectx, string name, DataV
throw ectx.ExceptSchemaMismatch(nameof(inputSchema), "input", name, "known-size vector of Single of two or more items", type.ToString());
}

private sealed class Mapper : OneToOneMapperBase
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
public sealed class ColumnSchemaInfo
{
Expand Down Expand Up @@ -596,6 +597,73 @@ private static void TransformFeatures(IExceptionContext ectx, in VBuffer<float>

dst = editor.Commit();
}

public bool CanSaveOnnx(OnnxContext ctx) => true;

public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));

for (int i = 0; i < _numColumns; i++)
{
var colPair = _parent.ColumnPairs[i];
var transformInfo = _parent._transformInfos[i];
string inputColumnName = colPair.inputColumnName;
string outputColumnName = colPair.outputColumnName;
if (!ctx.ContainsColumn(inputColumnName))
{
ctx.RemoveColumn(colPair.outputColumnName, false);
continue;
}

var dstVariableName = ctx.AddIntermediateVariable(transformInfo.OutputType, outputColumnName);
SaveAsOnnxCore(ctx, i, ctx.GetVariableName(inputColumnName), dstVariableName);
}
}

private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
Host.CheckValue(ctx, nameof(ctx));

TransformInfo transformInfo = _parent._transformInfos[iinfo];
ColumnSchemaInfo schemaInfo = _parent._schemaInfos[iinfo];

float[] principalComponents = new float[transformInfo.Rank * transformInfo.Dimension];
for (int i = 0; i < transformInfo.Rank; i++)
{
Array.Copy(transformInfo.Eigenvectors[i], 0, principalComponents, i * transformInfo.Dimension, transformInfo.Dimension);
}
long[] pcaDims = { transformInfo.Rank, transformInfo.Dimension };
var pcaMatrix = ctx.AddInitializer(principalComponents, pcaDims, "principalComponents");

float[] zeroMean = new float[transformInfo.Rank];
if (transformInfo.MeanProjected != null)
{
Array.Copy(transformInfo.MeanProjected, zeroMean, transformInfo.Rank);
}

long[] meanDims = { transformInfo.Rank };
var zeroMeanNode = ctx.AddInitializer(zeroMean, meanDims, "meanVector");

// NB: Hack
// Currently ML.NET persists ONNX graphs in proto-buf 3 format but the Onnx runtime uses the proto-buf 2 format
// There is an incompatibility between the two where proto-buf 3 does not include variables whose values are zero
// In the Gemm node below, we want the srcVariableName matrix to be sent in without a transpose, so transA has to be zero
// Due to the incompatibility, we get an exception from the Onnx runtime
// To workaround this, we transpose the input data first with the Transpose operator and then use the Gemm operator with transA=1
// This should be removed once incompatibility is fixed.
string opType;
opType = "Transpose";
var transposeOutput = ctx.AddIntermediateVariable(schemaInfo.InputType, "TransposeOutput", true);
var transposeNode = ctx.CreateNode(opType, srcVariableName, transposeOutput, ctx.GetNodeName(opType), "");

opType = "Gemm";
var gemmNode = ctx.CreateNode(opType, new[] { transposeOutput, pcaMatrix, zeroMeanNode }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
gemmNode.AddAttribute("alpha", 1.0);
gemmNode.AddAttribute("beta", -1.0);
gemmNode.AddAttribute("transA", 1);
gemmNode.AddAttribute("transB", 1);
}
}

[TlcModule.EntryPoint(Name = "Transforms.PcaCalculator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
3 changes: 3 additions & 0 deletions test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,9 @@ private bool MatchNumberWithTolerance(MatchCollection firstCollection, MatchColl

public bool CompareNumbersWithTolerance(double expected, double actual, int? iterationOnCollection = null, int digitsOfPrecision = DigitsOfPrecision)
{
if (double.IsNaN(expected) && double.IsNaN(actual))
return true;

// this follows the IEEE recommendations for how to compare floating point numbers
double allowedVariance = Math.Pow(10, -digitsOfPrecision);
double delta = Round(expected, digitsOfPrecision) - Round(actual, digitsOfPrecision);
Expand Down
49 changes: 48 additions & 1 deletion test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,46 @@ public void OnnxTypeConversionTest()
CompareResults(model.ColumnPairs[0].outputColumnName, outputNames[1], mlnetResult, onnxResult);
}
}
Done();
}

[Fact]
public void PcaOnnxConversionTest()
{
var dataSource = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);

var mlContext = new MLContext(seed: 1);
var dataView = mlContext.Data.LoadFromTextFile(dataSource, new[] {
new TextLoader.Column("label", DataKind.Single, 11),
new TextLoader.Column("features", DataKind.Single, 0, 10)
}, hasHeader: true, separatorChar: ';');

bool[] zeroMeans = { true, false };
foreach (var zeroMean in zeroMeans)
{
var pipeline = ML.Transforms.ProjectToPrincipalComponents("pca", "features", rank: 5, seed: 1, ensureZeroMean: zeroMean);
var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

var onnxFileName = "pca.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);

SaveOnnxModel(onnxModel, onnxModelPath, null);

if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess)
{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4VectorColumns(model.ColumnPairs[0].outputColumnName, outputNames[2], transformedData, onnxResult);
}
}

Done();
}

private void CreateDummyExamplesToMakeComplierHappy()
Expand Down Expand Up @@ -845,7 +885,14 @@ private void CompareSelectedR4VectorColumns(string leftColumnName, string rightC

Assert.Equal(expected.Length, actual.Length);
for (int i = 0; i < expected.Length; ++i)
Assert.Equal(expected.GetItemOrDefault(i), actual.GetItemOrDefault(i), precision);
{
// We are using float values. But the Assert.Equal function takes doubles.
// And sometimes the converted doubles are different in their precision.
// So make sure we compare floats
float exp = expected.GetItemOrDefault(i);
float act = actual.GetItemOrDefault(i);
CompareNumbersWithTolerance(exp, act, null, precision);
}
}
}
}
Expand Down