Skip to content

Commit ce38a35

Browse files
lianchengrxin
authored andcommitted
[SPARK-12935][SQL] DataFrame API for Count-Min Sketch
This PR integrates Count-Min Sketch from spark-sketch into DataFrame. This version resorts to `RDD.aggregate` for building the sketch. A more performant UDAF version can be built in future follow-up PRs. Author: Cheng Lian <[email protected]> Closes #10911 from liancheng/cms-df-api.
1 parent e7f9199 commit ce38a35

File tree

7 files changed

+205
-37
lines changed

7 files changed

+205
-37
lines changed

common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@ public abstract class BloomFilter {
4747
public enum Version {
4848
/**
4949
* {@code BloomFilter} binary format version 1 (all values written in big-endian order):
50-
* - Version number, always 1 (32 bit)
51-
* - Total number of words of the underlying bit array (32 bit)
52-
* - The words/longs (numWords * 64 bit)
53-
* - Number of hash functions (32 bit)
50+
* <ul>
51+
* <li>Version number, always 1 (32 bit)</li>
52+
* <li>Total number of words of the underlying bit array (32 bit)</li>
53+
* <li>The words/longs (numWords * 64 bit)</li>
54+
* <li>Number of hash functions (32 bit)</li>
55+
* </ul>
5456
*/
5557
V1(1);
5658

common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,22 @@ abstract public class CountMinSketch {
5959
public enum Version {
6060
/**
6161
* {@code CountMinSketch} binary format version 1 (all values written in big-endian order):
62-
* - Version number, always 1 (32 bit)
63-
* - Total count of added items (64 bit)
64-
* - Depth (32 bit)
65-
* - Width (32 bit)
66-
* - Hash functions (depth * 64 bit)
67-
* - Count table
68-
* - Row 0 (width * 64 bit)
69-
* - Row 1 (width * 64 bit)
70-
* - ...
71-
* - Row depth - 1 (width * 64 bit)
62+
* <ul>
63+
* <li>Version number, always 1 (32 bit)</li>
64+
* <li>Total count of added items (64 bit)</li>
65+
* <li>Depth (32 bit)</li>
66+
* <li>Width (32 bit)</li>
67+
* <li>Hash functions (depth * 64 bit)</li>
68+
* <li>
69+
* Count table
70+
* <ul>
71+
* <li>Row 0 (width * 64 bit)</li>
72+
* <li>Row 1 (width * 64 bit)</li>
73+
* <li>...</li>
74+
* <li>Row {@code depth - 1} (width * 64 bit)</li>
75+
* </ul>
76+
* </li>
77+
* </ul>
7278
*/
7379
V1(1);
7480

common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@
2121
import java.io.DataOutputStream;
2222
import java.io.IOException;
2323
import java.io.InputStream;
24+
import java.io.ObjectInputStream;
25+
import java.io.ObjectOutputStream;
2426
import java.io.OutputStream;
27+
import java.io.Serializable;
2528
import java.io.UnsupportedEncodingException;
2629
import java.util.Arrays;
2730
import java.util.Random;
2831

29-
class CountMinSketchImpl extends CountMinSketch {
30-
public static final long PRIME_MODULUS = (1L << 31) - 1;
32+
class CountMinSketchImpl extends CountMinSketch implements Serializable {
33+
private static final long PRIME_MODULUS = (1L << 31) - 1;
3134

3235
private int depth;
3336
private int width;
@@ -37,6 +40,9 @@ class CountMinSketchImpl extends CountMinSketch {
3740
private double eps;
3841
private double confidence;
3942

43+
private CountMinSketchImpl() {
44+
}
45+
4046
CountMinSketchImpl(int depth, int width, int seed) {
4147
this.depth = depth;
4248
this.width = width;
@@ -55,16 +61,6 @@ class CountMinSketchImpl extends CountMinSketch {
5561
initTablesWith(depth, width, seed);
5662
}
5763

58-
CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) {
59-
this.depth = depth;
60-
this.width = width;
61-
this.eps = 2.0 / width;
62-
this.confidence = 1 - 1 / Math.pow(2, depth);
63-
this.hashA = hashA;
64-
this.table = table;
65-
this.totalCount = totalCount;
66-
}
67-
6864
@Override
6965
public boolean equals(Object other) {
7066
if (other == this) {
@@ -325,27 +321,43 @@ public void writeTo(OutputStream out) throws IOException {
325321
}
326322

327323
public static CountMinSketchImpl readFrom(InputStream in) throws IOException {
324+
CountMinSketchImpl sketch = new CountMinSketchImpl();
325+
sketch.readFrom0(in);
326+
return sketch;
327+
}
328+
329+
private void readFrom0(InputStream in) throws IOException {
328330
DataInputStream dis = new DataInputStream(in);
329331

330-
// Ignores version number
331-
dis.readInt();
332+
int version = dis.readInt();
333+
if (version != Version.V1.getVersionNumber()) {
334+
throw new IOException("Unexpected Count-Min Sketch version number (" + version + ")");
335+
}
332336

333-
long totalCount = dis.readLong();
334-
int depth = dis.readInt();
335-
int width = dis.readInt();
337+
this.totalCount = dis.readLong();
338+
this.depth = dis.readInt();
339+
this.width = dis.readInt();
340+
this.eps = 2.0 / width;
341+
this.confidence = 1 - 1 / Math.pow(2, depth);
336342

337-
long hashA[] = new long[depth];
343+
this.hashA = new long[depth];
338344
for (int i = 0; i < depth; ++i) {
339-
hashA[i] = dis.readLong();
345+
this.hashA[i] = dis.readLong();
340346
}
341347

342-
long table[][] = new long[depth][width];
348+
this.table = new long[depth][width];
343349
for (int i = 0; i < depth; ++i) {
344350
for (int j = 0; j < width; ++j) {
345-
table[i][j] = dis.readLong();
351+
this.table[i][j] = dis.readLong();
346352
}
347353
}
354+
}
355+
356+
private void writeObject(ObjectOutputStream out) throws IOException {
357+
this.writeTo(out);
358+
}
348359

349-
return new CountMinSketchImpl(depth, width, totalCount, hashA, table);
360+
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
361+
this.readFrom0(in);
350362
}
351363
}

sql/core/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@
4242
<version>1.5.6</version>
4343
<type>jar</type>
4444
</dependency>
45+
<dependency>
46+
<groupId>org.apache.spark</groupId>
47+
<artifactId>spark-sketch_2.10</artifactId>
48+
<version>${project.version}</version>
49+
</dependency>
4550
<dependency>
4651
<groupId>org.apache.spark</groupId>
4752
<artifactId>spark-core_${scala.binary.version}</artifactId>

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import scala.collection.JavaConverters._
2323

2424
import org.apache.spark.annotation.Experimental
2525
import org.apache.spark.sql.execution.stat._
26+
import org.apache.spark.sql.types._
27+
import org.apache.spark.util.sketch.CountMinSketch
2628

2729
/**
2830
* :: Experimental ::
@@ -309,4 +311,83 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
309311
def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
310312
sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
311313
}
314+
315+
/**
316+
* Builds a Count-min Sketch over a specified column.
317+
*
318+
* @param colName name of the column over which the sketch is built
319+
* @param depth depth of the sketch
320+
* @param width width of the sketch
321+
* @param seed random seed
322+
* @return a [[CountMinSketch]] over column `colName`
323+
* @since 2.0.0
324+
*/
325+
def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = {
326+
countMinSketch(Column(colName), depth, width, seed)
327+
}
328+
329+
/**
330+
* Builds a Count-min Sketch over a specified column.
331+
*
332+
* @param colName name of the column over which the sketch is built
333+
* @param eps relative error of the sketch
334+
* @param confidence confidence of the sketch
335+
* @param seed random seed
336+
* @return a [[CountMinSketch]] over column `colName`
337+
* @since 2.0.0
338+
*/
339+
def countMinSketch(
340+
colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
341+
countMinSketch(Column(colName), eps, confidence, seed)
342+
}
343+
344+
/**
345+
* Builds a Count-min Sketch over a specified column.
346+
*
347+
* @param col the column over which the sketch is built
348+
* @param depth depth of the sketch
349+
* @param width width of the sketch
350+
* @param seed random seed
351+
* @return a [[CountMinSketch]] over column `colName`
352+
* @since 2.0.0
353+
*/
354+
def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = {
355+
countMinSketch(col, CountMinSketch.create(depth, width, seed))
356+
}
357+
358+
/**
359+
* Builds a Count-min Sketch over a specified column.
360+
*
361+
* @param col the column over which the sketch is built
362+
* @param eps relative error of the sketch
363+
* @param confidence confidence of the sketch
364+
* @param seed random seed
365+
* @return a [[CountMinSketch]] over column `colName`
366+
* @since 2.0.0
367+
*/
368+
def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
369+
countMinSketch(col, CountMinSketch.create(eps, confidence, seed))
370+
}
371+
372+
private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = {
373+
val singleCol = df.select(col)
374+
val colType = singleCol.schema.head.dataType
375+
376+
require(
377+
colType == StringType || colType.isInstanceOf[IntegralType],
378+
s"Count-min Sketch only supports string type and integral types, " +
379+
s"and does not support type $colType."
380+
)
381+
382+
singleCol.rdd.aggregate(zero)(
383+
(sketch: CountMinSketch, row: Row) => {
384+
sketch.add(row.get(0))
385+
sketch
386+
},
387+
388+
(sketch1: CountMinSketch, sketch2: CountMinSketch) => {
389+
sketch1.mergeInPlace(sketch2)
390+
}
391+
)
392+
}
312393
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
import org.apache.spark.api.java.JavaRDD;
3636
import org.apache.spark.api.java.JavaSparkContext;
3737
import org.apache.spark.sql.*;
38-
import static org.apache.spark.sql.functions.*;
3938
import org.apache.spark.sql.test.TestSQLContext;
4039
import org.apache.spark.sql.types.*;
40+
import org.apache.spark.util.sketch.CountMinSketch;
41+
import static org.apache.spark.sql.functions.*;
4142
import static org.apache.spark.sql.types.DataTypes.*;
4243

4344
public class JavaDataFrameSuite {
@@ -321,4 +322,29 @@ public void testTextLoad() {
321322
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
322323
Assert.assertEquals(5L, df2.count());
323324
}
325+
326+
@Test
327+
public void testCountMinSketch() {
328+
DataFrame df = context.range(1000);
329+
330+
CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
331+
Assert.assertEquals(sketch1.totalCount(), 1000);
332+
Assert.assertEquals(sketch1.depth(), 10);
333+
Assert.assertEquals(sketch1.width(), 20);
334+
335+
CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42);
336+
Assert.assertEquals(sketch2.totalCount(), 1000);
337+
Assert.assertEquals(sketch2.depth(), 10);
338+
Assert.assertEquals(sketch2.width(), 20);
339+
340+
CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42);
341+
Assert.assertEquals(sketch3.totalCount(), 1000);
342+
Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4);
343+
Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3);
344+
345+
CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42);
346+
Assert.assertEquals(sketch4.totalCount(), 1000);
347+
Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4);
348+
Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3);
349+
}
324350
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@ package org.apache.spark.sql
1919

2020
import java.util.Random
2121

22+
import org.scalatest.Matchers._
23+
2224
import org.apache.spark.sql.functions.col
2325
import org.apache.spark.sql.test.SharedSQLContext
26+
import org.apache.spark.sql.types.DoubleType
2427

2528
class DataFrameStatSuite extends QueryTest with SharedSQLContext {
2629
import testImplicits._
@@ -210,4 +213,37 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
210213
sampled.groupBy("key").count().orderBy("key"),
211214
Seq(Row(0, 6), Row(1, 11)))
212215
}
216+
217+
// This test case only verifies that `DataFrame.countMinSketch()` methods do return
218+
// `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in
219+
// `CountMinSketchSuite` in project spark-sketch.
220+
test("countMinSketch") {
221+
val df = sqlContext.range(1000)
222+
223+
val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42)
224+
assert(sketch1.totalCount() === 1000)
225+
assert(sketch1.depth() === 10)
226+
assert(sketch1.width() === 20)
227+
228+
val sketch2 = df.stat.countMinSketch($"id", depth = 10, width = 20, seed = 42)
229+
assert(sketch2.totalCount() === 1000)
230+
assert(sketch2.depth() === 10)
231+
assert(sketch2.width() === 20)
232+
233+
val sketch3 = df.stat.countMinSketch("id", eps = 0.001, confidence = 0.99, seed = 42)
234+
assert(sketch3.totalCount() === 1000)
235+
assert(sketch3.relativeError() === 0.001)
236+
assert(sketch3.confidence() === 0.99 +- 5e-3)
237+
238+
val sketch4 = df.stat.countMinSketch($"id", eps = 0.001, confidence = 0.99, seed = 42)
239+
assert(sketch4.totalCount() === 1000)
240+
assert(sketch4.relativeError() === 0.001 +- 1e04)
241+
assert(sketch4.confidence() === 0.99 +- 5e-3)
242+
243+
intercept[IllegalArgumentException] {
244+
df.select('id cast DoubleType as 'id)
245+
.stat
246+
.countMinSketch('id, depth = 10, width = 20, seed = 42)
247+
}
248+
}
213249
}

0 commit comments

Comments
 (0)