Skip to content

Commit d58e8d1

Browse files
authored
Adding support for MurmurHash KeyDataTypes (#5138)
* merging * removed some outdated comments * update
1 parent c023271 commit d58e8d1

File tree

2 files changed

+82
-17
lines changed

2 files changed

+82
-17
lines changed

src/Microsoft.ML.Data/Transforms/Hashing.cs

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,20 +1349,38 @@ private void AddMetaKeyValues(int i, DataViewSchema.Annotations.Builder builder)
13491349
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, string dstVariable)
13501350
{
13511351
string castOutput;
1352+
string isGreaterThanZeroOutput = "";
13521353
OnnxNode castNode;
13531354
OnnxNode murmurNode;
1355+
OnnxNode isZeroNode;
13541356

13551357
var srcType = _srcTypes[iinfo].GetItemType();
1356-
if (srcType is KeyDataViewType)
1357-
return false;
13581358
if (_parent._columns[iinfo].Combine)
13591359
return false;
13601360

13611361
var opType = "MurmurHash3";
13621362
string murmurOutput = ctx.AddIntermediateVariable(_dstTypes[iinfo], "MurmurOutput");
13631363

1364-
// Numeric input types are limited to those supported by the Onnxruntime MurmurHash operator, which currently only supports
1365-
// uints and ints. Thus, ulongs, longs, doubles and floats are not supported.
1364+
// Get zero value indeces
1365+
if (_srcTypes[iinfo] is KeyDataViewType)
1366+
{
1367+
var optType2 = "Cast";
1368+
castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "CastOutput", true);
1369+
isZeroNode = ctx.CreateNode(optType2, srcVariable, castOutput, ctx.GetNodeName(optType2), "");
1370+
isZeroNode.AddAttribute("to", NumberDataViewType.Int64.RawType);
1371+
1372+
var zero = ctx.AddInitializer(0);
1373+
var isGreaterThanZeroOutputBool = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "isGreaterThanZeroOutputBool");
1374+
optType2 = "Greater";
1375+
ctx.CreateNode(optType2, new[] { castOutput, zero }, new[] { isGreaterThanZeroOutputBool }, ctx.GetNodeName(optType2), "");
1376+
1377+
isGreaterThanZeroOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "isGreaterThanZeroOutput");
1378+
optType2 = "Cast";
1379+
isZeroNode = ctx.CreateNode(optType2, isGreaterThanZeroOutputBool, isGreaterThanZeroOutput, ctx.GetNodeName(optType2), "");
1380+
isZeroNode.AddAttribute("to", NumberDataViewType.Int64.RawType);
1381+
}
1382+
1383+
// Since these numeric types are not supported by Onnxruntime, we cast them to UInt32.
13661384
if (srcType == NumberDataViewType.UInt16 || srcType == NumberDataViewType.Int16 ||
13671385
srcType == NumberDataViewType.SByte || srcType == NumberDataViewType.Byte ||
13681386
srcType == BooleanDataViewType.Instance)
@@ -1372,15 +1390,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri
13721390
castNode.AddAttribute("to", NumberDataViewType.UInt32.RawType);
13731391
murmurNode = ctx.CreateNode(opType, castOutput, murmurOutput, ctx.GetNodeName(opType), "com.microsoft");
13741392
}
1375-
else if (srcType == NumberDataViewType.UInt32 || srcType == NumberDataViewType.Int32 || srcType == NumberDataViewType.UInt64 ||
1376-
srcType == NumberDataViewType.Int64 || srcType == NumberDataViewType.Single || srcType == NumberDataViewType.Double || srcType == TextDataViewType.Instance)
1377-
1378-
{
1379-
murmurNode = ctx.CreateNode(opType, srcVariable, murmurOutput, ctx.GetNodeName(opType), "com.microsoft");
1380-
}
13811393
else
13821394
{
1383-
return false;
1395+
murmurNode = ctx.CreateNode(opType, srcVariable, murmurOutput, ctx.GetNodeName(opType), "com.microsoft");
13841396
}
13851397

13861398
murmurNode.AddAttribute("positive", 1);
@@ -1417,10 +1429,17 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri
14171429
string one = ctx.AddInitializer(1);
14181430
ctx.CreateNode(opType, new[] { castOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), "");
14191431

1432+
string mulOutput = ctx.AddIntermediateVariable(vectorShape, "MulOutput");
1433+
if (_srcTypes[iinfo] is KeyDataViewType)
1434+
{
1435+
opType = "Mul";
1436+
ctx.CreateNode(opType, new[] { isGreaterThanZeroOutput, addOutput }, new[] { mulOutput }, ctx.GetNodeName(opType), "");
1437+
}
1438+
14201439
opType = "Cast";
1421-
var castNodeFinal = ctx.CreateNode(opType, addOutput, dstVariable, ctx.GetNodeName(opType), "");
1440+
var input = (_srcTypes[iinfo] is KeyDataViewType) ? mulOutput: addOutput;
1441+
var castNodeFinal = ctx.CreateNode(opType, input, dstVariable, ctx.GetNodeName(opType), "");
14221442
castNodeFinal.AddAttribute("to", _dstTypes[iinfo].GetItemType().RawType);
1423-
14241443
return true;
14251444
}
14261445

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,11 +1195,57 @@ public void OneHotHashEncodingOnnxConversionTest()
11951195
Done();
11961196
}
11971197

1198+
private class HashData
1199+
{
1200+
public uint Value { get; set; }
1201+
}
1202+
1203+
[Fact]
1204+
public void MurmurHashKeyTest()
1205+
{
1206+
var mlContext = new MLContext();
1207+
1208+
var samples = new[]
1209+
{
1210+
new HashData {Value = 232},
1211+
new HashData {Value = 42},
1212+
new HashData {Value = 0},
1213+
};
1214+
1215+
IDataView data = mlContext.Data.LoadFromEnumerable(samples);
1216+
1217+
var hashEstimator = mlContext.Transforms.Conversion.MapValueToKey("Value").Append(mlContext.Transforms.Conversion.Hash(new[]
1218+
{
1219+
new HashingEstimator.ColumnOptions(
1220+
"ValueHashed",
1221+
"Value")
1222+
}));
1223+
var model = hashEstimator.Fit(data);
1224+
var transformedData = model.Transform(data);
1225+
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data);
1226+
1227+
var onnxFileName = "MurmurHashV2.onnx";
1228+
var onnxTextName = "MurmurHashV2.txt";
1229+
var onnxModelPath = GetOutputPath(onnxFileName);
1230+
var onnxTextPath = GetOutputPath(onnxTextName);
1231+
1232+
SaveOnnxModel(onnxModel, onnxModelPath, onnxTextPath);
1233+
1234+
if (IsOnnxRuntimeSupported())
1235+
{
1236+
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
1237+
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
1238+
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
1239+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
1240+
var onnxTransformer = onnxEstimator.Fit(data);
1241+
var onnxResult = onnxTransformer.Transform(data);
1242+
CompareSelectedColumns<uint>("ValueHashed", "ValueHashed", transformedData, onnxResult);
1243+
}
1244+
Done();
1245+
}
1246+
11981247
[Theory]
11991248
[CombinatorialData]
1200-
// Due to lack of Onnxruntime support, long/ulong, double, floats, and OrderedHashing are not supported.
1201-
// An InvalidOperationException stating that the onnx pipeline can't be fully converted is thrown
1202-
// when users try to convert the items mentioned above.
12031249
public void MurmurHashScalarTest(
12041250
[CombinatorialValues(DataKind.SByte, DataKind.Int16, DataKind.Int32, DataKind.Int64, DataKind.Byte,
12051251
DataKind.UInt16, DataKind.UInt32, DataKind.UInt64, DataKind.Single, DataKind.Double, DataKind.String, DataKind.Boolean)] DataKind type,
@@ -1252,7 +1298,7 @@ public void MurmurHashScalarTest(
12521298

12531299
[Theory]
12541300
[CombinatorialData]
1255-
// Due to lack of Onnxruntime support, long/ulong, double, floats, and OrderedHashing are not supported.
1301+
// Due to lack of Onnxruntime support, OrderedHashing is not supported.
12561302
// An InvalidOperationException stating that the onnx pipeline can't be fully converted is thrown
12571303
// when users try to convert the items mentioned above.
12581304
public void MurmurHashVectorTest(

0 commit comments

Comments
 (0)