diff --git a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs index c2c4b1eaea..693620a61d 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxUtils.cs @@ -305,7 +305,7 @@ public static ModelProto MakeModel(List nodes, string producerName, s model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion; model.ModelVersion = modelVersion; model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 1 }); - model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 7 }); + model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 9 }); model.Graph = new GraphProto(); var graph = model.Graph; graph.Node.Add(nodes); diff --git a/src/Microsoft.ML.PCA/PcaTransformer.cs b/src/Microsoft.ML.PCA/PcaTransformer.cs index 8951773dc3..8860f4377e 100644 --- a/src/Microsoft.ML.PCA/PcaTransformer.cs +++ b/src/Microsoft.ML.PCA/PcaTransformer.cs @@ -11,6 +11,7 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.CpuMath; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Numeric; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -511,7 +512,7 @@ internal static void ValidatePcaInput(IExceptionContext ectx, string name, DataV throw ectx.ExceptSchemaMismatch(nameof(inputSchema), "input", name, "known-size vector of Single of two or more items", type.ToString()); } - private sealed class Mapper : OneToOneMapperBase + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { public sealed class ColumnSchemaInfo { @@ -596,6 +597,73 @@ private static void TransformFeatures(IExceptionContext ectx, in VBuffer dst = editor.Commit(); } + + public bool CanSaveOnnx(OnnxContext ctx) => true; + + public void SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + + for (int i = 0; i < _numColumns; i++) + { + var colPair = _parent.ColumnPairs[i]; + var transformInfo = _parent._transformInfos[i]; + string inputColumnName = colPair.inputColumnName; + string outputColumnName = colPair.outputColumnName; + if (!ctx.ContainsColumn(inputColumnName)) + { + ctx.RemoveColumn(colPair.outputColumnName, false); + continue; + } + + var dstVariableName = ctx.AddIntermediateVariable(transformInfo.OutputType, outputColumnName); + SaveAsOnnxCore(ctx, i, ctx.GetVariableName(inputColumnName), dstVariableName); + } + } + + private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) + { + Host.CheckValue(ctx, nameof(ctx)); + + TransformInfo transformInfo = _parent._transformInfos[iinfo]; + ColumnSchemaInfo schemaInfo = _parent._schemaInfos[iinfo]; + + float[] principalComponents = new float[transformInfo.Rank * transformInfo.Dimension]; + for (int i = 0; i < transformInfo.Rank; i++) + { + Array.Copy(transformInfo.Eigenvectors[i], 0, principalComponents, i * transformInfo.Dimension, transformInfo.Dimension); + } + long[] pcaDims = { transformInfo.Rank, transformInfo.Dimension }; + var pcaMatrix = ctx.AddInitializer(principalComponents, pcaDims, "principalComponents"); + + float[] zeroMean = new float[transformInfo.Rank]; + if (transformInfo.MeanProjected != null) + { + Array.Copy(transformInfo.MeanProjected, zeroMean, transformInfo.Rank); + } + + long[] meanDims = { transformInfo.Rank }; + var zeroMeanNode = ctx.AddInitializer(zeroMean, meanDims, "meanVector"); + + // NB: Hack + // Currently ML.NET persists ONNX graphs in proto-buf 3 format but the Onnx runtime uses the proto-buf 2 format + // There is an incompatibility between the two where proto-buf 3 does not include variables whose values are zero + // In the Gemm node below, we want the srcVariableName matrix to be sent in without a transpose, so transA has to be zero + // Due to the incompatibility, we get an exception from the Onnx runtime + // To workaround this, we transpose the input data first with the Transpose operator and then use the Gemm operator with transA=1 + // This should be removed once incompatibility is fixed. + string opType; + opType = "Transpose"; + var transposeOutput = ctx.AddIntermediateVariable(schemaInfo.InputType, "TransposeOutput", true); + var transposeNode = ctx.CreateNode(opType, srcVariableName, transposeOutput, ctx.GetNodeName(opType), ""); + + opType = "Gemm"; + var gemmNode = ctx.CreateNode(opType, new[] { transposeOutput, pcaMatrix, zeroMeanNode }, new[] { dstVariableName }, ctx.GetNodeName(opType), ""); + gemmNode.AddAttribute("alpha", 1.0); + gemmNode.AddAttribute("beta", -1.0); + gemmNode.AddAttribute("transA", 1); + gemmNode.AddAttribute("transB", 1); + } } [TlcModule.EntryPoint(Name = "Transforms.PcaCalculator", diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt index f0795a1f13..ada4cef3c0 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt @@ -609,7 +609,7 @@ "version": "1" }, { - "version": "7" + "version": "9" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt index a6abf86b57..e43bb45cb9 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt @@ -526,7 +526,7 @@ "version": "1" }, { - "version": "7" + "version": "9" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt index 548e3564e3..2849b88f36 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LogisticRegressionSaveModelToOnnxTest.txt @@ -270,7 +270,7 @@ "version": "1" }, { - "version": "7" + "version": "9" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt index 22aee806af..320ddfb268 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt @@ -900,7 +900,7 @@ "version": "1" }, { - "version": "7" + "version": "9" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt index 68335b20ad..f118108b41 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt @@ -674,7 +674,7 @@ "version": "1" }, { - "version": "7" + "version": "9" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt index 44f74a7022..b385c683bf 100644 --- a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt +++ b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt @@ -349,7 +349,7 @@ "version": "1" }, { - "version": "7" + "version": "9" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt index a83f522dc5..d01a72b97a 100644 --- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt +++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLogisticRegressionSaveModelToOnnxTest.txt @@ -426,7 +426,7 @@ "version": "1" }, { - "version": "7" + "version": "9" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt b/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt index 548e3564e3..2849b88f36 100644 --- a/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt +++ b/test/BaselineOutput/Common/Onnx/Regression/Adult/SimplePipeline.txt @@ -270,7 +270,7 @@ "version": "1" }, { - "version": "7" + "version": "9" } ] } \ No newline at end of file diff --git a/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt b/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt index 5b1a98942b..97b44aea40 100644 --- a/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt +++ b/test/BaselineOutput/Common/Onnx/Transforms/Sentiment/SmallWordEmbed.txt @@ -1116,7 +1116,7 @@ "version": "1" }, { - "version": "7" + "version": "9" } ] } \ No newline at end of file diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs index 19c5e083d0..47deee2fcc 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs @@ -571,6 +571,9 @@ private bool MatchNumberWithTolerance(MatchCollection firstCollection, MatchColl public bool CompareNumbersWithTolerance(double expected, double actual, int? iterationOnCollection = null, int digitsOfPrecision = DigitsOfPrecision) { + if (double.IsNaN(expected) && double.IsNaN(actual)) + return true; + // this follows the IEEE recommendations for how to compare floating point numbers double allowedVariance = Math.Pow(10, -digitsOfPrecision); double delta = Round(expected, digitsOfPrecision) - Round(actual, digitsOfPrecision); diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 138d49f5d4..c80d14895e 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -738,6 +738,46 @@ public void OnnxTypeConversionTest() CompareResults(model.ColumnPairs[0].outputColumnName, outputNames[1], mlnetResult, onnxResult); } } + Done(); + } + + [Fact] + public void PcaOnnxConversionTest() + { + var dataSource = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); + + var mlContext = new MLContext(seed: 1); + var dataView = mlContext.Data.LoadFromTextFile(dataSource, new[] { + new TextLoader.Column("label", DataKind.Single, 11), + new TextLoader.Column("features", DataKind.Single, 0, 10) + }, hasHeader: true, separatorChar: ';'); + + bool[] zeroMeans = { true, false }; + foreach (var zeroMean in zeroMeans) + { + var pipeline = ML.Transforms.ProjectToPrincipalComponents("pca", "features", rank: 5, seed: 1, ensureZeroMean: zeroMean); + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + var onnxFileName = "pca.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + + SaveOnnxModel(onnxModel, onnxModelPath, null); + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedR4VectorColumns(model.ColumnPairs[0].outputColumnName, outputNames[2], transformedData, onnxResult); + } + } + + Done(); } private void CreateDummyExamplesToMakeComplierHappy() @@ -845,7 +885,14 @@ private void CompareSelectedR4VectorColumns(string leftColumnName, string rightC Assert.Equal(expected.Length, actual.Length); for (int i = 0; i < expected.Length; ++i) - Assert.Equal(expected.GetItemOrDefault(i), actual.GetItemOrDefault(i), precision); + { + // We are using float values. But the Assert.Equal function takes doubles. + // And sometimes the converted doubles are different in their precision. + // So make sure we compare floats + float exp = expected.GetItemOrDefault(i); + float act = actual.GetItemOrDefault(i); + CompareNumbersWithTolerance(exp, act, null, precision); + } } } }