From 4b05a35d58cdabccd915582894d303ba437bee0f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 25 Jan 2016 23:23:51 -0800 Subject: [PATCH 1/3] bloom filter serialization --- .../apache/spark/util/sketch/BitArray.java | 46 +++++++++++---- .../apache/spark/util/sketch/BloomFilter.java | 20 ++++++- .../spark/util/sketch/BloomFilterImpl.java | 57 ++++++++++++++++++- .../spark/util/sketch/CountMinSketch.java | 23 +------- .../spark/util/sketch/CountMinSketchImpl.java | 7 +-- .../org/apache/spark/util/sketch/Version.java | 35 ++++++++++++ .../spark/util/sketch/BloomFilterSuite.scala | 20 +++++++ 7 files changed, 166 insertions(+), 42 deletions(-) create mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/Version.java diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java index 1bc665ad54b7..0e5b6f5668c0 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -17,6 +17,9 @@ package org.apache.spark.util.sketch; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; import java.util.Arrays; public final class BitArray { @@ -24,6 +27,9 @@ public final class BitArray { private long bitCount; static int numWords(long numBits) { + if (numBits <= 0) { + throw new IllegalArgumentException("numBits must be positive"); + } long numWords = (long) Math.ceil(numBits / 64.0); if (numWords > Integer.MAX_VALUE) { throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits"); @@ -32,13 +38,14 @@ static int numWords(long numBits) { } BitArray(long numBits) { - if (numBits <= 0) { - throw new IllegalArgumentException("numBits must be positive"); - } - this.data = new long[numWords(numBits)]; + this(new long[numWords(numBits)]); + } + + private BitArray(long[] data) { + this.data = data; long bitCount = 0; - for (long value : data) { - bitCount += Long.bitCount(value); + for (long datum : data) { + bitCount += Long.bitCount(datum); } this.bitCount = bitCount; } @@ -78,13 +85,28 @@ void putAll(BitArray array) { this.bitCount = bitCount; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || !(o instanceof BitArray)) return false; + void writeTo(DataOutputStream out) throws IOException { + out.writeInt(data.length); + for (long datum : data) { + out.writeLong(datum); + } + } - BitArray bitArray = (BitArray) o; - return Arrays.equals(data, bitArray.data); + static BitArray readFrom(DataInputStream in) throws IOException { + int numWords = in.readInt(); + long[] data = new long[numWords]; + for (int i = 0; i < numWords; i++) { + data[i] = in.readLong(); + } + return new BitArray(data); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || !(other instanceof BitArray)) return false; + BitArray that = (BitArray) other; + return Arrays.equals(data, that.data); } @Override diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 38949c6311df..de10c6a23c10 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -17,6 +17,10 @@ package org.apache.spark.util.sketch; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + /** * A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether * an element is a member of a set. It returns false when the element is definitely not in the @@ -83,7 +87,7 @@ public abstract class BloomFilter { * bloom filters are appropriately sized to avoid saturating them. * * @param other The bloom filter to combine this bloom filter with. It is not mutated. - * @throws IllegalArgumentException if {@code isCompatible(that) == false} + * @throws IncompatibleMergeException if {@code isCompatible(that) == false} */ public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException; @@ -93,6 +97,20 @@ public abstract class BloomFilter { */ public abstract boolean mightContain(Object item); + /** + * Writes out this {@link BloomFilter} to an output stream in binary format. + * It is the caller's responsibility to close the stream. + */ + public abstract void writeTo(OutputStream out) throws IOException; + + /** + * Reads in a {@link BloomFilter} from an input stream. + * It is the caller's responsibility to close the stream. + */ + public static BloomFilter readFrom(InputStream in) throws IOException { + return BloomFilterImpl.readFrom(in); + } + /** * Computes the optimal k (number of hashes per element inserted in Bloom filter), given the * expected insertions and total number of bits in the Bloom filter. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index bbd6cf719dc0..b97043686b33 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -17,16 +17,49 @@ package org.apache.spark.util.sketch; -import java.io.UnsupportedEncodingException; +import java.io.*; +/* + * Binary format of a serialized BloomFilterImpl, version 1 (all values written in big-endian + * order): + * + * - Version number, always 1 (32 bit) + * - Total number of words of the BitArray (32 bit) + * - Long array inside the BitArray (numWords * 64 bit) + * - Number of hash functions (32 bit) + */ public class BloomFilterImpl extends BloomFilter { private final int numHashFunctions; private final BitArray bits; BloomFilterImpl(int numHashFunctions, long numBits) { + this(new BitArray(numBits), numHashFunctions); + } + + private BloomFilterImpl(BitArray bits, int numHashFunctions) { + this.bits = bits; this.numHashFunctions = numHashFunctions; - this.bits = new BitArray(numBits); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (other == null || !(other instanceof BloomFilterImpl)) { + return false; + } + + BloomFilterImpl that = (BloomFilterImpl) other; + + return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits); + } + + @Override + public int hashCode() { + return bits.hashCode() * 31 + numHashFunctions; } @Override @@ -161,4 +194,24 @@ public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeExcep this.bits.putAll(that.bits); return this; } + + @Override + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(Version.V1.getVersionNumber()); + bits.writeTo(dos); + dos.writeInt(numHashFunctions); + } + + public static BloomFilterImpl readFrom(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Bloom Filter version number (" + version + ")"); + } + + return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt()); + } } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 9f4ff42403c3..004fbbf3152f 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -55,25 +55,6 @@ * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ abstract public class CountMinSketch { - /** - * Version number of the serialized binary format. - */ - public enum Version { - V1(1); - - private final int versionNumber; - - Version(int versionNumber) { - this.versionNumber = versionNumber; - } - - public int getVersionNumber() { - return versionNumber; - } - } - - public abstract Version version(); - /** * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}. */ @@ -128,13 +109,13 @@ public abstract CountMinSketch mergeInPlace(CountMinSketch other) /** * Writes out this {@link CountMinSketch} to an output stream in binary format. - * It is the caller's responsibility to close the stream + * It is the caller's responsibility to close the stream. */ public abstract void writeTo(OutputStream out) throws IOException; /** * Reads in a {@link CountMinSketch} from an input stream. - * It is the caller's responsibility to close the stream + * It is the caller's responsibility to close the stream. */ public static CountMinSketch readFrom(InputStream in) throws IOException { return CountMinSketchImpl.readFrom(in); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 0209446ea3b1..8f17ddb31011 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -112,11 +112,6 @@ public int hashCode() { return hash; } - @Override - public Version version() { - return Version.V1; - } - private void initTablesWith(int depth, int width, int seed) { this.table = new long[depth][width]; this.hashA = new long[depth]; @@ -327,7 +322,7 @@ public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMerg public void writeTo(OutputStream out) throws IOException { DataOutputStream dos = new DataOutputStream(out); - dos.writeInt(version().getVersionNumber()); + dos.writeInt(Version.V1.getVersionNumber()); dos.writeLong(this.totalCount); dos.writeInt(this.depth); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Version.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Version.java new file mode 100644 index 000000000000..40790c92f3ae --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Version.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +/** + * Version number of the serialized binary format for bloom filter or count-min sketch. + */ +public enum Version { + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + int getVersionNumber() { + return versionNumber; + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala index d2de509f1951..a0408d2da4df 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.sketch +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + import scala.reflect.ClassTag import scala.util.Random @@ -25,6 +27,20 @@ import org.scalatest.FunSuite // scalastyle:ignore funsuite class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite private final val EPSILON = 0.01 + // Serializes and deserializes a given `BloomFilter`, then checks whether the deserialized + // version is equivalent to the original one. + private def checkSerDe(filter: BloomFilter): Unit = { + val out = new ByteArrayOutputStream() + filter.writeTo(out) + out.close() + + val in = new ByteArrayInputStream(out.toByteArray) + val deserialized = BloomFilter.readFrom(in) + in.close() + + assert(filter == deserialized) + } + def testAccuracy[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { test(s"accuracy - $typeName") { // use a fixed seed to make the test predictable. @@ -51,6 +67,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite // Also check the actual fpp is not significantly higher than we expected. val actualFpp = errorCount.toDouble / (numItems - numInsertion) assert(actualFpp - fpp < EPSILON) + + checkSerDe(filter) } } @@ -76,6 +94,8 @@ class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite items1.foreach(i => assert(filter1.mightContain(i))) items2.foreach(i => assert(filter1.mightContain(i))) + + checkSerDe(filter1) } } From 38d674c99af18aaf807c120647efb16442b5a967 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 25 Jan 2016 23:58:00 -0800 Subject: [PATCH 2/3] address comments --- .../main/java/org/apache/spark/util/sketch/BitArray.java | 6 +++--- .../main/java/org/apache/spark/util/sketch/BloomFilter.java | 2 +- .../java/org/apache/spark/util/sketch/BloomFilterImpl.java | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java index 0e5b6f5668c0..2a0484e324b1 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -28,7 +28,7 @@ public final class BitArray { static int numWords(long numBits) { if (numBits <= 0) { - throw new IllegalArgumentException("numBits must be positive"); + throw new IllegalArgumentException("numBits must be positive, but got " + numBits); } long numWords = (long) Math.ceil(numBits / 64.0); if (numWords > Integer.MAX_VALUE) { @@ -44,8 +44,8 @@ static int numWords(long numBits) { private BitArray(long[] data) { this.data = data; long bitCount = 0; - for (long datum : data) { - bitCount += Long.bitCount(datum); + for (long word : data) { + bitCount += Long.bitCount(word); } this.bitCount = bitCount; } diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index de10c6a23c10..2301b4993532 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -87,7 +87,7 @@ public abstract class BloomFilter { * bloom filters are appropriately sized to avoid saturating them. * * @param other The bloom filter to combine this bloom filter with. It is not mutated. - * @throws IncompatibleMergeException if {@code isCompatible(that) == false} + * @throws IncompatibleMergeException if {@code isCompatible(other) == false} */ public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException; diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index b97043686b33..aa6bae02e43a 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -209,7 +209,7 @@ public static BloomFilterImpl readFrom(InputStream in) throws IOException { int version = dis.readInt(); if (version != Version.V1.getVersionNumber()) { - throw new IOException("Unexpected Bloom Filter version number (" + version + ")"); + throw new IOException("Unexpected Bloom filter version number (" + version + ")"); } return new BloomFilterImpl(BitArray.readFrom(dis), dis.readInt()); From c9b29c94d6cbb6bde098e6d5b971de118be0218b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 26 Jan 2016 00:18:16 -0800 Subject: [PATCH 3/3] move version back --- .../apache/spark/util/sketch/BloomFilter.java | 22 ++++++++++++ .../spark/util/sketch/BloomFilterImpl.java | 9 ----- .../spark/util/sketch/CountMinSketch.java | 28 +++++++++++++++ .../spark/util/sketch/CountMinSketchImpl.java | 15 -------- .../org/apache/spark/util/sketch/Version.java | 35 ------------------- 5 files changed, 50 insertions(+), 59 deletions(-) delete mode 100644 common/sketch/src/main/java/org/apache/spark/util/sketch/Version.java diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 2301b4993532..00378d58518f 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -43,6 +43,28 @@ * The implementation is largely based on the {@code BloomFilter} class from guava. */ public abstract class BloomFilter { + + public enum Version { + /** + * {@code BloomFilter} binary format version 1 (all values written in big-endian order): + * - Version number, always 1 (32 bit) + * - Total number of words of the underlying bit array (32 bit) + * - The words/longs (numWords * 64 bit) + * - Number of hash functions (32 bit) + */ + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + int getVersionNumber() { + return versionNumber; + } + } + /** * Returns the false positive probability, i.e. the probability that * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index aa6bae02e43a..1c08d07afaea 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -19,15 +19,6 @@ import java.io.*; -/* - * Binary format of a serialized BloomFilterImpl, version 1 (all values written in big-endian - * order): - * - * - Version number, always 1 (32 bit) - * - Total number of words of the BitArray (32 bit) - * - Long array inside the BitArray (numWords * 64 bit) - * - Number of hash functions (32 bit) - */ public class BloomFilterImpl extends BloomFilter { private final int numHashFunctions; diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 004fbbf3152f..00c0b1b9e2db 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -55,6 +55,34 @@ * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. */ abstract public class CountMinSketch { + + public enum Version { + /** + * {@code CountMinSketch} binary format version 1 (all values written in big-endian order): + * - Version number, always 1 (32 bit) + * - Total count of added items (64 bit) + * - Depth (32 bit) + * - Width (32 bit) + * - Hash functions (depth * 64 bit) + * - Count table + * - Row 0 (width * 64 bit) + * - Row 1 (width * 64 bit) + * - ... + * - Row depth - 1 (width * 64 bit) + */ + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + int getVersionNumber() { + return versionNumber; + } + } + /** * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}. */ diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 8f17ddb31011..d08809605a93 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -26,21 +26,6 @@ import java.util.Arrays; import java.util.Random; -/* - * Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian - * order): - * - * - Version number, always 1 (32 bit) - * - Total count of added items (64 bit) - * - Depth (32 bit) - * - Width (32 bit) - * - Hash functions (depth * 64 bit) - * - Count table - * - Row 0 (width * 64 bit) - * - Row 1 (width * 64 bit) - * - ... - * - Row depth - 1 (width * 64 bit) - */ class CountMinSketchImpl extends CountMinSketch { public static final long PRIME_MODULUS = (1L << 31) - 1; diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Version.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Version.java deleted file mode 100644 index 40790c92f3ae..000000000000 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/Version.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.sketch; - -/** - * Version number of the serialized binary format for bloom filter or count-min sketch. - */ -public enum Version { - V1(1); - - private final int versionNumber; - - Version(int versionNumber) { - this.versionNumber = versionNumber; - } - - int getVersionNumber() { - return versionNumber; - } -}