Skip to content

Commit 6f0f1d9

Browse files
lianchengrxin
authored andcommitted
[SPARK-12934][SQL] Count-min sketch serialization
This PR adds serialization support for `CountMinSketch`. A version number is added to version the serialized binary format. Author: Cheng Lian <[email protected]> Closes #10893 from liancheng/cms-serialization.
1 parent dcae355 commit 6f0f1d9

File tree

4 files changed

+213
-19
lines changed

4 files changed

+213
-19
lines changed

common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.util.sketch;
1919

20+
import java.io.IOException;
2021
import java.io.InputStream;
2122
import java.io.OutputStream;
2223

@@ -54,6 +55,25 @@
5455
* This implementation is largely based on the {@code CountMinSketch} class from stream-lib.
5556
*/
5657
abstract public class CountMinSketch {
58+
/**
59+
* Version number of the serialized binary format.
60+
*/
61+
public enum Version {
62+
V1(1);
63+
64+
private final int versionNumber;
65+
66+
Version(int versionNumber) {
67+
this.versionNumber = versionNumber;
68+
}
69+
70+
public int getVersionNumber() {
71+
return versionNumber;
72+
}
73+
}
74+
75+
public abstract Version version();
76+
5777
/**
5878
* Returns the relative error (or {@code eps}) of this {@link CountMinSketch}.
5979
*/
@@ -99,19 +119,23 @@ abstract public class CountMinSketch {
99119
*
100120
* Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed
101121
* can be merged.
122+
*
123+
* @exception IncompatibleMergeException if the {@code other} {@link CountMinSketch} has
124+
* incompatible depth, width, relative-error, confidence, or random seed.
102125
*/
103-
public abstract CountMinSketch mergeInPlace(CountMinSketch other);
126+
public abstract CountMinSketch mergeInPlace(CountMinSketch other)
127+
throws IncompatibleMergeException;
104128

105129
/**
106130
* Writes out this {@link CountMinSketch} to an output stream in binary format.
107131
*/
108-
public abstract void writeTo(OutputStream out);
132+
public abstract void writeTo(OutputStream out) throws IOException;
109133

110134
/**
111135
* Reads in a {@link CountMinSketch} from an input stream.
112136
*/
113-
public static CountMinSketch readFrom(InputStream in) {
114-
throw new UnsupportedOperationException("Not implemented yet");
137+
public static CountMinSketch readFrom(InputStream in) throws IOException {
138+
return CountMinSketchImpl.readFrom(in);
115139
}
116140

117141
/**

common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java

Lines changed: 116 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,30 @@
1717

1818
package org.apache.spark.util.sketch;
1919

20+
import java.io.DataInputStream;
21+
import java.io.DataOutputStream;
22+
import java.io.IOException;
23+
import java.io.InputStream;
2024
import java.io.OutputStream;
2125
import java.io.UnsupportedEncodingException;
2226
import java.util.Arrays;
2327
import java.util.Random;
2428

29+
/*
30+
* Binary format of a serialized CountMinSketchImpl, version 1 (all values written in big-endian
31+
* order):
32+
*
33+
* - Version number, always 1 (32 bit)
34+
* - Total count of added items (64 bit)
35+
* - Depth (32 bit)
36+
* - Width (32 bit)
37+
* - Hash functions (depth * 64 bit)
38+
* - Count table
39+
* - Row 0 (width * 64 bit)
40+
* - Row 1 (width * 64 bit)
41+
* - ...
42+
* - Row depth - 1 (width * 64 bit)
43+
*/
2544
class CountMinSketchImpl extends CountMinSketch {
2645
public static final long PRIME_MODULUS = (1L << 31) - 1;
2746

@@ -33,15 +52,15 @@ class CountMinSketchImpl extends CountMinSketch {
3352
private double eps;
3453
private double confidence;
3554

36-
public CountMinSketchImpl(int depth, int width, int seed) {
55+
CountMinSketchImpl(int depth, int width, int seed) {
3756
this.depth = depth;
3857
this.width = width;
3958
this.eps = 2.0 / width;
4059
this.confidence = 1 - 1 / Math.pow(2, depth);
4160
initTablesWith(depth, width, seed);
4261
}
4362

44-
public CountMinSketchImpl(double eps, double confidence, int seed) {
63+
CountMinSketchImpl(double eps, double confidence, int seed) {
4564
// 2/w = eps ; w = 2/eps
4665
// 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence)
4766
this.eps = eps;
@@ -51,6 +70,53 @@ public CountMinSketchImpl(double eps, double confidence, int seed) {
5170
initTablesWith(depth, width, seed);
5271
}
5372

73+
CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) {
74+
this.depth = depth;
75+
this.width = width;
76+
this.eps = 2.0 / width;
77+
this.confidence = 1 - 1 / Math.pow(2, depth);
78+
this.hashA = hashA;
79+
this.table = table;
80+
this.totalCount = totalCount;
81+
}
82+
83+
@Override
84+
public boolean equals(Object other) {
85+
if (other == this) {
86+
return true;
87+
}
88+
89+
if (other == null || !(other instanceof CountMinSketchImpl)) {
90+
return false;
91+
}
92+
93+
CountMinSketchImpl that = (CountMinSketchImpl) other;
94+
95+
return
96+
this.depth == that.depth &&
97+
this.width == that.width &&
98+
this.totalCount == that.totalCount &&
99+
Arrays.equals(this.hashA, that.hashA) &&
100+
Arrays.deepEquals(this.table, that.table);
101+
}
102+
103+
@Override
104+
public int hashCode() {
105+
int hash = depth;
106+
107+
hash = hash * 31 + width;
108+
hash = hash * 31 + (int) (totalCount ^ (totalCount >>> 32));
109+
hash = hash * 31 + Arrays.hashCode(hashA);
110+
hash = hash * 31 + Arrays.deepHashCode(table);
111+
112+
return hash;
113+
}
114+
115+
@Override
116+
public Version version() {
117+
return Version.V1;
118+
}
119+
54120
private void initTablesWith(int depth, int width, int seed) {
55121
this.table = new long[depth][width];
56122
this.hashA = new long[depth];
@@ -221,27 +287,29 @@ private long estimateCountForStringItem(String item) {
221287
}
222288

223289
@Override
224-
public CountMinSketch mergeInPlace(CountMinSketch other) {
290+
public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException {
225291
if (other == null) {
226-
throw new CMSMergeException("Cannot merge null estimator");
292+
throw new IncompatibleMergeException("Cannot merge null estimator");
227293
}
228294

229295
if (!(other instanceof CountMinSketchImpl)) {
230-
throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName());
296+
throw new IncompatibleMergeException(
297+
"Cannot merge estimator of class " + other.getClass().getName()
298+
);
231299
}
232300

233301
CountMinSketchImpl that = (CountMinSketchImpl) other;
234302

235303
if (this.depth != that.depth) {
236-
throw new CMSMergeException("Cannot merge estimators of different depth");
304+
throw new IncompatibleMergeException("Cannot merge estimators of different depth");
237305
}
238306

239307
if (this.width != that.width) {
240-
throw new CMSMergeException("Cannot merge estimators of different width");
308+
throw new IncompatibleMergeException("Cannot merge estimators of different width");
241309
}
242310

243311
if (!Arrays.equals(this.hashA, that.hashA)) {
244-
throw new CMSMergeException("Cannot merge estimators of different seed");
312+
throw new IncompatibleMergeException("Cannot merge estimators of different seed");
245313
}
246314

247315
for (int i = 0; i < this.table.length; ++i) {
@@ -256,13 +324,48 @@ public CountMinSketch mergeInPlace(CountMinSketch other) {
256324
}
257325

258326
@Override
259-
public void writeTo(OutputStream out) {
260-
throw new UnsupportedOperationException("Not implemented yet");
327+
public void writeTo(OutputStream out) throws IOException {
328+
DataOutputStream dos = new DataOutputStream(out);
329+
330+
dos.writeInt(version().getVersionNumber());
331+
332+
dos.writeLong(this.totalCount);
333+
dos.writeInt(this.depth);
334+
dos.writeInt(this.width);
335+
336+
for (int i = 0; i < this.depth; ++i) {
337+
dos.writeLong(this.hashA[i]);
338+
}
339+
340+
for (int i = 0; i < this.depth; ++i) {
341+
for (int j = 0; j < this.width; ++j) {
342+
dos.writeLong(table[i][j]);
343+
}
344+
}
261345
}
262346

263-
protected static class CMSMergeException extends RuntimeException {
264-
public CMSMergeException(String message) {
265-
super(message);
347+
public static CountMinSketchImpl readFrom(InputStream in) throws IOException {
348+
DataInputStream dis = new DataInputStream(in);
349+
350+
// Ignores version number
351+
dis.readInt();
352+
353+
long totalCount = dis.readLong();
354+
int depth = dis.readInt();
355+
int width = dis.readInt();
356+
357+
long hashA[] = new long[depth];
358+
for (int i = 0; i < depth; ++i) {
359+
hashA[i] = dis.readLong();
360+
}
361+
362+
long table[][] = new long[depth][width];
363+
for (int i = 0; i < depth; ++i) {
364+
for (int j = 0; j < width; ++j) {
365+
table[i][j] = dis.readLong();
366+
}
266367
}
368+
369+
return new CountMinSketchImpl(depth, width, totalCount, hashA, table);
267370
}
268371
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.sketch;
19+
20+
public class IncompatibleMergeException extends Exception {
21+
public IncompatibleMergeException(String message) {
22+
super(message);
23+
}
24+
}

common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.util.sketch
1919

20+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
21+
2022
import scala.reflect.ClassTag
2123
import scala.util.Random
2224

@@ -29,9 +31,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
2931

3032
private val seed = 42
3133

34+
// Serializes and deserializes a given `CountMinSketch`, then checks whether the deserialized
35+
// version is equivalent to the original one.
36+
private def checkSerDe(sketch: CountMinSketch): Unit = {
37+
val out = new ByteArrayOutputStream()
38+
sketch.writeTo(out)
39+
40+
val in = new ByteArrayInputStream(out.toByteArray)
41+
val deserialized = CountMinSketch.readFrom(in)
42+
43+
assert(sketch === deserialized)
44+
}
45+
3246
def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
3347
test(s"accuracy - $typeName") {
34-
val r = new Random()
48+
// Uses fixed seed to ensure reproducible test execution
49+
val r = new Random(31)
3550

3651
val numAllItems = 1000000
3752
val allItems = Array.fill(numAllItems)(itemGenerator(r))
@@ -45,7 +60,10 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
4560
}
4661

4762
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
63+
checkSerDe(sketch)
64+
4865
sampledItemIndices.foreach(i => sketch.add(allItems(i)))
66+
checkSerDe(sketch)
4967

5068
val probCorrect = {
5169
val numErrors = allItems.map { item =>
@@ -66,7 +84,9 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
6684

6785
def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
6886
test(s"mergeInPlace - $typeName") {
69-
val r = new Random()
87+
// Uses fixed seed to ensure reproducible test execution
88+
val r = new Random(31)
89+
7090
val numToMerge = 5
7191
val numItemsPerSketch = 100000
7292
val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) {
@@ -75,11 +95,16 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
7595

7696
val sketches = perSketchItems.map { items =>
7797
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
98+
checkSerDe(sketch)
99+
78100
items.foreach(sketch.add)
101+
checkSerDe(sketch)
102+
79103
sketch
80104
}
81105

82106
val mergedSketch = sketches.reduce(_ mergeInPlace _)
107+
checkSerDe(mergedSketch)
83108

84109
val expectedSketch = {
85110
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -109,4 +134,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
109134
testItemType[Long]("Long") { _.nextLong() }
110135

111136
testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
137+
138+
test("incompatible merge") {
139+
intercept[IncompatibleMergeException] {
140+
CountMinSketch.create(10, 10, 1).mergeInPlace(null)
141+
}
142+
143+
intercept[IncompatibleMergeException] {
144+
val sketch1 = CountMinSketch.create(10, 20, 1)
145+
val sketch2 = CountMinSketch.create(10, 20, 2)
146+
sketch1.mergeInPlace(sketch2)
147+
}
148+
149+
intercept[IncompatibleMergeException] {
150+
val sketch1 = CountMinSketch.create(10, 10, 1)
151+
val sketch2 = CountMinSketch.create(10, 20, 2)
152+
sketch1.mergeInPlace(sketch2)
153+
}
154+
}
112155
}

0 commit comments

Comments
 (0)