Skip to content

Commit c023271

Browse files
authored
Support more types for HashEstimator (#5104)
* support more types but tests failed * fix bugs * bump to ort1.3 pre-release * correct/skip some tests * refactor tests * fix HashKey8V2 * This reverts commit e0c * add comments * revert changes on V1 * upgrade to ORT1.3 official * fix part of conflits * update * update * update * fix test failure * fix another test failure
1 parent d682a60 commit c023271

File tree

7 files changed

+70
-49
lines changed

7 files changed

+70
-49
lines changed

build/Dependencies.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
<GoogleProtobufPackageVersion>3.10.1</GoogleProtobufPackageVersion>
1717
<LightGBMPackageVersion>2.2.3</LightGBMPackageVersion>
1818
<MicrosoftExtensionsPackageVersion>2.1.0</MicrosoftExtensionsPackageVersion>
19-
<MicrosoftMLOnnxRuntimePackageVersion>1.2</MicrosoftMLOnnxRuntimePackageVersion>
19+
<MicrosoftMLOnnxRuntimePackageVersion>1.3.0</MicrosoftMLOnnxRuntimePackageVersion>
2020
<MlNetMklDepsPackageVersion>0.0.0.9</MlNetMklDepsPackageVersion>
2121
<ParquetDotNetPackageVersion>2.1.3</ParquetDotNetPackageVersion>
2222
<SystemDrawingCommonPackageVersion>4.5.0</SystemDrawingCommonPackageVersion>

src/Microsoft.ML.Data/Transforms/Hashing.cs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ public uint HashCoreOld(uint seed, uint mask, in float value)
544544

545545
[MethodImpl(MethodImplOptions.AggressiveInlining)]
546546
public uint HashCore(uint seed, uint mask, in float value)
547-
=> float.IsNaN(value) ? 0 : (Hashing.MixHash(Hashing.MurmurRound(seed, FloatUtils.GetBits(value == 0 ? 0 : value)), sizeof(uint)) & mask) + 1;
547+
=> float.IsNaN(value) ? 0 : (Hashing.MixHash(Hashing.MurmurRound(seed, FloatUtils.GetBits(value == 0 ? 0 : value)), sizeof(float)) & mask) + 1;
548548

549549
[MethodImpl(MethodImplOptions.AggressiveInlining)]
550550
public uint HashCore(uint seed, uint mask, in VBuffer<float> values)
@@ -578,7 +578,7 @@ public uint HashCore(uint seed, uint mask, in double value)
578578
if (double.IsNaN(value))
579579
return 0;
580580

581-
return (Hashing.MixHash(HashRound(seed, value), sizeof(uint)) & mask) + 1;
581+
return (Hashing.MixHash(HashRound(seed, value), sizeof(double)) & mask) + 1;
582582
}
583583

584584
[MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -600,8 +600,6 @@ private uint HashRound(uint seed, double value)
600600
ulong v = FloatUtils.GetBits(value == 0 ? 0 : value);
601601
var hash = Hashing.MurmurRound(seed, Utils.GetLo(v));
602602
var hi = Utils.GetHi(v);
603-
if (hi == 0)
604-
return hash;
605603
return Hashing.MurmurRound(hash, hi);
606604
}
607605
}
@@ -815,7 +813,7 @@ public uint HashCoreOld(uint seed, uint mask, in ulong value)
815813
[MethodImpl(MethodImplOptions.AggressiveInlining)]
816814
public uint HashCore(uint seed, uint mask, in ulong value)
817815
{
818-
return (Hashing.MixHash(HashRound(seed, value), sizeof(uint)) & mask) + 1;
816+
return (Hashing.MixHash(HashRound(seed, value), sizeof(ulong)) & mask) + 1;
819817
}
820818

821819
[MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -832,8 +830,6 @@ private uint HashRound(uint seed, ulong value)
832830
{
833831
var hash = Hashing.MurmurRound(seed, Utils.GetLo(value));
834832
var hi = Utils.GetHi(value);
835-
if (hi == 0)
836-
return hash;
837833
return Hashing.MurmurRound(hash, hi);
838834
}
839835
}
@@ -970,7 +966,7 @@ public uint HashCoreOld(uint seed, uint mask, in long value)
970966
[MethodImpl(MethodImplOptions.AggressiveInlining)]
971967
public uint HashCore(uint seed, uint mask, in long value)
972968
{
973-
return (Hashing.MixHash(HashRound(seed, value), sizeof(uint)) & mask) + 1;
969+
return (Hashing.MixHash(HashRound(seed, value), sizeof(long)) & mask) + 1;
974970
}
975971

976972
[MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -987,8 +983,6 @@ private uint HashRound(uint seed, long value)
987983
{
988984
var hash = Hashing.MurmurRound(seed, Utils.GetLo((ulong)value));
989985
var hi = Utils.GetHi((ulong)value);
990-
if (hi == 0)
991-
return hash;
992986
return Hashing.MurmurRound(hash, hi);
993987
}
994988
}
@@ -1378,8 +1372,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri
13781372
castNode.AddAttribute("to", NumberDataViewType.UInt32.RawType);
13791373
murmurNode = ctx.CreateNode(opType, castOutput, murmurOutput, ctx.GetNodeName(opType), "com.microsoft");
13801374
}
1381-
else if (srcType == NumberDataViewType.UInt32 ||
1382-
srcType == NumberDataViewType.Int32 || srcType == TextDataViewType.Instance)
1375+
else if (srcType == NumberDataViewType.UInt32 || srcType == NumberDataViewType.Int32 || srcType == NumberDataViewType.UInt64 ||
1376+
srcType == NumberDataViewType.Int64 || srcType == NumberDataViewType.Single || srcType == NumberDataViewType.Double || srcType == TextDataViewType.Instance)
1377+
13831378
{
13841379
murmurNode = ctx.CreateNode(opType, srcVariable, murmurOutput, ctx.GetNodeName(opType), "com.microsoft");
13851380
}

test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ public sealed partial class TestDataPipe : TestDataPipeBase
2727

2828
private static Double[] _dataDouble = new Double[] { -0.0, 0, 1, -1, 2, -2, Double.NaN, Double.MinValue,
2929
Double.MaxValue, Double.Epsilon, Double.NegativeInfinity, Double.PositiveInfinity };
30-
private static uint[] _resultsDouble = new uint[] { 16, 16, 25, 27, 12, 2, 0, 6, 17, 4, 11, 30 };
30+
private static uint[] _resultsDouble = new uint[] { 30, 30, 19, 24, 32, 25, 0, 2, 7, 30, 5, 3 };
3131

3232
private static VBuffer<Double> _dataDoubleSparse = new VBuffer<Double>(5, 3, new double[] { -0.0, 0, 1 }, new[] { 0, 3, 4 });
33-
private static uint[] _resultsDoubleSparse = new uint[] { 16,16,16,16, 25 };
33+
private static uint[] _resultsDoubleSparse = new uint[] { 30, 30, 30, 30, 19 };
3434

3535
[Fact()]
3636
public void SavePipeLabelParsers()

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,8 +1201,8 @@ public void OneHotHashEncodingOnnxConversionTest()
12011201
// An InvalidOperationException stating that the onnx pipeline can't be fully converted is thrown
12021202
// when users try to convert the items mentioned above.
12031203
public void MurmurHashScalarTest(
1204-
[CombinatorialValues(DataKind.SByte, DataKind.Int16, DataKind.Int32, DataKind.Byte,
1205-
DataKind.UInt16, DataKind.UInt32, DataKind.String, DataKind.Boolean)] DataKind type,
1204+
[CombinatorialValues(DataKind.SByte, DataKind.Int16, DataKind.Int32, DataKind.Int64, DataKind.Byte,
1205+
DataKind.UInt16, DataKind.UInt32, DataKind.UInt64, DataKind.Single, DataKind.Double, DataKind.String, DataKind.Boolean)] DataKind type,
12061206
[CombinatorialValues(1, 5, 31)] int numberOfBits, bool useOrderedHashing)
12071207
{
12081208

@@ -1215,7 +1215,11 @@ public void MurmurHashScalarTest(
12151215
(type == DataKind.UInt16) ? 6 :
12161216
(type == DataKind.Int32) ? 8 :
12171217
(type == DataKind.UInt32) ? 10 :
1218-
(type == DataKind.String) ? 12 : 14;
1218+
(type == DataKind.Int64) ? 12 :
1219+
(type == DataKind.UInt64) ? 14 :
1220+
(type == DataKind.Single) ? 16 :
1221+
(type == DataKind.Double) ? 18 :
1222+
(type == DataKind.String) ? 20 : 22;
12191223

12201224
var dataView = mlContext.Data.LoadFromTextFile(dataPath, new[] {
12211225
new TextLoader.Column("Value", type, column),
@@ -1252,9 +1256,9 @@ public void MurmurHashScalarTest(
12521256
// An InvalidOperationException stating that the onnx pipeline can't be fully converted is thrown
12531257
// when users try to convert the items mentioned above.
12541258
public void MurmurHashVectorTest(
1255-
[CombinatorialValues(DataKind.SByte, DataKind.Int16, DataKind.Int32, DataKind.Byte,
1256-
DataKind.UInt16, DataKind.UInt32, DataKind.String, DataKind.Boolean)] DataKind type,
1257-
[CombinatorialValues(1, 5, 31)] int numberOfBits)
1259+
[CombinatorialValues(DataKind.SByte, DataKind.Int16, DataKind.Int32, DataKind.Int64, DataKind.Byte,
1260+
DataKind.UInt16, DataKind.UInt32, DataKind.UInt64, DataKind.Single, DataKind.Double, DataKind.String, DataKind.Boolean)] DataKind type,
1261+
[CombinatorialValues(1, 5, 31)] int numberOfBits)
12581262
{
12591263

12601264
var mlContext = new MLContext();
@@ -1266,15 +1270,23 @@ public void MurmurHashVectorTest(
12661270
(type == DataKind.UInt16) ? 6 :
12671271
(type == DataKind.Int32) ? 8 :
12681272
(type == DataKind.UInt32) ? 10 :
1269-
(type == DataKind.String) ? 12 : 14;
1273+
(type == DataKind.Int64) ? 12 :
1274+
(type == DataKind.UInt64) ? 14 :
1275+
(type == DataKind.Single) ? 16 :
1276+
(type == DataKind.Double) ? 18 :
1277+
(type == DataKind.String) ? 20 : 22;
12701278

12711279
var columnEnd = (type == DataKind.SByte) ? 1 :
12721280
(type == DataKind.Byte) ? 3 :
12731281
(type == DataKind.Int16) ? 5 :
12741282
(type == DataKind.UInt16) ? 7 :
12751283
(type == DataKind.Int32) ? 9 :
12761284
(type == DataKind.UInt32) ? 11 :
1277-
(type == DataKind.String) ? 13 : 15;
1285+
(type == DataKind.Int64) ? 13 :
1286+
(type == DataKind.UInt64) ? 15 :
1287+
(type == DataKind.Single) ? 17 :
1288+
(type == DataKind.Double) ? 19 :
1289+
(type == DataKind.String) ? 21 : 23;
12781290

12791291
var dataView = mlContext.Data.LoadFromTextFile(dataPath, new[] {
12801292
new TextLoader.Column("Value", type, columnStart, columnEnd),

test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -416,18 +416,16 @@ public void TestTrainTestSplitWithStratification()
416416
Assert.Contains(4, ids);
417417
split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.DateTimeStrat));
418418
ids = split.TestSet.GetColumn<int>(split.TestSet.Schema[nameof(Input.Id)]);
419-
Assert.Contains(0, ids);
420-
Assert.Contains(7, ids);
419+
Assert.Contains(5, ids);
420+
Assert.Contains(6, ids);
421421
split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.DateTimeOffsetStrat));
422422
ids = split.TrainSet.GetColumn<int>(split.TrainSet.Schema[nameof(Input.Id)]);
423-
Assert.Contains(1, ids);
424-
Assert.Contains(3, ids);
425-
split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.TimeSpanStrat));
426-
ids = split.TestSet.GetColumn<int>(split.TestSet.Schema[nameof(Input.Id)]);
427423
Assert.Contains(4, ids);
428-
Assert.Contains(5, ids);
429-
Assert.Contains(6, ids);
430424
Assert.Contains(7, ids);
425+
split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.TimeSpanStrat));
426+
ids = split.TestSet.GetColumn<int>(split.TestSet.Schema[nameof(Input.Id)]);
427+
Assert.Contains(1, ids);
428+
Assert.Contains(2, ids);
431429
}
432430
}
433431
}

test/Microsoft.ML.Tests/Transformers/HashTests.cs

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ ValueGetter<TType> hashGetter<TType>(HashingEstimator.ColumnOptions colInfo)
219219
Assert.Equal(expectedCombinedSparse, result);
220220
}
221221

222-
private void HashTestPositiveIntegerCore(ulong value, uint expected, uint expectedOrdered, uint expectedOrdered3, uint expectedCombined, uint expectedCombinedSparse)
222+
private void HashTestPositiveIntegerCore32Bits(ulong value, uint expected, uint expectedOrdered, uint expectedOrdered3, uint expectedCombined, uint expectedCombinedSparse)
223+
223224
{
224225
uint eKey = value == 0 ? 0 : expected;
225226
uint eoKey = value == 0 ? 0 : expectedOrdered;
@@ -241,29 +242,44 @@ private void HashTestPositiveIntegerCore(ulong value, uint expected, uint expect
241242
HashTestCore((uint)value, NumberDataViewType.UInt32, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
242243
HashTestCore((uint)value, new KeyDataViewType(typeof(uint), int.MaxValue - 1), eKey, eoKey, e3Key, ecKey, 0);
243244
}
244-
HashTestCore(value, NumberDataViewType.UInt64, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
245-
HashTestCore((ulong)value, new KeyDataViewType(typeof(ulong), int.MaxValue - 1), eKey, eoKey, e3Key, ecKey, 0);
246245

247246
HashTestCore(new DataViewRowId(value, 0), RowIdDataViewType.Instance, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
247+
HashTestCore((ulong)value, new KeyDataViewType(typeof(ulong), int.MaxValue - 1), eKey, eoKey, e3Key, ecKey, 0);
248248

249249
// Next let's check signed numbers.
250-
251250
if (value <= (ulong)sbyte.MaxValue)
252251
HashTestCore((sbyte)value, NumberDataViewType.SByte, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
253252
if (value <= (ulong)short.MaxValue)
254253
HashTestCore((short)value, NumberDataViewType.Int16, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
255254
if (value <= int.MaxValue)
256255
HashTestCore((int)value, NumberDataViewType.Int32, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
256+
}
257+
258+
private void HashTestPositiveIntegerCore64Bits(ulong value, uint expected, uint expectedOrdered, uint expectedOrdered3, uint expectedCombined, uint expectedCombinedSparse)
259+
260+
{
261+
uint eKey = value == 0 ? 0 : expected;
262+
uint eoKey = value == 0 ? 0 : expectedOrdered;
263+
uint e3Key = value == 0 ? 0 : expectedOrdered3;
264+
uint ecKey = value == 0 ? 0 : expectedCombined;
265+
266+
HashTestCore(value, NumberDataViewType.UInt64, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
267+
268+
// Next let's check signed numbers.
257269
if (value <= long.MaxValue)
258-
HashTestCore((long)value, NumberDataViewType.Int64, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
270+
HashTestCore((long)value, NumberDataViewType.Int64, expected, expectedOrdered, expectedOrdered3, expectedCombined, expectedCombinedSparse);
259271
}
260272

261273
[Fact]
262274
public void TestHashIntegerNumbers()
263275
{
264-
HashTestPositiveIntegerCore(0, 842, 358, 20, 882, 1010);
265-
HashTestPositiveIntegerCore(1, 502, 537, 746, 588, 286);
266-
HashTestPositiveIntegerCore(2, 407, 801, 652, 696, 172);
276+
HashTestPositiveIntegerCore32Bits(0, 842, 358, 20, 882, 1010);
277+
HashTestPositiveIntegerCore32Bits(1, 502, 537, 746, 588, 286);
278+
HashTestPositiveIntegerCore32Bits(2, 407, 801, 652, 696, 172);
279+
280+
HashTestPositiveIntegerCore64Bits(0, 512, 851, 795, 1010, 620);
281+
HashTestPositiveIntegerCore64Bits(1, 329, 190, 574, 491, 805);
282+
HashTestPositiveIntegerCore64Bits(2, 484, 713, 128, 606, 326);
267283
}
268284

269285
[Fact]
@@ -279,10 +295,10 @@ public void TestHashFloatingPointNumbers()
279295
HashTestCore(1f, NumberDataViewType.Single, 463, 855, 732, 75, 487);
280296
HashTestCore(-1f, NumberDataViewType.Single, 252, 612, 780, 179, 80);
281297
HashTestCore(0f, NumberDataViewType.Single, 842, 358, 20, 882, 1010);
282-
// Note that while we have the hash for numeric types be equal, the same is not necessarily the case for floating point numbers.
283-
HashTestCore(1d, NumberDataViewType.Double, 937, 667, 424, 727, 510);
284-
HashTestCore(-1d, NumberDataViewType.Double, 930, 78, 813, 582, 179);
285-
HashTestCore(0d, NumberDataViewType.Double, 842, 358, 20, 882, 1010);
298+
299+
HashTestCore(1d, NumberDataViewType.Double, 188, 57, 690, 727, 36);
300+
HashTestCore(-1d, NumberDataViewType.Double, 885, 804, 22, 582, 346);
301+
HashTestCore(0d, NumberDataViewType.Double, 512, 851, 795, 1010, 620);
286302
}
287303

288304
[Fact]

test/data/type-samples.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
sbyte byte short ushort int uint strings boolean
2-
0 1 0 23 0 4554 0 53 0 25 0 35 0 rain 0 1
3-
2 3 2 13 2 455 2 63 2 63 2 63 djldaoiejffjauhglehdlgh pink 1 0
4-
127 23 127 65 127 93 127 99 127 69 127 91 alibaba bug
5-
-128 24 255 25 32767 325 65535 632 2147483647 34 4294967295 45 to mato monkey
6-
0 2 5 98 -32768 335 78 698 -2147483648 97 3 56 U+12w blue
1+
sbyte byte short ushort int uint long ulong float double strings boolean
2+
0 1 0 23 0 4554 0 53 0 25 0 35 0 -1 0 1 0 -1 0 -1 0 rain 0 1
3+
2 3 2 13 2 455 2 63 2 63 2 63 2 63 2 63 1 2 1 2 djldaoiejffjauhglehdlgh pink 1 0
4+
127 23 127 65 127 93 127 99 127 69 127 91 2147483647 34 2147483647 34 -2 300 -2 300 alibaba bug
5+
-128 24 255 25 32767 325 65535 632 2147483647 34 4294967295 45 9223372036854775807 97 9223372036854775807 97 355 4 355 4 to mato monkey
6+
0 2 5 98 -32768 335 78 698 -2147483648 97 3 56 -9223372036854775808 5 4 5 -4000 5 -4000 5 U+12w blue

0 commit comments

Comments
 (0)