Skip to content

Commit e97d7f9

Browse files
committed
CountMinSketch serialization
1 parent e789b1d commit e97d7f9

File tree

4 files changed

+183
-16
lines changed

4 files changed

+183
-16
lines changed

common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.util.sketch;
1919

20+
import java.io.DataInputStream;
21+
import java.io.IOException;
2022
import java.io.InputStream;
2123
import java.io.OutputStream;
2224

@@ -54,6 +56,25 @@
5456
* This implementation is largely based on the {@code CountMinSketch} class from stream-lib.
5557
*/
5658
abstract public class CountMinSketch {
59+
/**
60+
* Version number of the serialized binary format.
61+
*/
62+
public enum Version {
63+
V1(1);
64+
65+
private final int versionNumber;
66+
67+
Version(int versionNumber) {
68+
this.versionNumber = versionNumber;
69+
}
70+
71+
public int getVersionNumber() {
72+
return versionNumber;
73+
}
74+
}
75+
76+
public abstract Version version();
77+
5778
/**
5879
* Returns the relative error (or {@code eps}) of this {@link CountMinSketch}.
5980
*/
@@ -99,19 +120,44 @@ abstract public class CountMinSketch {
99120
*
100121
* Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed
101122
* can be merged.
123+
*
124+
* @exception CountMinSketchMergeException if the {@code other} {@link CountMinSketch} has
125+
* incompatible depth, width, relative-error, confidence, or random seed.
102126
*/
103-
public abstract CountMinSketch mergeInPlace(CountMinSketch other);
127+
public abstract CountMinSketch mergeInPlace(CountMinSketch other)
128+
throws CountMinSketchMergeException;
104129

105130
/**
106131
* Writes out this {@link CountMinSketch} to an output stream in binary format.
107132
*/
108-
public abstract void writeTo(OutputStream out);
133+
public abstract void writeTo(OutputStream out) throws IOException;
109134

110135
/**
111136
* Reads in a {@link CountMinSketch} from an input stream.
112137
*/
113-
public static CountMinSketch readFrom(InputStream in) {
114-
throw new UnsupportedOperationException("Not implemented yet");
138+
public static CountMinSketch readFrom(InputStream in) throws IOException {
139+
DataInputStream dis = new DataInputStream(in);
140+
141+
// Ignores version number
142+
dis.readInt();
143+
144+
long totalCount = dis.readLong();
145+
int depth = dis.readInt();
146+
int width = dis.readInt();
147+
148+
long hashA[] = new long[depth];
149+
for (int i = 0; i < depth; ++i) {
150+
hashA[i] = dis.readLong();
151+
}
152+
153+
long table[][] = new long[depth][width];
154+
for (int i = 0; i < depth; ++i) {
155+
for (int j = 0; j < width; ++j) {
156+
table[i][j] = dis.readLong();
157+
}
158+
}
159+
160+
return new CountMinSketchImpl(depth, width, totalCount, hashA, table);
115161
}
116162

117163
/**

common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.util.sketch;
1919

20+
import java.io.DataOutputStream;
21+
import java.io.IOException;
2022
import java.io.OutputStream;
2123
import java.io.UnsupportedEncodingException;
2224
import java.util.Arrays;
@@ -51,6 +53,53 @@ public CountMinSketchImpl(double eps, double confidence, int seed) {
5153
initTablesWith(depth, width, seed);
5254
}
5355

56+
public CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) {
57+
this.depth = depth;
58+
this.width = width;
59+
this.eps = 2.0 / width;
60+
this.confidence = 1 - 1 / Math.pow(2, depth);
61+
this.hashA = hashA;
62+
this.table = table;
63+
this.totalCount = totalCount;
64+
}
65+
66+
@Override
67+
public boolean equals(Object other) {
68+
if (other == this) {
69+
return true;
70+
}
71+
72+
if (!(other instanceof CountMinSketchImpl)) {
73+
return false;
74+
}
75+
76+
CountMinSketchImpl that = (CountMinSketchImpl) other;
77+
78+
if (this.depth == that.depth &&
79+
this.width == that.width &&
80+
this.totalCount == that.totalCount) {
81+
for (int i = 0; i < depth; ++i) {
82+
if (this.hashA[i] != that.hashA[i]) {
83+
return false;
84+
}
85+
86+
for (int j = 0; j < width; ++j) {
87+
if (this.table[i][j] != that.table[i][j]) {
88+
return false;
89+
}
90+
}
91+
}
92+
return true;
93+
} else {
94+
return false;
95+
}
96+
}
97+
98+
@Override
99+
public Version version() {
100+
return Version.V1;
101+
}
102+
54103
private void initTablesWith(int depth, int width, int seed) {
55104
this.table = new long[depth][width];
56105
this.hashA = new long[depth];
@@ -221,27 +270,27 @@ private long estimateCountForStringItem(String item) {
221270
}
222271

223272
@Override
224-
public CountMinSketch mergeInPlace(CountMinSketch other) {
273+
public CountMinSketch mergeInPlace(CountMinSketch other) throws CountMinSketchMergeException {
225274
if (other == null) {
226-
throw new CMSMergeException("Cannot merge null estimator");
275+
throw new CountMinSketchMergeException("Cannot merge null estimator");
227276
}
228277

229278
if (!(other instanceof CountMinSketchImpl)) {
230-
throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName());
279+
throw new CountMinSketchMergeException("Cannot merge estimator of class " + other.getClass().getName());
231280
}
232281

233282
CountMinSketchImpl that = (CountMinSketchImpl) other;
234283

235284
if (this.depth != that.depth) {
236-
throw new CMSMergeException("Cannot merge estimators of different depth");
285+
throw new CountMinSketchMergeException("Cannot merge estimators of different depth");
237286
}
238287

239288
if (this.width != that.width) {
240-
throw new CMSMergeException("Cannot merge estimators of different width");
289+
throw new CountMinSketchMergeException("Cannot merge estimators of different width");
241290
}
242291

243292
if (!Arrays.equals(this.hashA, that.hashA)) {
244-
throw new CMSMergeException("Cannot merge estimators of different seed");
293+
throw new CountMinSketchMergeException("Cannot merge estimators of different seed");
245294
}
246295

247296
for (int i = 0; i < this.table.length; ++i) {
@@ -256,13 +305,23 @@ public CountMinSketch mergeInPlace(CountMinSketch other) {
256305
}
257306

258307
@Override
259-
public void writeTo(OutputStream out) {
260-
throw new UnsupportedOperationException("Not implemented yet");
261-
}
308+
public void writeTo(OutputStream out) throws IOException {
309+
DataOutputStream dos = new DataOutputStream(out);
310+
311+
dos.writeInt(version().getVersionNumber());
312+
313+
dos.writeLong(this.totalCount);
314+
dos.writeInt(this.depth);
315+
dos.writeInt(this.width);
262316

263-
protected static class CMSMergeException extends RuntimeException {
264-
public CMSMergeException(String message) {
265-
super(message);
317+
for (int i = 0; i < this.depth; ++i) {
318+
dos.writeLong(this.hashA[i]);
319+
}
320+
321+
for (int i = 0; i < this.depth; ++i) {
322+
for (int j = 0; j < this.width; ++j) {
323+
dos.writeLong(table[i][j]);
324+
}
266325
}
267326
}
268327
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.sketch;
19+
20+
public class CountMinSketchMergeException extends Exception {
21+
public CountMinSketchMergeException(String message) {
22+
super(message);
23+
}
24+
}

common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.util.sketch
1919

20+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
21+
2022
import scala.reflect.ClassTag
2123
import scala.util.Random
2224

@@ -29,6 +31,16 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
2931

3032
private val seed = 42
3133

34+
private def checkSerDe(sketch: CountMinSketch): Unit = {
35+
val out = new ByteArrayOutputStream()
36+
sketch.writeTo(out)
37+
38+
val in = new ByteArrayInputStream(out.toByteArray)
39+
val deserialized = CountMinSketch.readFrom(in)
40+
41+
assert(sketch === deserialized)
42+
}
43+
3244
def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
3345
test(s"accuracy - $typeName") {
3446
val r = new Random()
@@ -45,7 +57,10 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
4557
}
4658

4759
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
60+
checkSerDe(sketch)
61+
4862
sampledItemIndices.foreach(i => sketch.add(allItems(i)))
63+
checkSerDe(sketch)
4964

5065
val probCorrect = {
5166
val numErrors = allItems.map { item =>
@@ -75,11 +90,16 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
7590

7691
val sketches = perSketchItems.map { items =>
7792
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
93+
checkSerDe(sketch)
94+
7895
items.foreach(sketch.add)
96+
checkSerDe(sketch)
97+
7998
sketch
8099
}
81100

82101
val mergedSketch = sketches.reduce(_ mergeInPlace _)
102+
checkSerDe(mergedSketch)
83103

84104
val expectedSketch = {
85105
val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -109,4 +129,22 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
109129
testItemType[Long]("Long") { _.nextLong() }
110130

111131
testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
132+
133+
test("incompatible merge") {
134+
intercept[CountMinSketchMergeException] {
135+
CountMinSketch.create(10, 10, 1).mergeInPlace(null)
136+
}
137+
138+
intercept[CountMinSketchMergeException] {
139+
val sketch1 = CountMinSketch.create(10, 20, 1)
140+
val sketch2 = CountMinSketch.create(10, 20, 2)
141+
sketch1.mergeInPlace(sketch2)
142+
}
143+
144+
intercept[CountMinSketchMergeException] {
145+
val sketch1 = CountMinSketch.create(10, 10, 1)
146+
val sketch2 = CountMinSketch.create(10, 20, 2)
147+
sketch1.mergeInPlace(sketch2)
148+
}
149+
}
112150
}

0 commit comments

Comments
 (0)