Skip to content

Commit 47c49ff

Browse files
authored
Double cast to float for some onnx estimators (#4745)
* double fix * added support for int64 keys * number change * modifying test case * modifying test case
1 parent 7db00a1 commit 47c49ff

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

src/Microsoft.ML.Data/Transforms/KeyToValue.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string ds
520520
long[] values = Array.ConvertAll<TValue, long>(_values.GetValues().ToArray(), item => Convert.ToInt64(item));
521521
node.AddAttribute("values_int64s", values);
522522
}
523-
else if (TypeOutput == NumberDataViewType.Single)
523+
else if (TypeOutput == NumberDataViewType.Double || TypeOutput == NumberDataViewType.Single)
524524
{
525525
float[] values = Array.ConvertAll<TValue, float>(_values.GetValues().ToArray(), item => Convert.ToSingle(item));
526526
node.AddAttribute("values_floats", values);

src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
786786
OnnxNode node;
787787
long[] termIds;
788788
string opType = "LabelEncoder";
789+
OnnxNode castNode;
789790
var labelEncoderOutput = ctx.AddIntermediateVariable(_types[iinfo], "LabelEncoderOutput", true);
790791

791792
if (info.TypeSrc.GetItemType().Equals(TextDataViewType.Instance))
@@ -800,6 +801,26 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
800801
var terms = GetTermsAndIds<float>(iinfo, out termIds);
801802
node.AddAttribute("keys_floats", terms);
802803
}
804+
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Double))
805+
{
806+
var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true);
807+
castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), "");
808+
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
809+
castNode.AddAttribute("to", t);
810+
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
811+
var terms = GetTermsAndIds<double>(iinfo, out termIds);
812+
node.AddAttribute("keys_floats", terms);
813+
}
814+
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Int64))
815+
{
816+
var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true);
817+
castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), "");
818+
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType();
819+
castNode.AddAttribute("to", t);
820+
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
821+
var terms = GetTermsAndIds<long>(iinfo, out termIds);
822+
node.AddAttribute("keys_strings", terms.Select(item => item.ToString()));
823+
}
803824
else
804825
{
805826
// LabelEncoder-2 in ORT v1 only supports the following mappings
@@ -822,7 +843,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
822843
InternalDataKindExtensions.TryGetDataKind(_parent._unboundMaps[iinfo].OutputType.RawType, out dataKind);
823844

824845
opType = "Cast";
825-
var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
846+
castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
826847
castNode.AddAttribute("to", dataKind.ToType());
827848

828849
return true;

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,8 @@ public void IndicateMissingValuesOnnxConversionTest()
10921092

10931093
[Theory]
10941094
[InlineData(DataKind.Single)]
1095+
[InlineData(DataKind.Int64)]
1096+
[InlineData(DataKind.Double)]
10951097
[InlineData(DataKind.String)]
10961098
public void ValueToKeyMappingOnnxConversionTest(DataKind valueType)
10971099
{

0 commit comments

Comments
 (0)