-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-12934][SQL] Count-min sketch serialization #10893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,30 @@ | |
|
|
||
| package org.apache.spark.util.sketch; | ||
|
|
||
| import java.io.DataInputStream; | ||
| import java.io.DataOutputStream; | ||
| import java.io.IOException; | ||
| import java.io.InputStream; | ||
| import java.io.OutputStream; | ||
| import java.io.UnsupportedEncodingException; | ||
| import java.util.Arrays; | ||
| import java.util.Random; | ||
|
|
||
| /* | ||
| * Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian | ||
| * order): | ||
| * | ||
| * - Version number, always 1 (32 bit) | ||
| * - Total count of added items (64 bit) | ||
| * - Depth (32 bit) | ||
| * - Width (32 bit) | ||
| * - Hash functions (depth * 64 bit) | ||
| * - Count table | ||
| * - Row 0 (width * 64 bit) | ||
| * - Row 1 (width * 64 bit) | ||
| * - ... | ||
| * - Row depth - 1 (width * 64 bit) | ||
| */ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we move this comment near
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I couldn't decide whether to put it before |
||
| class CountMinSketchImpl extends CountMinSketch { | ||
| public static final long PRIME_MODULUS = (1L << 31) - 1; | ||
|
|
||
|
|
@@ -33,15 +52,15 @@ class CountMinSketchImpl extends CountMinSketch { | |
| private double eps; | ||
| private double confidence; | ||
|
|
||
| public CountMinSketchImpl(int depth, int width, int seed) { | ||
| CountMinSketchImpl(int depth, int width, int seed) { | ||
| this.depth = depth; | ||
| this.width = width; | ||
| this.eps = 2.0 / width; | ||
| this.confidence = 1 - 1 / Math.pow(2, depth); | ||
| initTablesWith(depth, width, seed); | ||
| } | ||
|
|
||
| public CountMinSketchImpl(double eps, double confidence, int seed) { | ||
| CountMinSketchImpl(double eps, double confidence, int seed) { | ||
| // 2/w = eps ; w = 2/eps | ||
| // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence) | ||
| this.eps = eps; | ||
|
|
@@ -51,6 +70,53 @@ public CountMinSketchImpl(double eps, double confidence, int seed) { | |
| initTablesWith(depth, width, seed); | ||
| } | ||
|
|
||
| CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) { | ||
| this.depth = depth; | ||
| this.width = width; | ||
| this.eps = 2.0 / width; | ||
| this.confidence = 1 - 1 / Math.pow(2, depth); | ||
| this.hashA = hashA; | ||
| this.table = table; | ||
| this.totalCount = totalCount; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object other) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also override hashcode?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added, thanks. |
||
| if (other == this) { | ||
| return true; | ||
| } | ||
|
|
||
| if (other == null || !(other instanceof CountMinSketchImpl)) { | ||
| return false; | ||
| } | ||
|
|
||
| CountMinSketchImpl that = (CountMinSketchImpl) other; | ||
|
|
||
| return | ||
| this.depth == that.depth && | ||
| this.width == that.width && | ||
| this.totalCount == that.totalCount && | ||
| Arrays.equals(this.hashA, that.hashA) && | ||
| Arrays.deepEquals(this.table, that.table); | ||
| } | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| int hash = depth; | ||
|
|
||
| hash = hash * 31 + width; | ||
| hash = hash * 31 + (int) (totalCount ^ (totalCount >>> 32)); | ||
| hash = hash * 31 + Arrays.hashCode(hashA); | ||
| hash = hash * 31 + Arrays.deepHashCode(table); | ||
|
|
||
| return hash; | ||
| } | ||
|
|
||
| @Override | ||
| public Version version() { | ||
| return Version.V1; | ||
| } | ||
|
|
||
| private void initTablesWith(int depth, int width, int seed) { | ||
| this.table = new long[depth][width]; | ||
| this.hashA = new long[depth]; | ||
|
|
@@ -221,27 +287,29 @@ private long estimateCountForStringItem(String item) { | |
| } | ||
|
|
||
| @Override | ||
| public CountMinSketch mergeInPlace(CountMinSketch other) { | ||
| public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException { | ||
| if (other == null) { | ||
| throw new CMSMergeException("Cannot merge null estimator"); | ||
| throw new IncompatibleMergeException("Cannot merge null estimator"); | ||
| } | ||
|
|
||
| if (!(other instanceof CountMinSketchImpl)) { | ||
| throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName()); | ||
| throw new IncompatibleMergeException( | ||
| "Cannot merge estimator of class " + other.getClass().getName() | ||
| ); | ||
| } | ||
|
|
||
| CountMinSketchImpl that = (CountMinSketchImpl) other; | ||
|
|
||
| if (this.depth != that.depth) { | ||
| throw new CMSMergeException("Cannot merge estimators of different depth"); | ||
| throw new IncompatibleMergeException("Cannot merge estimators of different depth"); | ||
| } | ||
|
|
||
| if (this.width != that.width) { | ||
| throw new CMSMergeException("Cannot merge estimators of different width"); | ||
| throw new IncompatibleMergeException("Cannot merge estimators of different width"); | ||
| } | ||
|
|
||
| if (!Arrays.equals(this.hashA, that.hashA)) { | ||
| throw new CMSMergeException("Cannot merge estimators of different seed"); | ||
| throw new IncompatibleMergeException("Cannot merge estimators of different seed"); | ||
| } | ||
|
|
||
| for (int i = 0; i < this.table.length; ++i) { | ||
|
|
@@ -256,13 +324,48 @@ public CountMinSketch mergeInPlace(CountMinSketch other) { | |
| } | ||
|
|
||
| @Override | ||
| public void writeTo(OutputStream out) { | ||
| throw new UnsupportedOperationException("Not implemented yet"); | ||
| public void writeTo(OutputStream out) throws IOException { | ||
| DataOutputStream dos = new DataOutputStream(out); | ||
|
|
||
| dos.writeInt(version().getVersionNumber()); | ||
|
|
||
| dos.writeLong(this.totalCount); | ||
| dos.writeInt(this.depth); | ||
| dos.writeInt(this.width); | ||
|
|
||
| for (int i = 0; i < this.depth; ++i) { | ||
| dos.writeLong(this.hashA[i]); | ||
| } | ||
|
|
||
| for (int i = 0; i < this.depth; ++i) { | ||
| for (int j = 0; j < this.width; ++j) { | ||
| dos.writeLong(table[i][j]); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| protected static class CMSMergeException extends RuntimeException { | ||
| public CMSMergeException(String message) { | ||
| super(message); | ||
| public static CountMinSketchImpl readFrom(InputStream in) throws IOException { | ||
| DataInputStream dis = new DataInputStream(in); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be closed before returning from the method, right ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But |
||
|
|
||
| // Ignores version number | ||
| dis.readInt(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should add some check here to throw an exception if the version number is not 1. |
||
|
|
||
| long totalCount = dis.readLong(); | ||
| int depth = dis.readInt(); | ||
| int width = dis.readInt(); | ||
|
|
||
| long hashA[] = new long[depth]; | ||
| for (int i = 0; i < depth; ++i) { | ||
| hashA[i] = dis.readLong(); | ||
| } | ||
|
|
||
| long table[][] = new long[depth][width]; | ||
| for (int i = 0; i < depth; ++i) { | ||
| for (int j = 0; j < width; ++j) { | ||
| table[i][j] = dis.readLong(); | ||
| } | ||
| } | ||
|
|
||
| return new CountMinSketchImpl(depth, width, totalCount, hashA, table); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.util.sketch; | ||
|
|
||
| public class IncompatibleMergeException extends Exception { | ||
| public IncompatibleMergeException(String message) { | ||
| super(message); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
|
|
||
| package org.apache.spark.util.sketch | ||
|
|
||
| import java.io.{ByteArrayInputStream, ByteArrayOutputStream} | ||
|
|
||
| import scala.reflect.ClassTag | ||
| import scala.util.Random | ||
|
|
||
|
|
@@ -29,9 +31,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite | |
|
|
||
| private val seed = 42 | ||
|
|
||
| // Serializes and deserializes a given `CountMinSketch`, then checks whether the deserialized | ||
| // version is equivalent to the original one. | ||
| private def checkSerDe(sketch: CountMinSketch): Unit = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add an inline documentation explaining what this function does |
||
| val out = new ByteArrayOutputStream() | ||
| sketch.writeTo(out) | ||
|
|
||
| val in = new ByteArrayInputStream(out.toByteArray) | ||
| val deserialized = CountMinSketch.readFrom(in) | ||
|
|
||
| assert(sketch === deserialized) | ||
| } | ||
|
|
||
| def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { | ||
| test(s"accuracy - $typeName") { | ||
| val r = new Random() | ||
| // Uses fixed seed to ensure reproducible test execution | ||
| val r = new Random(31) | ||
|
|
||
| val numAllItems = 1000000 | ||
| val allItems = Array.fill(numAllItems)(itemGenerator(r)) | ||
|
|
@@ -45,7 +60,10 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite | |
| } | ||
|
|
||
| val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) | ||
| checkSerDe(sketch) | ||
|
|
||
| sampledItemIndices.foreach(i => sketch.add(allItems(i))) | ||
| checkSerDe(sketch) | ||
|
|
||
| val probCorrect = { | ||
| val numErrors = allItems.map { item => | ||
|
|
@@ -66,7 +84,9 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite | |
|
|
||
| def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { | ||
| test(s"mergeInPlace - $typeName") { | ||
| val r = new Random() | ||
| // Uses fixed seed to ensure reproducible test execution | ||
| val r = new Random(31) | ||
|
|
||
| val numToMerge = 5 | ||
| val numItemsPerSketch = 100000 | ||
| val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { | ||
|
|
@@ -75,11 +95,16 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite | |
|
|
||
| val sketches = perSketchItems.map { items => | ||
| val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) | ||
| checkSerDe(sketch) | ||
|
|
||
| items.foreach(sketch.add) | ||
| checkSerDe(sketch) | ||
|
|
||
| sketch | ||
| } | ||
|
|
||
| val mergedSketch = sketches.reduce(_ mergeInPlace _) | ||
| checkSerDe(mergedSketch) | ||
|
|
||
| val expectedSketch = { | ||
| val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) | ||
|
|
@@ -109,4 +134,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite | |
| testItemType[Long]("Long") { _.nextLong() } | ||
|
|
||
| testItemType[String]("String") { r => r.nextString(r.nextInt(20)) } | ||
|
|
||
| test("incompatible merge") { | ||
| intercept[IncompatibleMergeException] { | ||
| CountMinSketch.create(10, 10, 1).mergeInPlace(null) | ||
| } | ||
|
|
||
| intercept[IncompatibleMergeException] { | ||
| val sketch1 = CountMinSketch.create(10, 20, 1) | ||
| val sketch2 = CountMinSketch.create(10, 20, 2) | ||
| sketch1.mergeInPlace(sketch2) | ||
| } | ||
|
|
||
| intercept[IncompatibleMergeException] { | ||
| val sketch1 = CountMinSketch.create(10, 10, 1) | ||
| val sketch2 = CountMinSketch.create(10, 20, 2) | ||
| sketch1.mergeInPlace(sketch2) | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this out so that we can also use it in bloom filter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be count min sketch specific, because the two binary protocols can and should evolve separately.