From a5a7a2a3f82545279a3fb87c2f7c85ce25809268 Mon Sep 17 00:00:00 2001 From: Keren Fuentes Date: Tue, 5 Jan 2021 15:47:17 -0800 Subject: [PATCH 01/10] onnx export for valuemapping estimator --- .../Transforms/ValueMapping.cs | 268 +++++++++++++++++- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 205 +++++++++++++- test/data/type-conversion-boolean.txt | 2 +- 3 files changed, 464 insertions(+), 11 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs index 7a81bee190..4d618cfbef 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Data; using Microsoft.ML.Data.IO; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -818,6 +819,8 @@ private static ValueMap CreateValueMapInvoke(DataViewSchema.Column public abstract Delegate GetGetter(DataViewRow input, int index); public abstract IDataView GetDataView(IHostEnvironment env); + public abstract TKey[] GetKeys(); + public abstract TValue[] GetValues(); } /// @@ -962,6 +965,17 @@ private static TValue GetVector(TValue value) } private static TValue GetValue(TValue value) => value; + + public override T[] GetKeys() + { + + return _mapping.Keys.Cast().ToArray(); + } + public override T[] GetValues() + { + return _mapping.Values.Cast().ToArray(); + } + } /// @@ -1012,12 +1026,13 @@ private protected override IRowMapper MakeRowMapper(DataViewSchema schema) return new Mapper(this, schema, _valueMap, ColumnPairs); } - private sealed class Mapper : OneToOneMapperBase + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private readonly DataViewSchema _inputSchema; private readonly ValueMap _valueMap; private readonly (string outputColumnName, string inputColumnName)[] _columns; private readonly ValueMappingTransformer _parent; + public bool CanSaveOnnx(OnnxContext ctx) => true; internal Mapper(ValueMappingTransformer transform, DataViewSchema inputSchema, @@ -1040,6 +1055,257 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput) + { + var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); + var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, (int)srcShape[1]), "castOutput"); + var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType(); + castNode.AddAttribute("to", t); + node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); + var values = Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToString(item)); + node.AddAttribute("keys_strings", values); + } + + private void CastInputToInt(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput) + { + var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); + var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, (int)srcShape[1]), "castOutput"); + var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType(); + castNode.AddAttribute("to", t); + node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); + var values = Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToInt64(item)); + node.AddAttribute("keys_int64s", values); + } + + private void CastInputToFloat(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput) + { + var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); + var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "castOutput"); + var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); + castNode.AddAttribute("to", t); + node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); + var values = Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToSingle(item)); + node.AddAttribute("keys_floats", values); + } + + private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName) + { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + OnnxNode node; + string opType = "LabelEncoder"; + var labelEncoderInput = srcVariableName; + var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); + var typeValue = _valueMap.ValueColumn.Type; + var typeKey = _valueMap.KeyColumn.Type; + + var labelEncoderOutput = (typeValue == NumberDataViewType.Single || typeValue == TextDataViewType.Instance || typeValue == NumberDataViewType.Int64) ? dstVariableName : + (typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) ? ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "LabelEncoderOutput") : + ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, (int) srcShape[1]), "LabelEncoderOutput"); + + // Labelencoder doesn't support mapping between the same type and only supports mappings between int64s, floats, and strings. + // Keys that are of NumberDataTypeView, but not int64s, are cast to strings. If values, we cast them to int64s, in order to avoid same type mappings. + // String -> String mappings can't be supported. + if (typeKey == NumberDataViewType.Int64) + { + // To avoid int64 -> int64 mapping, we cast keys to strings + if (typeValue is NumberDataViewType) + { + CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else + { + node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); + node.AddAttribute("keys_int64s", _valueMap.GetKeys()); + } + } + else if (typeKey == NumberDataViewType.Int32) + { + // To avoid string -> string mapping, we cast keys to int64s + if (typeValue is TextDataViewType) + { + labelEncoderOutput = dstVariableName; + CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput; + } + else + CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else if (typeKey == NumberDataViewType.Int16) + { + if (typeValue is TextDataViewType) + { + CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else + CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else if (typeKey == NumberDataViewType.UInt64) + { + if (typeValue is TextDataViewType) + { + CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else + CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else if (typeKey == NumberDataViewType.UInt32) + { + if (typeValue is TextDataViewType) + { + CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else + CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); // TODO + } + else if (typeKey == NumberDataViewType.UInt16) + { + if (typeValue is TextDataViewType) + { + CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else + CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else if (typeKey == NumberDataViewType.Single) + { + if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) + { + CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else + { + node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); + node.AddAttribute("keys_floats", _valueMap.GetKeys()); + } + } + else if (typeKey == NumberDataViewType.Double) + { + if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) + { + CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else + CastInputToFloat(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else if (typeKey == TextDataViewType.Instance) + { + node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); + node.AddAttribute("keys_strings", _valueMap.GetKeys>()); + } + else if (typeKey == BooleanDataViewType.Instance) + { + if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) + { + var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, (int)srcShape[1]), "castOutput"); + var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType(); + castNode.AddAttribute("to", t); + node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); + var values = Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToString(Convert.ToByte(item))); + node.AddAttribute("keys_strings", values); + } + else + CastInputToFloat(ctx, out node, srcVariableName, opType, labelEncoderOutput); + } + else + return false; + + // Pass in values + if (typeValue == NumberDataViewType.Int64) + { + node.AddAttribute("values_int64s", _valueMap.GetValues()); + } + else if (typeValue == NumberDataViewType.Int32) + { + node.AddAttribute("values_int64s", _valueMap.GetValues().Select(item => Convert.ToInt64(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else if (typeValue == NumberDataViewType.Int16) + { + node.AddAttribute("values_int64s", _valueMap.GetValues().Select(item => Convert.ToInt64(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else if (typeValue == NumberDataViewType.UInt64) + { + node.AddAttribute("values_int64s", _valueMap.GetValues().Select(item => Convert.ToInt64(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else if (typeValue == NumberDataViewType.UInt32) + { + node.AddAttribute("values_int64s", _valueMap.GetValues().Select(item => Convert.ToInt64(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else if (typeValue == NumberDataViewType.UInt16) + { + node.AddAttribute("values_int64s", _valueMap.GetValues().Select(item => Convert.ToInt64(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else if (typeValue == NumberDataViewType.Single) + { + node.AddAttribute("values_floats", _valueMap.GetValues()); + } + else if (typeValue == NumberDataViewType.Double) + { + node.AddAttribute("values_floats", _valueMap.GetValues().Select(item => Convert.ToSingle(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else if (typeValue == TextDataViewType.Instance) + { + node.AddAttribute("values_strings", _valueMap.GetValues>()); + } + else if (typeValue == BooleanDataViewType.Instance) + { + node.AddAttribute("values_floats", _valueMap.GetValues().Select(item => Convert.ToSingle(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else + return false; + + //Unknown keys should map to 0 + node.AddAttribute("default_int64", 0); + node.AddAttribute("default_string", ""); + node.AddAttribute("default_float", 0f); + return true; + } + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() { var result = new DataViewSchema.DetachedColumn[_columns.Length]; diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index f09d02a09a..f132653d56 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -183,14 +183,6 @@ public void RegressionTrainersOnnxConversionTest() hasHeader: true); List> estimators = new List>() { - // TODO TEST_STABILITY: Sdca has developed some instability with failures in comparison against baseline. Disabling it for now. - //mlContext.Regression.Trainers.Sdca("Target","FeatureVector"), - mlContext.Regression.Trainers.Ols("Target","FeatureVector"), - mlContext.Regression.Trainers.OnlineGradientDescent("Target","FeatureVector"), - mlContext.Regression.Trainers.FastForest("Target", "FeatureVector"), - mlContext.Regression.Trainers.FastTree("Target", "FeatureVector"), - mlContext.Regression.Trainers.FastTreeTweedie("Target", "FeatureVector"), - mlContext.Regression.Trainers.LbfgsPoissonRegression("Target", "FeatureVector"), }; if (Environment.Is64BitProcess) { @@ -204,7 +196,7 @@ public void RegressionTrainersOnnxConversionTest() // Step 2: Convert ML.NET model to ONNX format and save it as a model file and a text file. TestPipeline(estimator, dataView, onnxModelFileName, new ColumnComparison[] { new ColumnComparison("Score", 3) }, onnxTxtFileName, subDir); - CheckEquality(subDir, onnxTxtFileName, digitsOfPrecision: 1); + //CheckEquality(subDir, onnxTxtFileName, digitsOfPrecision: 1); } Done(); } @@ -1209,6 +1201,201 @@ public void ValueToKeyMappingOnnxConversionTest( Done(); } + [Fact] + public void ValueMappingOnnxConversion0Test() + { + var mlContext = new MLContext(seed: 1); + string filePath = GetDataPath("type-conversion-boolean.txt"); + + TextLoader.Column[] columnsVector = new[] + { + new TextLoader.Column("Keys", DataKind.Boolean, 0, 3) + }; + + IDataView[] dataViews = + { + mlContext.Data.LoadFromTextFile(filePath, columnsVector , separatorChar: '\t') //vector + }; + + var labelMap = new Dictionary + { + {false, 53f}, + {true, 23f} + }; + + IEstimator[] pipelines = + { + mlContext.Transforms.Conversion.MapValue("Value", labelMap, "Keys") + }; + + for (int i = 0; i < pipelines.Length; i++) + { + for (int j = 0; j < dataViews.Length; j++) + { + var onnxFileName = "MapValue2.onnx"; + + TestPipeline(pipelines[i], dataViews[j], onnxFileName, new ColumnComparison[] { new ColumnComparison("Value") }); + } + } + Done(); + } + + [Theory] + [CombinatorialData] + // Due to lack of support from OnnxRuntime, String => String mappings are not supported + public void ValueMappingOnnxConversionTest([CombinatorialValues(DataKind.Int64, DataKind.Int32, DataKind.UInt32, DataKind.UInt64, + DataKind.UInt16, DataKind.Int16, DataKind.Double, DataKind.String, DataKind.Boolean)] + DataKind keyType) + { + var mlContext = new MLContext(seed: 1); + string filePath = (keyType == DataKind.Boolean) ? GetDataPath("type-conversion-boolean.txt") + : GetDataPath("type-conversion.txt"); + + TextLoader.Column[] columnsVector = new[] + { + new TextLoader.Column("Keys", keyType, 0, 2) + }; + + IDataView[] dataViews = + { + mlContext.Data.LoadFromTextFile(filePath, columnsVector , separatorChar: '\t') //vector + }; + List> pipelines = new List>(); + + if (keyType == DataKind.Single) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 698 }, { 23, 7908 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, false }, { 23, true } }, "Keys")); + } + else if (keyType == DataKind.Double) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 698 }, { 23, 7908 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 698 }, { 23, 7908 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, true }, { 23, false } }, "Keys")); + } + else if (keyType == DataKind.Boolean) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 6 }, { false, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 6 }, { false, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 6 }, { false, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 6 }, { false, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 6 }, { false, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, "True" }, { false, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 6 }, { false, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 698 }, { false, 7908 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 698 }, { false, 7908 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { false, true }, { true, false } }, "Keys")); + } + else if (keyType == DataKind.String) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 3 }, { "23", 23 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 3 }, { "23", 23 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 6 }, { "23", 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 6 }, { "23", 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 6 }, { "23", 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 6 }, { "23", 23 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 6 }, { "23", 23 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 3 }, { "23", 23 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", true }, { "23", false } }, "Keys")); + } + else if (keyType == DataKind.Int32) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + } + else if (keyType == DataKind.Int16) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + } + else if (keyType == DataKind.Int64) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + } + else if (keyType == DataKind.UInt32) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + } + else if (keyType == DataKind.UInt16) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + } + else if (keyType == DataKind.UInt64) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + } + foreach (IEstimator pipeline in pipelines) + { + for (int j = 0; j < dataViews.Length; j++) + { + var onnxFileName = "MapValue.onnx"; + + TestPipeline(pipeline, dataViews[j], onnxFileName, new ColumnComparison[] { new ColumnComparison("Value") }); + } + } + Done(); + } + [Theory] [InlineData(DataKind.Single)] [InlineData(DataKind.Int64)] diff --git a/test/data/type-conversion-boolean.txt b/test/data/type-conversion-boolean.txt index c1f22fbc23..84b1b7522c 100644 --- a/test/data/type-conversion-boolean.txt +++ b/test/data/type-conversion-boolean.txt @@ -1 +1 @@ -False \ No newline at end of file +False True False \ No newline at end of file From 0eb8e7a00f380b59bfb486d653f83634ed766ea9 Mon Sep 17 00:00:00 2001 From: Keren Fuentes Date: Tue, 5 Jan 2021 16:39:38 -0800 Subject: [PATCH 02/10] reformatting --- .../Transforms/ValueMapping.cs | 30 ++++------- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 51 ++++--------------- 2 files changed, 19 insertions(+), 62 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs index 4d618cfbef..41348c121b 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs @@ -968,7 +968,6 @@ private static TValue GetVector(TValue value) public override T[] GetKeys() { - return _mapping.Keys.Cast().ToArray(); } public override T[] GetValues() @@ -1134,12 +1133,13 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV (typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) ? ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "LabelEncoderOutput") : ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, (int) srcShape[1]), "LabelEncoderOutput"); - // Labelencoder doesn't support mapping between the same type and only supports mappings between int64s, floats, and strings. - // Keys that are of NumberDataTypeView, but not int64s, are cast to strings. If values, we cast them to int64s, in order to avoid same type mappings. + // The LabelEncoder operator doesn't support mappings between the same type and only supports mappings between int64s, floats, and strings. + // As a result, we need to cast most inputs and outputs. In order to avoid as many unsupported mappings, we cast keys that are of NumberDataTypeView + // to strings and values of NumberDataTypeView to int64s. // String -> String mappings can't be supported. if (typeKey == NumberDataViewType.Int64) { - // To avoid int64 -> int64 mapping, we cast keys to strings + // To avoid a int64 -> int64 mapping, we cast keys to strings if (typeValue is NumberDataViewType) { CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); @@ -1152,48 +1152,37 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV } else if (typeKey == NumberDataViewType.Int32) { - // To avoid string -> string mapping, we cast keys to int64s + // To avoid a string -> string mapping, we cast keys to int64s if (typeValue is TextDataViewType) - { - labelEncoderOutput = dstVariableName; - CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput; - } + CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); else CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); } else if (typeKey == NumberDataViewType.Int16) { if (typeValue is TextDataViewType) - { CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); - } else CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); } else if (typeKey == NumberDataViewType.UInt64) { if (typeValue is TextDataViewType) - { CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); - } else CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); } else if (typeKey == NumberDataViewType.UInt32) { if (typeValue is TextDataViewType) - { CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); - } else - CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); // TODO + CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); } else if (typeKey == NumberDataViewType.UInt16) { if (typeValue is TextDataViewType) - { CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); - } else CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); } @@ -1212,14 +1201,14 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV else if (typeKey == NumberDataViewType.Double) { if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) - { CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); - } else CastInputToFloat(ctx, out node, srcVariableName, opType, labelEncoderOutput); } else if (typeKey == TextDataViewType.Instance) { + if (typeValue == TextDataViewType.Instance) + return false; node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); node.AddAttribute("keys_strings", _valueMap.GetKeys>()); } @@ -1241,7 +1230,6 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV else return false; - // Pass in values if (typeValue == NumberDataViewType.Int64) { node.AddAttribute("values_int64s", _valueMap.GetValues()); diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index f132653d56..ae53d37c4a 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -183,6 +183,14 @@ public void RegressionTrainersOnnxConversionTest() hasHeader: true); List> estimators = new List>() { + // TODO TEST_STABILITY: Sdca has developed some instability with failures in comparison against baseline. Disabling it for now. + //mlContext.Regression.Trainers.Sdca("Target","FeatureVector"), + mlContext.Regression.Trainers.Ols("Target","FeatureVector"), + mlContext.Regression.Trainers.OnlineGradientDescent("Target","FeatureVector"), + mlContext.Regression.Trainers.FastForest("Target", "FeatureVector"), + mlContext.Regression.Trainers.FastTree("Target", "FeatureVector"), + mlContext.Regression.Trainers.FastTreeTweedie("Target", "FeatureVector"), + mlContext.Regression.Trainers.LbfgsPoissonRegression("Target", "FeatureVector"), }; if (Environment.Is64BitProcess) { @@ -196,7 +204,7 @@ public void RegressionTrainersOnnxConversionTest() // Step 2: Convert ML.NET model to ONNX format and save it as a model file and a text file. TestPipeline(estimator, dataView, onnxModelFileName, new ColumnComparison[] { new ColumnComparison("Score", 3) }, onnxTxtFileName, subDir); - //CheckEquality(subDir, onnxTxtFileName, digitsOfPrecision: 1); + CheckEquality(subDir, onnxTxtFileName, digitsOfPrecision: 1); } Done(); } @@ -1201,48 +1209,9 @@ public void ValueToKeyMappingOnnxConversionTest( Done(); } - [Fact] - public void ValueMappingOnnxConversion0Test() - { - var mlContext = new MLContext(seed: 1); - string filePath = GetDataPath("type-conversion-boolean.txt"); - - TextLoader.Column[] columnsVector = new[] - { - new TextLoader.Column("Keys", DataKind.Boolean, 0, 3) - }; - - IDataView[] dataViews = - { - mlContext.Data.LoadFromTextFile(filePath, columnsVector , separatorChar: '\t') //vector - }; - - var labelMap = new Dictionary - { - {false, 53f}, - {true, 23f} - }; - - IEstimator[] pipelines = - { - mlContext.Transforms.Conversion.MapValue("Value", labelMap, "Keys") - }; - - for (int i = 0; i < pipelines.Length; i++) - { - for (int j = 0; j < dataViews.Length; j++) - { - var onnxFileName = "MapValue2.onnx"; - - TestPipeline(pipelines[i], dataViews[j], onnxFileName, new ColumnComparison[] { new ColumnComparison("Value") }); - } - } - Done(); - } - [Theory] [CombinatorialData] - // Due to lack of support from OnnxRuntime, String => String mappings are not supported + // Due to lack of support in OnnxRuntime, String => String mappings are not supported public void ValueMappingOnnxConversionTest([CombinatorialValues(DataKind.Int64, DataKind.Int32, DataKind.UInt32, DataKind.UInt64, DataKind.UInt16, DataKind.Int16, DataKind.Double, DataKind.String, DataKind.Boolean)] DataKind keyType) From 3d3f34c3c6268f00005ac0af4f513608af0cb4ba Mon Sep 17 00:00:00 2001 From: Keren Fuentes Date: Wed, 6 Jan 2021 15:35:45 -0800 Subject: [PATCH 03/10] resolving comments and adding scalar testing --- .../Transforms/ValueMapping.cs | 70 +++++++------------ test/Microsoft.ML.Tests/OnnxConversionTest.cs | 7 +- 2 files changed, 30 insertions(+), 47 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs index 41348c121b..90783a3ac5 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs @@ -1082,40 +1082,20 @@ public void SaveAsOnnx(OnnxContext ctx) } } - private void CastInputToString(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput) + private void CastInputTo(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput, PrimitiveDataViewType itemType) { var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); - var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, (int)srcShape[1]), "castOutput"); + var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(itemType, (int)srcShape[1]), "castOutput"); var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); - var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType(); - castNode.AddAttribute("to", t); + castNode.AddAttribute("to", itemType.RawType); node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); - var values = Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToString(item)); - node.AddAttribute("keys_strings", values); - } + if (itemType == TextDataViewType.Instance) + node.AddAttribute("keys_strings", Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToString(item))); + else if (itemType == NumberDataViewType.Single) + node.AddAttribute("keys_floats", Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToSingle(item))); + else if (itemType == NumberDataViewType.Int64) + node.AddAttribute("keys_int64s", Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToInt64(item))); - private void CastInputToInt(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput) - { - var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); - var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, (int)srcShape[1]), "castOutput"); - var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); - var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType(); - castNode.AddAttribute("to", t); - node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); - var values = Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToInt64(item)); - node.AddAttribute("keys_int64s", values); - } - - private void CastInputToFloat(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput) - { - var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); - var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "castOutput"); - var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); - var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); - castNode.AddAttribute("to", t); - node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); - var values = Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToSingle(item)); - node.AddAttribute("keys_floats", values); } private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName) @@ -1135,14 +1115,14 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV // The LabelEncoder operator doesn't support mappings between the same type and only supports mappings between int64s, floats, and strings. // As a result, we need to cast most inputs and outputs. In order to avoid as many unsupported mappings, we cast keys that are of NumberDataTypeView - // to strings and values of NumberDataTypeView to int64s. + // to strings and values of NumberDataViewType to int64s. // String -> String mappings can't be supported. if (typeKey == NumberDataViewType.Int64) { // To avoid a int64 -> int64 mapping, we cast keys to strings if (typeValue is NumberDataViewType) { - CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); } else { @@ -1154,43 +1134,43 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV { // To avoid a string -> string mapping, we cast keys to int64s if (typeValue is TextDataViewType) - CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); else - CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); } else if (typeKey == NumberDataViewType.Int16) { if (typeValue is TextDataViewType) - CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); else - CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); } else if (typeKey == NumberDataViewType.UInt64) { if (typeValue is TextDataViewType) - CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); else - CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); } else if (typeKey == NumberDataViewType.UInt32) { if (typeValue is TextDataViewType) - CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); else - CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); } else if (typeKey == NumberDataViewType.UInt16) { if (typeValue is TextDataViewType) - CastInputToInt(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); else - CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); } else if (typeKey == NumberDataViewType.Single) { if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) { - CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); } else { @@ -1201,9 +1181,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV else if (typeKey == NumberDataViewType.Double) { if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) - CastInputToString(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); else - CastInputToFloat(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single); } else if (typeKey == TextDataViewType.Instance) { @@ -1225,7 +1205,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV node.AddAttribute("keys_strings", values); } else - CastInputToFloat(ctx, out node, srcVariableName, opType, labelEncoderOutput); + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single); } else return false; diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index ae53d37c4a..7d0b49ad5a 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1224,9 +1224,13 @@ public void ValueMappingOnnxConversionTest([CombinatorialValues(DataKind.Int64, { new TextLoader.Column("Keys", keyType, 0, 2) }; - + TextLoader.Column[] columnsScalar = new[] + { + new TextLoader.Column("Keys", keyType, 0) + }; IDataView[] dataViews = { + mlContext.Data.LoadFromTextFile(filePath, columnsScalar, separatorChar: '\t'), //scalar mlContext.Data.LoadFromTextFile(filePath, columnsVector , separatorChar: '\t') //vector }; List> pipelines = new List>(); @@ -1358,7 +1362,6 @@ public void ValueMappingOnnxConversionTest([CombinatorialValues(DataKind.Int64, for (int j = 0; j < dataViews.Length; j++) { var onnxFileName = "MapValue.onnx"; - TestPipeline(pipeline, dataViews[j], onnxFileName, new ColumnComparison[] { new ColumnComparison("Value") }); } } From 7d1b86d3d5442949c02d441701a662fe7a4e628a Mon Sep 17 00:00:00 2001 From: Keren Fuentes Date: Thu, 7 Jan 2021 11:55:35 -0800 Subject: [PATCH 04/10] adding key type support --- .../Transforms/ValueMapping.cs | 5 ++-- .../BaseTestBaseline.cs | 2 -- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 24 +++++++++---------- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs index 90783a3ac5..47c6ff8836 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs @@ -1108,6 +1108,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); var typeValue = _valueMap.ValueColumn.Type; var typeKey = _valueMap.KeyColumn.Type; + var kind = _valueMap.ValueColumn.Type.GetRawKind(); var labelEncoderOutput = (typeValue == NumberDataViewType.Single || typeValue == TextDataViewType.Instance || typeValue == NumberDataViewType.Int64) ? dstVariableName : (typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) ? ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "LabelEncoderOutput") : @@ -1226,13 +1227,13 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); castNode.AddAttribute("to", typeValue.RawType); } - else if (typeValue == NumberDataViewType.UInt64) + else if (typeValue == NumberDataViewType.UInt64 || kind == InternalDataKind.U8) { node.AddAttribute("values_int64s", _valueMap.GetValues().Select(item => Convert.ToInt64(item))); var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); castNode.AddAttribute("to", typeValue.RawType); } - else if (typeValue == NumberDataViewType.UInt32) + else if (typeValue == NumberDataViewType.UInt32 || kind == InternalDataKind.U4) { node.AddAttribute("values_int64s", _valueMap.GetValues().Select(item => Convert.ToInt64(item))); var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs index c67b19efa6..5457f3af20 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs @@ -674,9 +674,7 @@ private static double Round(double value, int digitsOfPrecision) public void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6, bool isRightColumnOnnxScalar = false) { var leftColumn = left.Schema[leftColumnName]; - var rightColumn = right.Schema[rightColumnName]; var leftType = leftColumn.Type.GetItemType(); - var rightType = rightColumn.Type.GetItemType(); if (leftType == NumberDataViewType.SByte) CompareSelectedColumns(leftColumnName, rightColumnName, left, right, isRightColumnOnnxScalar: isRightColumnOnnxScalar); diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 7d0b49ad5a..5e97a970a6 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1214,7 +1214,7 @@ public void ValueToKeyMappingOnnxConversionTest( // Due to lack of support in OnnxRuntime, String => String mappings are not supported public void ValueMappingOnnxConversionTest([CombinatorialValues(DataKind.Int64, DataKind.Int32, DataKind.UInt32, DataKind.UInt64, DataKind.UInt16, DataKind.Int16, DataKind.Double, DataKind.String, DataKind.Boolean)] - DataKind keyType) + DataKind keyType, [CombinatorialValues(true, false)] bool treatValuesAsKeyType) { var mlContext = new MLContext(seed: 1); string filePath = (keyType == DataKind.Boolean) ? GetDataPath("type-conversion-boolean.txt") @@ -1237,16 +1237,16 @@ public void ValueMappingOnnxConversionTest([CombinatorialValues(DataKind.Int64, if (keyType == DataKind.Single) { - pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); - pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); - pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); - pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); - pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); - pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); - pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); - pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); - pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 698 }, { 23, 7908 } }, "Keys")); - pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, false }, { 23, true } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 698 }, { 23, 7908 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, false }, { 23, true } }, "Keys", treatValuesAsKeyType)); } else if (keyType == DataKind.Double) { @@ -1361,7 +1361,7 @@ public void ValueMappingOnnxConversionTest([CombinatorialValues(DataKind.Int64, { for (int j = 0; j < dataViews.Length; j++) { - var onnxFileName = "MapValue.onnx"; + var onnxFileName = "ValueMapping.onnx"; TestPipeline(pipeline, dataViews[j], onnxFileName, new ColumnComparison[] { new ColumnComparison("Value") }); } } From ba2b7272cbb00778f85da12c09bc4477e4bdfaaf Mon Sep 17 00:00:00 2001 From: Keren Fuentes Date: Thu, 7 Jan 2021 15:01:12 -0800 Subject: [PATCH 05/10] testing mac --- build/ci/job-template.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/ci/job-template.yml b/build/ci/job-template.yml index 4944eb6c6c..5cda8924c9 100644 --- a/build/ci/job-template.yml +++ b/build/ci/job-template.yml @@ -63,7 +63,7 @@ jobs: continueOnError: true # Extra MacOS step required to install OS-specific dependencies - ${{ if eq(parameters.pool.name, 'Hosted macOS') }}: - - script: brew update && brew install mono-libgdiplus && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force + - script: brew update && brew --force install mono-libgdiplus && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force displayName: Install MacOS build dependencies - ${{ if and( eq(parameters.nightlyBuild, 'true'), eq(parameters.pool.name, 'Hosted Ubuntu 1604')) }}: - bash: echo "##vso[task.setvariable variable=LD_LIBRARY_PATH]$(nightlyBuildRunPath):$LD_LIBRARY_PATH" From 2416863c395c884f8cb560a80bea577faa987859 Mon Sep 17 00:00:00 2001 From: Keren Fuentes Date: Thu, 7 Jan 2021 15:07:46 -0800 Subject: [PATCH 06/10] testing mac --- build/ci/job-template.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/ci/job-template.yml b/build/ci/job-template.yml index 5cda8924c9..7442384821 100644 --- a/build/ci/job-template.yml +++ b/build/ci/job-template.yml @@ -63,7 +63,7 @@ jobs: continueOnError: true # Extra MacOS step required to install OS-specific dependencies - ${{ if eq(parameters.pool.name, 'Hosted macOS') }}: - - script: brew update && brew --force install mono-libgdiplus && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force + - script: brew update && brew unlink python@3.9 && brew install mono-libgdiplus && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force displayName: Install MacOS build dependencies - ${{ if and( eq(parameters.nightlyBuild, 'true'), eq(parameters.pool.name, 'Hosted Ubuntu 1604')) }}: - bash: echo "##vso[task.setvariable variable=LD_LIBRARY_PATH]$(nightlyBuildRunPath):$LD_LIBRARY_PATH" From 0be5f427f3e2619fb8a66014fdc5ca22e8024a11 Mon Sep 17 00:00:00 2001 From: Keren Fuentes Date: Thu, 7 Jan 2021 15:22:42 -0800 Subject: [PATCH 07/10] testing mac --- build/ci/job-template.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/ci/job-template.yml b/build/ci/job-template.yml index 7442384821..1848ea50e0 100644 --- a/build/ci/job-template.yml +++ b/build/ci/job-template.yml @@ -63,7 +63,7 @@ jobs: continueOnError: true # Extra MacOS step required to install OS-specific dependencies - ${{ if eq(parameters.pool.name, 'Hosted macOS') }}: - - script: brew update && brew unlink python@3.9 && brew install mono-libgdiplus && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force + - script: brew update && brew link --overwrite python@3.9 && brew install mono-libgdiplus && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force displayName: Install MacOS build dependencies - ${{ if and( eq(parameters.nightlyBuild, 'true'), eq(parameters.pool.name, 'Hosted Ubuntu 1604')) }}: - bash: echo "##vso[task.setvariable variable=LD_LIBRARY_PATH]$(nightlyBuildRunPath):$LD_LIBRARY_PATH" From af7f6096455b72f89e940bd11d719e0880efdedf Mon Sep 17 00:00:00 2001 From: Keren Fuentes Date: Thu, 7 Jan 2021 15:43:23 -0800 Subject: [PATCH 08/10] testing mac --- build/ci/job-template.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/ci/job-template.yml b/build/ci/job-template.yml index 1848ea50e0..a5db3b5dbc 100644 --- a/build/ci/job-template.yml +++ b/build/ci/job-template.yml @@ -63,7 +63,7 @@ jobs: continueOnError: true # Extra MacOS step required to install OS-specific dependencies - ${{ if eq(parameters.pool.name, 'Hosted macOS') }}: - - script: brew update && brew link --overwrite python@3.9 && brew install mono-libgdiplus && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force + - script: brew update && brew install mono-libgdiplus && brew link --overwrite python@3.9 && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force displayName: Install MacOS build dependencies - ${{ if and( eq(parameters.nightlyBuild, 'true'), eq(parameters.pool.name, 'Hosted Ubuntu 1604')) }}: - bash: echo "##vso[task.setvariable variable=LD_LIBRARY_PATH]$(nightlyBuildRunPath):$LD_LIBRARY_PATH" From e4ea4fe38f7a1d7a1f5259fe2c91e4789893a6ae Mon Sep 17 00:00:00 2001 From: Keren Fuentes Date: Fri, 8 Jan 2021 11:07:07 -0800 Subject: [PATCH 09/10] testing mac --- build/vsts-ci.yml | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/build/vsts-ci.yml b/build/vsts-ci.yml index 83f12746b7..1c162976fc 100644 --- a/build/vsts-ci.yml +++ b/build/vsts-ci.yml @@ -50,14 +50,7 @@ phases: queue: name: Hosted macOS steps: - - script: | - brew uninstall openssl@1.0.2t | - brew uninstall python@2.7.17 | - brew untap local/openssl | - brew untap local/python2 - displayName: MacOS Homebrew bug Workaround - continueOnError: true - - script: brew update && brew unlink python@3.8 && brew install mono-libgdiplus && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force + - script: brew update && brew unlink python@3.9 && brew untap local/bin/2to3 && brew install mono-libgdiplus && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force displayName: Install build dependencies - script: ./restore.sh displayName: restore all projects From 13fd3c6b3c121ae653f430463df2537cf69e9883 Mon Sep 17 00:00:00 2001 From: Keren Fuentes Date: Thu, 14 Jan 2021 09:20:01 -0800 Subject: [PATCH 10/10] restoring files --- build/ci/job-template.yml | 2 +- build/vsts-ci.yml | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/build/ci/job-template.yml b/build/ci/job-template.yml index a5db3b5dbc..4944eb6c6c 100644 --- a/build/ci/job-template.yml +++ b/build/ci/job-template.yml @@ -63,7 +63,7 @@ jobs: continueOnError: true # Extra MacOS step required to install OS-specific dependencies - ${{ if eq(parameters.pool.name, 'Hosted macOS') }}: - - script: brew update && brew install mono-libgdiplus && brew link --overwrite python@3.9 && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force + - script: brew update && brew install mono-libgdiplus && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force displayName: Install MacOS build dependencies - ${{ if and( eq(parameters.nightlyBuild, 'true'), eq(parameters.pool.name, 'Hosted Ubuntu 1604')) }}: - bash: echo "##vso[task.setvariable variable=LD_LIBRARY_PATH]$(nightlyBuildRunPath):$LD_LIBRARY_PATH" diff --git a/build/vsts-ci.yml b/build/vsts-ci.yml index 1c162976fc..83f12746b7 100644 --- a/build/vsts-ci.yml +++ b/build/vsts-ci.yml @@ -50,7 +50,14 @@ phases: queue: name: Hosted macOS steps: - - script: brew update && brew unlink python@3.9 && brew untap local/bin/2to3 && brew install mono-libgdiplus && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force + - script: | + brew uninstall openssl@1.0.2t | + brew uninstall python@2.7.17 | + brew untap local/openssl | + brew untap local/python2 + displayName: MacOS Homebrew bug Workaround + continueOnError: true + - script: brew update && brew unlink python@3.8 && brew install mono-libgdiplus && brew install $(Build.SourcesDirectory)/build/libomp.rb && brew link libomp --force displayName: Install build dependencies - script: ./restore.sh displayName: restore all projects