diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index 058eda3bf0..c1863a32cf 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -520,7 +520,7 @@ public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string ds long[] values = Array.ConvertAll(_values.GetValues().ToArray(), item => Convert.ToInt64(item)); node.AddAttribute("values_int64s", values); } - else if (TypeOutput == NumberDataViewType.Single) + else if (TypeOutput == NumberDataViewType.Double || TypeOutput == NumberDataViewType.Single) { float[] values = Array.ConvertAll(_values.GetValues().ToArray(), item => Convert.ToSingle(item)); node.AddAttribute("values_floats", values); diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index 2d064943b2..4f39d5af0b 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -786,6 +786,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src OnnxNode node; long[] termIds; string opType = "LabelEncoder"; + OnnxNode castNode; var labelEncoderOutput = ctx.AddIntermediateVariable(_types[iinfo], "LabelEncoderOutput", true); if (info.TypeSrc.GetItemType().Equals(TextDataViewType.Instance)) @@ -800,6 +801,26 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src var terms = GetTermsAndIds(iinfo, out termIds); node.AddAttribute("keys_floats", terms); } + else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Double)) + { + var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true); + castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); + castNode.AddAttribute("to", t); + node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); + var terms = GetTermsAndIds(iinfo, out termIds); + node.AddAttribute("keys_floats", terms); + } + else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Int64)) + { + var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true); + castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType(); + castNode.AddAttribute("to", t); + node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType)); + var terms = GetTermsAndIds(iinfo, out termIds); + node.AddAttribute("keys_strings", terms.Select(item => item.ToString())); + } else { // LabelEncoder-2 in ORT v1 only supports the following mappings @@ -822,7 +843,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src InternalDataKindExtensions.TryGetDataKind(_parent._unboundMaps[iinfo].OutputType.RawType, out dataKind); opType = "Cast"; - var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), ""); + castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), ""); castNode.AddAttribute("to", dataKind.ToType()); return true; diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index f138917912..e748d43e74 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1094,6 +1094,8 @@ public void IndicateMissingValuesOnnxConversionTest() [Theory] [InlineData(DataKind.Single)] + [InlineData(DataKind.Int64)] + [InlineData(DataKind.Double)] [InlineData(DataKind.String)] public void ValueToKeyMappingOnnxConversionTest(DataKind valueType) {