Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i
}

public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
// This is not compatible with original and another implementations.
// But remain it for backward compatibility for the components existing before 2.3.
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
Expand All @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i
return fmix(h1, lengthInBytes);
}

public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) {
// This is compatible with original and another implementations.
// Use this method for new components after Spark 2.3.
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
int k1 = 0;
for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) {
k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift;
}
h1 ^= mixK1(k1);
return fmix(h1, lengthInBytes);
}

private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
assert (lengthInBytes % 4 == 0);
int h1 = seed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i
}

public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
// This is not compatible with original and another implementations.
// But remain it for backward compatibility for the components existing before 2.3.
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
Expand All @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i
return fmix(h1, lengthInBytes);
}

public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) {
// This is compatible with original and another implementations.
// Use this method for new components after Spark 2.3.
assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
int lengthAligned = lengthInBytes - lengthInBytes % 4;
int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
int k1 = 0;
for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) {
k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift;
}
h1 ^= mixK1(k1);
return fmix(h1, lengthInBytes);
}

private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
assert (lengthInBytes % 4 == 0);
int h1 = seed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.util.Random;
import java.util.Set;

import scala.util.hashing.MurmurHash3$;

import org.apache.spark.unsafe.Platform;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -51,6 +53,22 @@ public void testKnownLongInputs() {
Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE));
}

@Test // SPARK-23381 Check whether the hash of the byte array is the same as another implementations.
public void testKnownBytesInputs() {
byte[] test = "test".getBytes(StandardCharsets.UTF_8);
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test, 0),
Murmur3_x86_32.hashUnsafeBytes2(test, Platform.BYTE_ARRAY_OFFSET, test.length, 0));
byte[] test1 = "test1".getBytes(StandardCharsets.UTF_8);
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test1, 0),
Murmur3_x86_32.hashUnsafeBytes2(test1, Platform.BYTE_ARRAY_OFFSET, test1.length, 0));
byte[] te = "te".getBytes(StandardCharsets.UTF_8);
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(te, 0),
Murmur3_x86_32.hashUnsafeBytes2(te, Platform.BYTE_ARRAY_OFFSET, te.length, 0));
byte[] tes = "tes".getBytes(StandardCharsets.UTF_8);
Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(tes, 0),
Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0));
}

@Test
public void randomizedStressTest() {
int size = 65536;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
Expand All @@ -28,6 +29,8 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.OpenHashMap

Expand Down Expand Up @@ -138,7 +141,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme

@Since("2.3.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val hashFunc: Any => Int = OldHashingTF.murmur3Hash
val hashFunc: Any => Int = FeatureHasher.murmur3Hash
val n = $(numFeatures)
val localInputCols = $(inputCols)
val catCols = if (isSet(categoricalCols)) {
Expand Down Expand Up @@ -218,4 +221,32 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] {

@Since("2.3.0")
override def load(path: String): FeatureHasher = super.load(path)

private val seed = OldHashingTF.seed

/**
* Calculate a hash code value for the term object using
* Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32).
* This is the default hash algorithm used from Spark 2.0 onwards.
* Use hashUnsafeBytes2 to match the original algorithm with the value.
* See SPARK-23381.
*/
@Since("2.3.0")
def murmur3Hash(term: Any): Int = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe private[feature]?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also address this comment.

term match {
case null => seed
case b: Boolean => hashInt(if (b) 1 else 0, seed)
case b: Byte => hashInt(b, seed)
case s: Short => hashInt(s, seed)
case i: Int => hashInt(i, seed)
case l: Long => hashLong(l, seed)
case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
case s: String =>
val utf8 = UTF8String.fromString(s)
hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed)
case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " +
s"support type ${term.getClass.getCanonicalName} of input data.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ object HashingTF {

private[HashingTF] val Murmur3: String = "murmur3"

private val seed = 42
private[spark] val seed = 42

/**
* Calculate a hash code value for the term object using the native Scala implementation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class FeatureHasherSuite extends SparkFunSuite
with MLlibTestSparkContext
with DefaultReadWriteTest {

import testImplicits._

import HashingTFSuite.murmur3FeatureIdx
import FeatureHasherSuite.murmur3FeatureIdx

implicit private val vectorEncoder = ExpressionEncoder[Vector]()

Expand Down Expand Up @@ -216,3 +217,11 @@ class FeatureHasherSuite extends SparkFunSuite
testDefaultReadWrite(t)
}
}

object FeatureHasherSuite {

private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = {
Utils.nonNegativeMod(FeatureHasher.murmur3Hash(term), numFeatures)
}

}