Skip to content

Commit d5ed210

Browse files
mrkm4ntrgatorsmile
authored andcommitted
[SPARK-23381][CORE] Murmur3 hash generates a different value from other implementations
## What changes were proposed in this pull request? Murmur3 hash generates a different value from the original and other implementations (like Scala standard library and Guava or so) when the length of a bytes array is not multiple of 4. ## How was this patch tested? Added a unit test. **Note: When we merge this PR, please give all the credits to Shintaro Murakami.** Author: Shintaro Murakami <mrkm4ntrgmail.com> Author: gatorsmile <[email protected]> Author: Shintaro Murakami <[email protected]> Closes #20630 from gatorsmile/pr-20568.
1 parent 0a73aa3 commit d5ed210

File tree

7 files changed

+96
-5
lines changed

7 files changed

+96
-5
lines changed

common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i
6060
}
6161

6262
public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
63+
// This is not compatible with original and another implementations.
64+
// But remain it for backward compatibility for the components existing before 2.3.
6365
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
6466
int lengthAligned = lengthInBytes - lengthInBytes % 4;
6567
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
@@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i
7173
return fmix(h1, lengthInBytes);
7274
}
7375

76+
public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) {
77+
// This is compatible with original and another implementations.
78+
// Use this method for new components after Spark 2.3.
79+
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
80+
int lengthAligned = lengthInBytes - lengthInBytes % 4;
81+
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
82+
int k1 = 0;
83+
for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) {
84+
k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift;
85+
}
86+
h1 ^= mixK1(k1);
87+
return fmix(h1, lengthInBytes);
88+
}
89+
7490
private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
7591
assert (lengthInBytes % 4 == 0);
7692
int h1 = seed;

common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i
6060
}
6161

6262
public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
63+
// This is not compatible with original and another implementations.
64+
// But remain it for backward compatibility for the components existing before 2.3.
6365
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
6466
int lengthAligned = lengthInBytes - lengthInBytes % 4;
6567
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
@@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i
7173
return fmix(h1, lengthInBytes);
7274
}
7375

76+
public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) {
77+
// This is compatible with original and another implementations.
78+
// Use this method for new components after Spark 2.3.
79+
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
80+
int lengthAligned = lengthInBytes - lengthInBytes % 4;
81+
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
82+
int k1 = 0;
83+
for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) {
84+
k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift;
85+
}
86+
h1 ^= mixK1(k1);
87+
return fmix(h1, lengthInBytes);
88+
}
89+
7490
private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
7591
assert (lengthInBytes % 4 == 0);
7692
int h1 = seed;

common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import java.util.Random;
2323
import java.util.Set;
2424

25+
import scala.util.hashing.MurmurHash3$;
26+
2527
import org.apache.spark.unsafe.Platform;
2628
import org.junit.Assert;
2729
import org.junit.Test;
@@ -51,6 +53,23 @@ public void testKnownLongInputs() {
5153
Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE));
5254
}
5355

56+
// SPARK-23381 Check whether the hash of the byte array is the same as another implementations
57+
@Test
58+
public void testKnownBytesInputs() {
59+
byte[] test = "test".getBytes(StandardCharsets.UTF_8);
60+
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test, 0),
61+
Murmur3_x86_32.hashUnsafeBytes2(test, Platform.BYTE_ARRAY_OFFSET, test.length, 0));
62+
byte[] test1 = "test1".getBytes(StandardCharsets.UTF_8);
63+
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test1, 0),
64+
Murmur3_x86_32.hashUnsafeBytes2(test1, Platform.BYTE_ARRAY_OFFSET, test1.length, 0));
65+
byte[] te = "te".getBytes(StandardCharsets.UTF_8);
66+
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(te, 0),
67+
Murmur3_x86_32.hashUnsafeBytes2(te, Platform.BYTE_ARRAY_OFFSET, te.length, 0));
68+
byte[] tes = "tes".getBytes(StandardCharsets.UTF_8);
69+
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(tes, 0),
70+
Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0));
71+
}
72+
5473
@Test
5574
public void randomizedStressTest() {
5675
int size = 65536;

mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.ml.feature
1919

20+
import org.apache.spark.SparkException
2021
import org.apache.spark.annotation.{Experimental, Since}
2122
import org.apache.spark.ml.Transformer
2223
import org.apache.spark.ml.attribute.AttributeGroup
@@ -28,6 +29,8 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
2829
import org.apache.spark.sql.{DataFrame, Dataset, Row}
2930
import org.apache.spark.sql.functions._
3031
import org.apache.spark.sql.types._
32+
import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2}
33+
import org.apache.spark.unsafe.types.UTF8String
3134
import org.apache.spark.util.Utils
3235
import org.apache.spark.util.collection.OpenHashMap
3336

@@ -138,7 +141,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme
138141

139142
@Since("2.3.0")
140143
override def transform(dataset: Dataset[_]): DataFrame = {
141-
val hashFunc: Any => Int = OldHashingTF.murmur3Hash
144+
val hashFunc: Any => Int = FeatureHasher.murmur3Hash
142145
val n = $(numFeatures)
143146
val localInputCols = $(inputCols)
144147
val catCols = if (isSet(categoricalCols)) {
@@ -218,4 +221,32 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] {
218221

219222
@Since("2.3.0")
220223
override def load(path: String): FeatureHasher = super.load(path)
224+
225+
private val seed = OldHashingTF.seed
226+
227+
/**
228+
* Calculate a hash code value for the term object using
229+
* Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32).
230+
* This is the default hash algorithm used from Spark 2.0 onwards.
231+
* Use hashUnsafeBytes2 to match the original algorithm with the value.
232+
* See SPARK-23381.
233+
*/
234+
@Since("2.3.0")
235+
private[feature] def murmur3Hash(term: Any): Int = {
236+
term match {
237+
case null => seed
238+
case b: Boolean => hashInt(if (b) 1 else 0, seed)
239+
case b: Byte => hashInt(b, seed)
240+
case s: Short => hashInt(s, seed)
241+
case i: Int => hashInt(i, seed)
242+
case l: Long => hashLong(l, seed)
243+
case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
244+
case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
245+
case s: String =>
246+
val utf8 = UTF8String.fromString(s)
247+
hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed)
248+
case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " +
249+
s"support type ${term.getClass.getCanonicalName} of input data.")
250+
}
251+
}
221252
}

mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ object HashingTF {
135135

136136
private[HashingTF] val Murmur3: String = "murmur3"
137137

138-
private val seed = 42
138+
private[spark] val seed = 42
139139

140140
/**
141141
* Calculate a hash code value for the term object using the native Scala implementation.

mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
2727
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2828
import org.apache.spark.sql.functions.col
2929
import org.apache.spark.sql.types._
30+
import org.apache.spark.util.Utils
3031

3132
class FeatureHasherSuite extends SparkFunSuite
3233
with MLlibTestSparkContext
3334
with DefaultReadWriteTest {
3435

3536
import testImplicits._
3637

37-
import HashingTFSuite.murmur3FeatureIdx
38+
import FeatureHasherSuite.murmur3FeatureIdx
3839

3940
implicit private val vectorEncoder = ExpressionEncoder[Vector]()
4041

@@ -216,3 +217,11 @@ class FeatureHasherSuite extends SparkFunSuite
216217
testDefaultReadWrite(t)
217218
}
218219
}
220+
221+
object FeatureHasherSuite {
222+
223+
private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = {
224+
Utils.nonNegativeMod(FeatureHasher.murmur3Hash(term), numFeatures)
225+
}
226+
227+
}

python/pyspark/ml/feature.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -741,9 +741,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
741741
>>> df = spark.createDataFrame(data, cols)
742742
>>> hasher = FeatureHasher(inputCols=cols, outputCol="features")
743743
>>> hasher.transform(df).head().features
744-
SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0})
744+
SparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})
745745
>>> hasher.setCategoricalCols(["real"]).transform(df).head().features
746-
SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0})
746+
SparseVector(262144, {171257: 1.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})
747747
>>> hasherPath = temp_path + "/hasher"
748748
>>> hasher.save(hasherPath)
749749
>>> loadedHasher = FeatureHasher.load(hasherPath)

0 commit comments

Comments
 (0)