diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs index 7a81bee190..47c6ff8836 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Data; using Microsoft.ML.Data.IO; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -818,6 +819,8 @@ private static ValueMap CreateValueMapInvoke(DataViewSchema.Column public abstract Delegate GetGetter(DataViewRow input, int index); public abstract IDataView GetDataView(IHostEnvironment env); + public abstract TKey[] GetKeys(); + public abstract TValue[] GetValues(); } /// @@ -962,6 +965,16 @@ private static TValue GetVector(TValue value) } private static TValue GetValue(TValue value) => value; + + public override T[] GetKeys() + { + return _mapping.Keys.Cast().ToArray(); + } + public override T[] GetValues() + { + return _mapping.Values.Cast().ToArray(); + } + } /// @@ -1012,12 +1025,13 @@ private protected override IRowMapper MakeRowMapper(DataViewSchema schema) return new Mapper(this, schema, _valueMap, ColumnPairs); } - private sealed class Mapper : OneToOneMapperBase + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private readonly DataViewSchema _inputSchema; private readonly ValueMap _valueMap; private readonly (string outputColumnName, string inputColumnName)[] _columns; private readonly ValueMappingTransformer _parent; + public bool CanSaveOnnx(OnnxContext ctx) => true; internal Mapper(ValueMappingTransformer transform, DataViewSchema inputSchema, @@ -1040,6 +1054,227 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput, PrimitiveDataViewType itemType) + { + var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); + var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(itemType, (int)srcShape[1]), "castOutput"); + var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", itemType.RawType); + node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); + if (itemType == TextDataViewType.Instance) + node.AddAttribute("keys_strings", Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToString(item))); + else if (itemType == NumberDataViewType.Single) + node.AddAttribute("keys_floats", Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToSingle(item))); + else if (itemType == NumberDataViewType.Int64) + node.AddAttribute("keys_int64s", Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToInt64(item))); + + } + + private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName) + { + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); + OnnxNode node; + string opType = "LabelEncoder"; + var labelEncoderInput = srcVariableName; + var srcShape = ctx.RetrieveShapeOrNull(srcVariableName); + var typeValue = _valueMap.ValueColumn.Type; + var typeKey = _valueMap.KeyColumn.Type; + var kind = _valueMap.ValueColumn.Type.GetRawKind(); + + var labelEncoderOutput = (typeValue == NumberDataViewType.Single || typeValue == TextDataViewType.Instance || typeValue == NumberDataViewType.Int64) ? dstVariableName : + (typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) ? ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "LabelEncoderOutput") : + ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, (int) srcShape[1]), "LabelEncoderOutput"); + + // The LabelEncoder operator doesn't support mappings between the same type and only supports mappings between int64s, floats, and strings. + // As a result, we need to cast most inputs and outputs. In order to avoid as many unsupported mappings, we cast keys that are of NumberDataTypeView + // to strings and values of NumberDataViewType to int64s. + // String -> String mappings can't be supported. + if (typeKey == NumberDataViewType.Int64) + { + // To avoid a int64 -> int64 mapping, we cast keys to strings + if (typeValue is NumberDataViewType) + { + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); + } + else + { + node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); + node.AddAttribute("keys_int64s", _valueMap.GetKeys()); + } + } + else if (typeKey == NumberDataViewType.Int32) + { + // To avoid a string -> string mapping, we cast keys to int64s + if (typeValue is TextDataViewType) + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); + else + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); + } + else if (typeKey == NumberDataViewType.Int16) + { + if (typeValue is TextDataViewType) + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); + else + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); + } + else if (typeKey == NumberDataViewType.UInt64) + { + if (typeValue is TextDataViewType) + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); + else + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); + } + else if (typeKey == NumberDataViewType.UInt32) + { + if (typeValue is TextDataViewType) + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); + else + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); + } + else if (typeKey == NumberDataViewType.UInt16) + { + if (typeValue is TextDataViewType) + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64); + else + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); + } + else if (typeKey == NumberDataViewType.Single) + { + if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) + { + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); + } + else + { + node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); + node.AddAttribute("keys_floats", _valueMap.GetKeys()); + } + } + else if (typeKey == NumberDataViewType.Double) + { + if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance); + else + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single); + } + else if (typeKey == TextDataViewType.Instance) + { + if (typeValue == TextDataViewType.Instance) + return false; + node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType)); + node.AddAttribute("keys_strings", _valueMap.GetKeys>()); + } + else if (typeKey == BooleanDataViewType.Instance) + { + if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) + { + var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, (int)srcShape[1]), "castOutput"); + var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType(); + castNode.AddAttribute("to", t); + node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); + var values = Array.ConvertAll(_valueMap.GetKeys(), item => Convert.ToString(Convert.ToByte(item))); + node.AddAttribute("keys_strings", values); + } + else + CastInputTo(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single); + } + else + return false; + + if (typeValue == NumberDataViewType.Int64) + { + node.AddAttribute("values_int64s", _valueMap.GetValues()); + } + else if (typeValue == NumberDataViewType.Int32) + { + node.AddAttribute("values_int64s", _valueMap.GetValues().Select(item => Convert.ToInt64(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else if (typeValue == NumberDataViewType.Int16) + { + node.AddAttribute("values_int64s", _valueMap.GetValues().Select(item => Convert.ToInt64(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else if (typeValue == NumberDataViewType.UInt64 || kind == InternalDataKind.U8) + { + node.AddAttribute("values_int64s", _valueMap.GetValues().Select(item => Convert.ToInt64(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else if (typeValue == NumberDataViewType.UInt32 || kind == InternalDataKind.U4) + { + node.AddAttribute("values_int64s", _valueMap.GetValues().Select(item => Convert.ToInt64(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else if (typeValue == NumberDataViewType.UInt16) + { + node.AddAttribute("values_int64s", _valueMap.GetValues().Select(item => Convert.ToInt64(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else if (typeValue == NumberDataViewType.Single) + { + node.AddAttribute("values_floats", _valueMap.GetValues()); + } + else if (typeValue == NumberDataViewType.Double) + { + node.AddAttribute("values_floats", _valueMap.GetValues().Select(item => Convert.ToSingle(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else if (typeValue == TextDataViewType.Instance) + { + node.AddAttribute("values_strings", _valueMap.GetValues>()); + } + else if (typeValue == BooleanDataViewType.Instance) + { + node.AddAttribute("values_floats", _valueMap.GetValues().Select(item => Convert.ToSingle(item))); + var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), ""); + castNode.AddAttribute("to", typeValue.RawType); + } + else + return false; + + //Unknown keys should map to 0 + node.AddAttribute("default_int64", 0); + node.AddAttribute("default_string", ""); + node.AddAttribute("default_float", 0f); + return true; + } + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() { var result = new DataViewSchema.DetachedColumn[_columns.Length]; diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs index c67b19efa6..5457f3af20 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs @@ -674,9 +674,7 @@ private static double Round(double value, int digitsOfPrecision) public void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6, bool isRightColumnOnnxScalar = false) { var leftColumn = left.Schema[leftColumnName]; - var rightColumn = right.Schema[rightColumnName]; var leftType = leftColumn.Type.GetItemType(); - var rightType = rightColumn.Type.GetItemType(); if (leftType == NumberDataViewType.SByte) CompareSelectedColumns(leftColumnName, rightColumnName, left, right, isRightColumnOnnxScalar: isRightColumnOnnxScalar); diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index f09d02a09a..5e97a970a6 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1209,6 +1209,165 @@ public void ValueToKeyMappingOnnxConversionTest( Done(); } + [Theory] + [CombinatorialData] + // Due to lack of support in OnnxRuntime, String => String mappings are not supported + public void ValueMappingOnnxConversionTest([CombinatorialValues(DataKind.Int64, DataKind.Int32, DataKind.UInt32, DataKind.UInt64, + DataKind.UInt16, DataKind.Int16, DataKind.Double, DataKind.String, DataKind.Boolean)] + DataKind keyType, [CombinatorialValues(true, false)] bool treatValuesAsKeyType) + { + var mlContext = new MLContext(seed: 1); + string filePath = (keyType == DataKind.Boolean) ? GetDataPath("type-conversion-boolean.txt") + : GetDataPath("type-conversion.txt"); + + TextLoader.Column[] columnsVector = new[] + { + new TextLoader.Column("Keys", keyType, 0, 2) + }; + TextLoader.Column[] columnsScalar = new[] + { + new TextLoader.Column("Keys", keyType, 0) + }; + IDataView[] dataViews = + { + mlContext.Data.LoadFromTextFile(filePath, columnsScalar, separatorChar: '\t'), //scalar + mlContext.Data.LoadFromTextFile(filePath, columnsVector , separatorChar: '\t') //vector + }; + List> pipelines = new List>(); + + if (keyType == DataKind.Single) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 698 }, { 23, 7908 } }, "Keys", treatValuesAsKeyType)); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, false }, { 23, true } }, "Keys", treatValuesAsKeyType)); + } + else if (keyType == DataKind.Double) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 698 }, { 23, 7908 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 698 }, { 23, 7908 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, true }, { 23, false } }, "Keys")); + } + else if (keyType == DataKind.Boolean) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 6 }, { false, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 6 }, { false, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 6 }, { false, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 6 }, { false, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 6 }, { false, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, "True" }, { false, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 6 }, { false, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 698 }, { false, 7908 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { true, 698 }, { false, 7908 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { false, true }, { true, false } }, "Keys")); + } + else if (keyType == DataKind.String) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 3 }, { "23", 23 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 3 }, { "23", 23 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 6 }, { "23", 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 6 }, { "23", 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 6 }, { "23", 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 6 }, { "23", 23 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 6 }, { "23", 23 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", 3 }, { "23", 23 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { "3", true }, { "23", false } }, "Keys")); + } + else if (keyType == DataKind.Int32) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + } + else if (keyType == DataKind.Int16) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + } + else if (keyType == DataKind.Int64) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + } + else if (keyType == DataKind.UInt32) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + } + else if (keyType == DataKind.UInt16) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + } + else if (keyType == DataKind.UInt64) + { + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6 }, { 23, 46 } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, "True" }, { 23, "False" } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + pipelines.Add(mlContext.Transforms.Conversion.MapValue("Value", new Dictionary { { 3, 6.435f }, { 23, 23.534f } }, "Keys")); + } + foreach (IEstimator pipeline in pipelines) + { + for (int j = 0; j < dataViews.Length; j++) + { + var onnxFileName = "ValueMapping.onnx"; + TestPipeline(pipeline, dataViews[j], onnxFileName, new ColumnComparison[] { new ColumnComparison("Value") }); + } + } + Done(); + } + [Theory] [InlineData(DataKind.Single)] [InlineData(DataKind.Int64)] diff --git a/test/data/type-conversion-boolean.txt b/test/data/type-conversion-boolean.txt index c1f22fbc23..84b1b7522c 100644 --- a/test/data/type-conversion-boolean.txt +++ b/test/data/type-conversion-boolean.txt @@ -1 +1 @@ -False \ No newline at end of file +False True False \ No newline at end of file