From 9436426248917395d6565abf6ce7a43ae2c75169 Mon Sep 17 00:00:00 2001 From: batubond007 Date: Wed, 1 May 2024 13:43:57 +0300 Subject: [PATCH] [SPARK-47547][SQL] Fix inaccurate false positive rates when N is large for Bloom Filter ### What changes were proposed in this pull request? This PR introduces a new 128 bit hashing function based on Guava's solution. Bloom filter is now created based on hashing strategy, which is decided by the bit size of the Bloom Filter. ### Why are the changes needed? When the decided bit size by Bloom Filter is greater than Integer.MAX_VALUE, 32 bit hashing function wastes the bits which index is greater than Integer.MAX_VALUE. ### Does this PR introduce _any_ user-facing change? No, hashing function is decided by BloomFilter. ### How was this patch tested? Added UT for new hashing function. Manually tested the BloomFilter with N is near 500 million. ### Was this patch authored or co-authored using generative AI tooling? No. --- .../apache/spark/util/sketch/BloomFilter.java | 23 +- .../spark/util/sketch/BloomFilterImpl.java | 163 ++++-------- .../util/sketch/BloomFilterStrategies.java | 195 ++++++++++++++ .../util/sketch/BloomFilterStrategy.java | 34 +++ .../spark/util/sketch/Murmur3_x86_128.java | 248 ++++++++++++++++++ .../spark/unsafe/hash/Murmur3_x86_128.java | 245 +++++++++++++++++ .../unsafe/hash/Murmur3_x86_128Suite.java | 60 +++++ 7 files changed, 857 insertions(+), 111 deletions(-) create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterStrategies.java create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterStrategy.java create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_128.java create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_128.java create mode 100644 common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_128Suite.java diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 172b394689ca9..204f210147767 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -41,6 +41,8 @@ */ public abstract class BloomFilter { + protected BloomFilterStrategy strategy = BloomFilterStrategies.HASH_32; + public enum Version { /** * {@code BloomFilter} binary format version 1. All values written in big-endian order: @@ -223,6 +225,24 @@ public static long optimalNumOfBits(long expectedNumItems, long maxNumItems, lon return Math.min(optimalNumOfBits(expectedNumItems, fpp), maxNumOfBits); } + /** + * Returns the strategy that is used by the current implementation. + */ + public BloomFilterStrategy currentStrategy(){ + // Since the optimal strategy selection depends on the bit size, we can obtain it + // from the current bit size of the filter. + return optimalStrategy(this.bitSize()); + } + + /** + * Determines the most suitable strategy to use based on the bit size. + */ + public static BloomFilterStrategy optimalStrategy(long numBits){ + return numBits >= Integer.MAX_VALUE + ? BloomFilterStrategies.HASH_128 + : BloomFilterStrategies.HASH_32; + } + /** * Creates a {@link BloomFilter} with the expected number of insertions and a default expected * false positive probability of 3%. @@ -264,6 +284,7 @@ public static BloomFilter create(long expectedNumItems, long numBits) { throw new IllegalArgumentException("Number of bits must be positive"); } - return new BloomFilterImpl(optimalNumOfHashFunctions(expectedNumItems, numBits), numBits); + return new BloomFilterImpl(optimalNumOfHashFunctions(expectedNumItems, numBits), + numBits, optimalStrategy(numBits)); } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index 3bd04a531fe75..6f45b547ff107 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -25,13 +25,14 @@ class BloomFilterImpl extends BloomFilter implements Serializable { private BitArray bits; - BloomFilterImpl(int numHashFunctions, long numBits) { - this(new BitArray(numBits), numHashFunctions); + BloomFilterImpl(int numHashFunctions, long numBits, BloomFilterStrategy strategy) { + this(new BitArray(numBits), numHashFunctions, strategy); } - private BloomFilterImpl(BitArray bits, int numHashFunctions) { + private BloomFilterImpl(BitArray bits, int numHashFunctions, BloomFilterStrategy strategy) { this.bits = bits; this.numHashFunctions = numHashFunctions; + this.strategy = strategy; } private BloomFilterImpl() {} @@ -77,102 +78,17 @@ public boolean put(Object item) { @Override public boolean putString(String item) { - return putBinary(Utils.getBytesFromUTF8String(item)); - } - - @Override - public boolean putBinary(byte[] item) { - int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); - int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); - - long bitSize = bits.bitSize(); - boolean bitsChanged = false; - for (int i = 1; i <= numHashFunctions; i++) { - int combinedHash = h1 + (i * h2); - // Flip all the bits if it's negative (guaranteed positive number) - if (combinedHash < 0) { - combinedHash = ~combinedHash; - } - bitsChanged |= bits.set(combinedHash % bitSize); - } - return bitsChanged; - } - - @Override - public boolean mightContainString(String item) { - return mightContainBinary(Utils.getBytesFromUTF8String(item)); - } - - @Override - public boolean mightContainBinary(byte[] item) { - int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); - int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); - - long bitSize = bits.bitSize(); - for (int i = 1; i <= numHashFunctions; i++) { - int combinedHash = h1 + (i * h2); - // Flip all the bits if it's negative (guaranteed positive number) - if (combinedHash < 0) { - combinedHash = ~combinedHash; - } - if (!bits.get(combinedHash % bitSize)) { - return false; - } - } - return true; + return strategy.putString(item, bits, numHashFunctions); } @Override public boolean putLong(long item) { - // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n - // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. - // Note that `CountMinSketch` use a different strategy, it hash the input long element with - // every i to produce n hash values. - // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? - int h1 = Murmur3_x86_32.hashLong(item, 0); - int h2 = Murmur3_x86_32.hashLong(item, h1); - - long bitSize = bits.bitSize(); - boolean bitsChanged = false; - for (int i = 1; i <= numHashFunctions; i++) { - int combinedHash = h1 + (i * h2); - // Flip all the bits if it's negative (guaranteed positive number) - if (combinedHash < 0) { - combinedHash = ~combinedHash; - } - bitsChanged |= bits.set(combinedHash % bitSize); - } - return bitsChanged; + return strategy.putLong(item, bits, numHashFunctions); } @Override - public boolean mightContainLong(long item) { - int h1 = Murmur3_x86_32.hashLong(item, 0); - int h2 = Murmur3_x86_32.hashLong(item, h1); - - long bitSize = bits.bitSize(); - for (int i = 1; i <= numHashFunctions; i++) { - int combinedHash = h1 + (i * h2); - // Flip all the bits if it's negative (guaranteed positive number) - if (combinedHash < 0) { - combinedHash = ~combinedHash; - } - if (!bits.get(combinedHash % bitSize)) { - return false; - } - } - return true; - } - - @Override - public boolean mightContain(Object item) { - if (item instanceof String str) { - return mightContainString(str); - } else if (item instanceof byte[] bytes) { - return mightContainBinary(bytes); - } else { - return mightContainLong(Utils.integralToLong(item)); - } + public boolean putBinary(byte[] item) { + return strategy.putBinary(item, bits, numHashFunctions); } @Override @@ -204,13 +120,8 @@ public BloomFilter intersectInPlace(BloomFilter other) throws IncompatibleMergeE return this; } - @Override - public long cardinality() { - return this.bits.cardinality(); - } - private BloomFilterImpl checkCompatibilityForMerge(BloomFilter other) - throws IncompatibleMergeException { + throws IncompatibleMergeException { // Duplicates the logic of `isCompatible` here to provide better error message. if (other == null) { throw new IncompatibleMergeException("Cannot merge null bloom filter"); @@ -234,6 +145,37 @@ private BloomFilterImpl checkCompatibilityForMerge(BloomFilter other) return that; } + @Override + public long cardinality() { + return this.bits.cardinality(); + } + + @Override + public boolean mightContain(Object item) { + if (item instanceof String str) { + return mightContainString(str); + } else if (item instanceof byte[] bytes) { + return mightContainBinary(bytes); + } else { + return mightContainLong(Utils.integralToLong(item)); + } + } + + @Override + public boolean mightContainString(String item) { + return strategy.mightContainString(item, bits, numHashFunctions); + } + + @Override + public boolean mightContainLong(long item) { + return strategy.mightContainLong(item, bits, numHashFunctions); + } + + @Override + public boolean mightContainBinary(byte[] item) { + return strategy.mightContainBinary(item, bits, numHashFunctions); + } + @Override public void writeTo(OutputStream out) throws IOException { DataOutputStream dos = new DataOutputStream(out); @@ -243,18 +185,6 @@ public void writeTo(OutputStream out) throws IOException { bits.writeTo(dos); } - private void readFrom0(InputStream in) throws IOException { - DataInputStream dis = new DataInputStream(in); - - int version = dis.readInt(); - if (version != Version.V1.getVersionNumber()) { - throw new IOException("Unexpected Bloom filter version number (" + version + ")"); - } - - this.numHashFunctions = dis.readInt(); - this.bits = BitArray.readFrom(dis); - } - public static BloomFilterImpl readFrom(InputStream in) throws IOException { BloomFilterImpl filter = new BloomFilterImpl(); filter.readFrom0(in); @@ -267,6 +197,19 @@ public static BloomFilterImpl readFrom(byte[] bytes) throws IOException { } } + private void readFrom0(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Bloom filter version number (" + version + ")"); + } + + this.numHashFunctions = dis.readInt(); + this.bits = BitArray.readFrom(dis); + this.strategy = currentStrategy(); + } + private void writeObject(ObjectOutputStream out) throws IOException { writeTo(out); } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterStrategies.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterStrategies.java new file mode 100644 index 0000000000000..30c9a4276682b --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterStrategies.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +public enum BloomFilterStrategies implements BloomFilterStrategy { + HASH_32() { + @Override + public boolean putString(String item, BitArray bits, int numHashFunctions) { + return putBinary(Utils.getBytesFromUTF8String(item), bits, numHashFunctions); + } + + @Override + public boolean putLong(long item, BitArray bits, int numHashFunctions) { + // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n + // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. + // Note that `CountMinSketch` use a different strategy, it hash the input long element with + // every i to produce n hash values. + // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); + + long bitSize = bits.bitSize(); + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; + } + + @Override + public boolean putBinary(byte[] item, BitArray bits, int numHashFunctions) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); + + long bitSize = bits.bitSize(); + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; + } + + @Override + public boolean mightContainString(String item, BitArray bits, int numHashFunctions) { + return mightContainBinary(Utils.getBytesFromUTF8String(item), bits, numHashFunctions); + } + + @Override + public boolean mightContainLong(long item, BitArray bits, int numHashFunctions) { + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); + + long bitSize = bits.bitSize(); + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } + + @Override + public boolean mightContainBinary(byte[] item, BitArray bits, int numHashFunctions) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); + + long bitSize = bits.bitSize(); + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } + }, + HASH_128() { + @Override + public boolean putString(String item, BitArray bits, int numHashFunctions) { + return putBinary(Utils.getBytesFromUTF8String(item), bits, numHashFunctions); + } + + @Override + public boolean putLong(long item, BitArray bits, int numHashFunctions) { + long bitSize = bits.bitSize(); + Murmur3_x86_128.HashObject hashObject = Murmur3_x86_128.hashLong(item, 0); + long hash1 = hashObject.getHash1(); + long hash2 = hashObject.getHash2(); + + boolean bitsChanged = false; + long combinedHash = hash1; + for (int i = 0; i < numHashFunctions; i++) { + // Make the combined hash positive and indexable + bitsChanged |= bits.set((combinedHash & Long.MAX_VALUE) % bitSize); + combinedHash += hash2; + } + return bitsChanged; + } + + @Override + public boolean putBinary(byte[] item, BitArray bits, int numHashFunctions) { + long bitSize = bits.bitSize(); + Murmur3_x86_128.HashObject hashObject = Murmur3_x86_128.hashUnsafeBytes( + item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + long hash1 = hashObject.getHash1(); + long hash2 = hashObject.getHash2(); + + boolean bitsChanged = false; + long combinedHash = hash1; + for (int i = 0; i < numHashFunctions; i++) { + // Make the combined hash positive and indexable + bitsChanged |= bits.set((combinedHash & Long.MAX_VALUE) % bitSize); + combinedHash += hash2; + } + return bitsChanged; + } + + @Override + public boolean mightContainString(String item, BitArray bits, int numHashFunctions) { + return mightContainBinary(Utils.getBytesFromUTF8String(item), bits, numHashFunctions); + } + + @Override + public boolean mightContainLong(long item, BitArray bits, int numHashFunctions) { + long bitSize = bits.bitSize(); + Murmur3_x86_128.HashObject hashObject = Murmur3_x86_128.hashLong(item, 0); + long hash1 = hashObject.getHash1(); + long hash2 = hashObject.getHash2(); + + long combinedHash = hash1; + for (int i = 0; i < numHashFunctions; i++) { + // Make the combined hash positive and indexable + if (!bits.get((combinedHash & Long.MAX_VALUE) % bitSize)) { + return false; + } + combinedHash += hash2; + } + return true; + } + + @Override + public boolean mightContainBinary(byte[] item, BitArray bits, int numHashFunctions) { + long bitSize = bits.bitSize(); + Murmur3_x86_128.HashObject hashObject = Murmur3_x86_128.hashUnsafeBytes( + item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + long hash1 = hashObject.getHash1(); + long hash2 = hashObject.getHash2(); + + long combinedHash = hash1; + for (int i = 0; i < numHashFunctions; i++) { + // Make the combined hash positive and indexable + if (!bits.get((combinedHash & Long.MAX_VALUE) % bitSize)) { + return false; + } + combinedHash += hash2; + } + return true; + } + }; +} + diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterStrategy.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterStrategy.java new file mode 100644 index 0000000000000..0cd225a322372 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterStrategy.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +public interface BloomFilterStrategy { + int ordinal(); + + boolean putString(String item, BitArray bits, int numHashFunctions); + + boolean putLong(long item, BitArray bits, int numHashFunctions); + + boolean putBinary(byte[] item, BitArray bits, int numHashFunctions); + + boolean mightContainString(String item, BitArray bits, int numHashFunctions); + + boolean mightContainLong(long item, BitArray bits, int numHashFunctions); + + boolean mightContainBinary(byte[] item, BitArray bits, int numHashFunctions); +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_128.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_128.java new file mode 100644 index 0000000000000..b78571ca264de --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_128.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.nio.ByteOrder; + +/** + * 128-bit Murmur3 hasher. This is based on Guava's Murmur3_128HashFunction . + */ +// This class is duplicated from `org.apache.spark.unsafe.hash.Murmur3_x86_128` to make sure +// spark-sketch has no external dependencies. +public class Murmur3_x86_128 { + private static final boolean isBigEndian = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + + private static final long C1 = 0x87c37b91114253d5L; + private static final long C2 = 0x4cf5ad432745937fL; + + private final int seed; + + Murmur3_x86_128(int seed) { + this.seed = seed; + } + + @Override + public String toString() { + return "Murmur3_128(seed=" + seed + ")"; + } + + public static class HashObject { + private final long h1; + private final long h2; + + public HashObject(long h1, long h2) { + this.h1 = h1; + this.h2 = h2; + } + + public long getHash1() { + return h1; + } + + public long getHash2() { + return h2; + } + } + + public HashObject hashInt(int input) { + return hashInt(input, seed); + } + + public static HashObject hashInt(int input, int seed) { + return hashLong(((long) input) << 32, seed); + } + + public HashObject hashLong(long input) { + return hashLong(input, seed); + } + + public static HashObject hashLong(long input, int seed) { + long k1 = mixK1(input); + // Since k2 is 0, h2 is also 0 + long h1 = bmix64H1(seed, k1, 0); + return hashResult(h1, 0, 8); + } + + public long hashUnsafeWords(Object base, long offset, int lengthInBytes) { + return hashUnsafeWords(base, offset, lengthInBytes, seed); + } + + public static long hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; + long h1 = seed; + long h2 = seed; + long[] res = hashRemainingBytes(base, offset, lengthInBytes, h1, h2); + return fmix64(res[0]); + } + + public static HashObject hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + // This is compatible with original and another implementations. ` ` + assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; + int remainingBytes = lengthInBytes % 16; + int lengthAligned = lengthInBytes - remainingBytes; + + long[] res = hashBytesBy2Long(base, offset, lengthAligned, seed); + long h1 = res[0]; + long h2 = res[1]; + offset = res[2]; + res = hashRemainingBytes(base, offset, remainingBytes, h1, h2); + h1 = res[0]; + h2 = res[1]; + + return hashResult(h1, h2, lengthInBytes); + } + + private static long[] hashBytesBy2Long(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes % 16 == 0); + long h1 = seed; + long h2 = seed; + for (int i = 0; i < lengthInBytes; i += 16, offset += 16) { + long word1 = Platform.getLong(base, offset); + long word2 = Platform.getLong(base, offset + 8); + if (isBigEndian) { + word1 = Long.reverseBytes(word1); + word2 = Long.reverseBytes(word2); + } + long[] res = bmix64(h1, word1, h2, word2); + h1 = res[0]; + h2 = res[1]; + } + return new long[]{h1, h2, offset}; + } + + // This is based on Guava's `Murmur3_128Hasher.processRemaining(ByteBuffer)` method. + private static long[] hashRemainingBytes( + Object base, + long offset, + int remainingBytes, + long h1, + long h2){ + long k1 = 0; + long k2 = 0; + switch (remainingBytes) { + case 15: + k2 ^= (long) toInt(Platform.getByte(base, offset + 14)) << 48; + remainingBytes--; // fallthru + case 14: + k2 ^= (long) toInt(Platform.getByte(base, offset + 13)) << 40; + remainingBytes--; // fallthru + case 13: + k2 ^= (long) toInt(Platform.getByte(base, offset + 12)) << 32; + remainingBytes--; // fallthru + case 12: + k2 ^= (long) toInt(Platform.getByte(base, offset + 11)) << 24; + remainingBytes--; // fallthru + case 11: + k2 ^= (long) toInt(Platform.getByte(base, offset + 10)) << 16; + remainingBytes--; // fallthru + case 10: + k2 ^= (long) toInt(Platform.getByte(base, offset + 9)) << 8; + remainingBytes--; // fallthru + case 9: + k2 ^= (long) toInt(Platform.getByte(base, offset + 8)); + remainingBytes--; // fallthru + case 8: + k1 ^= Platform.getLong(base, offset); + break; + case 7: + k1 ^= (long) toInt(Platform.getByte(base, offset + 6)) << 48; + remainingBytes--; // fallthru + case 6: + k1 ^= (long) toInt(Platform.getByte(base, offset + 5)) << 40; + remainingBytes--; // fallthru + case 5: + k1 ^= (long) toInt(Platform.getByte(base, offset + 4)) << 32; + remainingBytes--; // fallthru + case 4: + k1 ^= (long) toInt(Platform.getByte(base, offset + 3)) << 24; + remainingBytes--; // fallthru + case 3: + k1 ^= (long) toInt(Platform.getByte(base, offset + 2)) << 16; + remainingBytes--; // fallthru + case 2: + k1 ^= (long) toInt(Platform.getByte(base, offset + 1)) << 8; + remainingBytes--; // fallthru + case 1: + k1 ^= (long) toInt(Platform.getByte(base, offset)); + break; + } + h1 ^= mixK1(k1); + h2 ^= mixK2(k2); + return new long[]{h1, h2}; + } + + private static long[] bmix64(long h1, long k1, long h2, long k2) { + h1 = bmix64H1(h1, k1, h2); + h2 = bmix64H2(h2, k2, h1); + return new long[]{h1, h2}; + } + + private static HashObject hashResult(long h1, long h2, int lengthInBytes) { + h1 ^= lengthInBytes; + h2 ^= lengthInBytes; + h1 += h2; + h2 += h1; + h1 = fmix64(h1); + h2 = fmix64(h2); + h1 += h2; + h2 += h1; + return new HashObject(h1, h2); + } + + private static long bmix64H1(long h1, long k1, long h2) { + h1 ^= mixK1(k1); + h1 = Long.rotateLeft(h1, 27); + h1 += h2; + return h1 * 5 + 0x52dce729; + } + + private static long bmix64H2(long h2, long k2, long h1) { + h2 ^= mixK2(k2); + h2 = Long.rotateLeft(h2, 31); + h2 += h1; + return h2 * 5 + 0x38495ab5; + } + + private static long fmix64(long k) { + k ^= k >>> 33; + k *= 0xff51afd7ed558ccdL; + k ^= k >>> 33; + k *= 0xc4ceb9fe1a85ec53L; + k ^= k >>> 33; + return k; + } + + private static long mixK1(long k1) { + k1 *= C1; + k1 = Long.rotateLeft(k1, 31); + k1 *= C2; + return k1; + } + + private static long mixK2(long k2) { + k2 *= C2; + k2 = Long.rotateLeft(k2, 33); + k2 *= C1; + return k2; + } + + // This method is copied from Guava's UnsignedBytes class to reduce dependency. + private static int toInt(byte value) { + return value & 255; + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_128.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_128.java new file mode 100644 index 0000000000000..a0067629848fc --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_128.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.hash; + +import java.nio.ByteOrder; + +import org.apache.spark.unsafe.Platform; + +import static com.google.common.primitives.UnsignedBytes.toInt; + +/** + * 128-bit Murmur3 hasher. This is based on Guava's Murmur3_128HashFunction . + */ +public class Murmur3_x86_128 { + private static final boolean isBigEndian = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + + private static final long C1 = 0x87c37b91114253d5L; + private static final long C2 = 0x4cf5ad432745937fL; + + private final int seed; + + Murmur3_x86_128(int seed) { + this.seed = seed; + } + + @Override + public String toString() { + return "Murmur3_128(seed=" + seed + ")"; + } + + public static class HashObject { + private final long h1; + private final long h2; + + public HashObject(long h1, long h2) { + this.h1 = h1; + this.h2 = h2; + } + + public long getHash1() { + return h1; + } + + public long getHash2() { + return h2; + } + } + + public HashObject hashInt(int input) { + return hashInt(input, seed); + } + + public static HashObject hashInt(int input, int seed) { + return hashLong(((long) input) << 32, seed); + } + + public HashObject hashLong(long input) { + return hashLong(input, seed); + } + + public static HashObject hashLong(long input, int seed) { + long k1 = mixK1(input); + // Since k2 is 0, h2 is also 0 + long h1 = bmix64H1(seed, k1, 0); + return hashResult(h1, 0, 8); + } + + public long hashUnsafeWords(Object base, long offset, int lengthInBytes) { + return hashUnsafeWords(base, offset, lengthInBytes, seed); + } + + public static long hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; + long h1 = seed; + long h2 = seed; + long[] res = hashRemainingBytes(base, offset, lengthInBytes, h1, h2); + return fmix64(res[0]); + } + + public static HashObject hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + // This is compatible with original and another implementations. ` ` + assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; + int remainingBytes = lengthInBytes % 16; + int lengthAligned = lengthInBytes - remainingBytes; + + long[] res = hashBytesBy2Long(base, offset, lengthAligned, seed); + long h1 = res[0]; + long h2 = res[1]; + offset = res[2]; + res = hashRemainingBytes(base, offset, remainingBytes, h1, h2); + h1 = res[0]; + h2 = res[1]; + + return hashResult(h1, h2, lengthInBytes); + } + + private static long[] hashBytesBy2Long(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes % 16 == 0); + long h1 = seed; + long h2 = seed; + for (int i = 0; i < lengthInBytes; i += 16, offset += 16) { + long word1 = Platform.getLong(base, offset); + long word2 = Platform.getLong(base, offset + 8); + if (isBigEndian) { + word1 = Long.reverseBytes(word1); + word2 = Long.reverseBytes(word2); + } + long[] res = bmix64(h1, word1, h2, word2); + h1 = res[0]; + h2 = res[1]; + } + return new long[]{h1, h2, offset}; + } + + // This is based on Guava's `Murmur3_128Hasher.processRemaining(ByteBuffer)` method. + private static long[] hashRemainingBytes( + Object base, + long offset, + int remainingBytes, + long h1, + long h2) { + long k1 = 0; + long k2 = 0; + switch (remainingBytes) { + case 15: + k2 ^= (long) toInt(Platform.getByte(base, offset + 14)) << 48; + remainingBytes--; // fallthru + case 14: + k2 ^= (long) toInt(Platform.getByte(base, offset + 13)) << 40; + remainingBytes--; // fallthru + case 13: + k2 ^= (long) toInt(Platform.getByte(base, offset + 12)) << 32; + remainingBytes--; // fallthru + case 12: + k2 ^= (long) toInt(Platform.getByte(base, offset + 11)) << 24; + remainingBytes--; // fallthru + case 11: + k2 ^= (long) toInt(Platform.getByte(base, offset + 10)) << 16; + remainingBytes--; // fallthru + case 10: + k2 ^= (long) toInt(Platform.getByte(base, offset + 9)) << 8; + remainingBytes--; // fallthru + case 9: + k2 ^= (long) toInt(Platform.getByte(base, offset + 8)); + remainingBytes--; // fallthru + case 8: + k1 ^= Platform.getLong(base, offset); + break; + case 7: + k1 ^= (long) toInt(Platform.getByte(base, offset + 6)) << 48; + remainingBytes--; // fallthru + case 6: + k1 ^= (long) toInt(Platform.getByte(base, offset + 5)) << 40; + remainingBytes--; // fallthru + case 5: + k1 ^= (long) toInt(Platform.getByte(base, offset + 4)) << 32; + remainingBytes--; // fallthru + case 4: + k1 ^= (long) toInt(Platform.getByte(base, offset + 3)) << 24; + remainingBytes--; // fallthru + case 3: + k1 ^= (long) toInt(Platform.getByte(base, offset + 2)) << 16; + remainingBytes--; // fallthru + case 2: + k1 ^= (long) toInt(Platform.getByte(base, offset + 1)) << 8; + remainingBytes--; // fallthru + case 1: + k1 ^= (long) toInt(Platform.getByte(base, offset)); + break; + } + h1 ^= mixK1(k1); + h2 ^= mixK2(k2); + return new long[]{h1, h2}; + } + + private static long[] bmix64(long h1, long k1, long h2, long k2) { + h1 = bmix64H1(h1, k1, h2); + h2 = bmix64H2(h2, k2, h1); + return new long[]{h1, h2}; + } + + private static HashObject hashResult(long h1, long h2, int lengthInBytes) { + h1 ^= lengthInBytes; + h2 ^= lengthInBytes; + h1 += h2; + h2 += h1; + h1 = fmix64(h1); + h2 = fmix64(h2); + h1 += h2; + h2 += h1; + return new HashObject(h1, h2); + } + + private static long bmix64H1(long h1, long k1, long h2) { + h1 ^= mixK1(k1); + h1 = Long.rotateLeft(h1, 27); + h1 += h2; + return h1 * 5 + 0x52dce729; + } + + private static long bmix64H2(long h2, long k2, long h1) { + h2 ^= mixK2(k2); + h2 = Long.rotateLeft(h2, 31); + h2 += h1; + return h2 * 5 + 0x38495ab5; + } + + private static long fmix64(long k) { + k ^= k >>> 33; + k *= 0xff51afd7ed558ccdL; + k ^= k >>> 33; + k *= 0xc4ceb9fe1a85ec53L; + k ^= k >>> 33; + return k; + } + + private static long mixK1(long k1) { + k1 *= C1; + k1 = Long.rotateLeft(k1, 31); + k1 *= C2; + return k1; + } + + private static long mixK2(long k2) { + k2 *= C2; + k2 = Long.rotateLeft(k2, 33); + k2 *= C1; + return k2; + } +} diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_128Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_128Suite.java new file mode 100644 index 0000000000000..593fab4a49449 --- /dev/null +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_128Suite.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.hash; + +import org.apache.spark.unsafe.Platform; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** + * Test file largely based on Guava's Murmur3Hash128Test. + */ +public class Murmur3_x86_128Suite { + + private static final Murmur3_x86_128 hasher = new Murmur3_x86_128(0); + + @Test + public void testKnownValues() { + assertHash(0, "629942693e10f867", "92db0b82baeb5347", "hell"); + assertHash(1, "a78ddff5adae8d10", "128900ef20900135", "hello"); + assertHash(2, "8a486b23f422e826", "f962a2c58947765f", "hello "); + assertHash(3, "2ea59f466f6bed8c", "c610990acc428a17", "hello w"); + assertHash(4, "79f6305a386c572c", "46305aed3483b94e", "hello wo"); + assertHash(5, "c2219d213ec1f1b5", "a1d8e2e0a52785bd", "hello wor"); + assertHash(0, "e34bbc7bbc071b6c", "7a433ca9c49a9347", + "The quick brown fox jumps over the lazy dog"); + assertHash(0, "658ca970ff85269a", "43fee3eaa68e5c3e", + "The quick brown fox jumps over the lazy cog"); + } + + private void assertHash(int seed, String expectedH1, String expectedH2, String stringInput) { + byte[] in = ascii(stringInput); + Murmur3_x86_128.HashObject hash = Murmur3_x86_128.hashUnsafeBytes( + in, Platform.BYTE_ARRAY_OFFSET, in.length, seed); + Assertions.assertEquals(expectedH1, Long.toHexString(hash.getHash1())); + Assertions.assertEquals(expectedH2, Long.toHexString(hash.getHash2())); + } + + private byte[] ascii(String string) { + byte[] bytes = new byte[string.length()]; + for (int i = 0; i < string.length(); i++) { + bytes[i] = (byte) string.charAt(i); + } + return bytes; + } +}