Skip to content

Commit c33b8dc

Browse files
larvaboyrxin
authored andcommitted
Implement ApproximateCountDistinct for SparkSql
Add the implementation for ApproximateCountDistinct to SparkSql. We use the HyperLogLog algorithm implemented in stream-lib, and do the count in two phases: 1) counting the number of distinct elements in each partitions, and 2) merge the HyperLogLog results from different partitions. A simple serializer and test cases are added as well. Author: larvaboy <[email protected]> Closes apache#737 from larvaboy/master and squashes the following commits: bd8ef3f [larvaboy] Add support of user-provided standard deviation to ApproxCountDistinct. 9ba8360 [larvaboy] Fix alignment and null handling issues. 95b4067 [larvaboy] Add a test case for count distinct and approximate count distinct. f57917d [larvaboy] Add the parser for the approximate count. a2d5d10 [larvaboy] Add ApproximateCountDistinct aggregates and functions. 7ad273a [larvaboy] Add SparkSql serializer for HyperLogLog. 1d9aacf [larvaboy] Fix a minor typo in the toString method of the Count case class. 653542b [larvaboy] Fix a couple of minor typos.
1 parent 92cebad commit c33b8dc

File tree

5 files changed

+122
-7
lines changed

5 files changed

+122
-7
lines changed

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
217217
* Return approximate number of distinct values for each key in this RDD.
218218
* The accuracy of approximation can be controlled through the relative standard deviation
219219
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
220-
* more accurate counts but increase the memory footprint and vise versa. Uses the provided
220+
* more accurate counts but increase the memory footprint and vice versa. Uses the provided
221221
* Partitioner to partition the output RDD.
222222
*/
223223
def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = {
@@ -232,7 +232,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
232232
* Return approximate number of distinct values for each key in this RDD.
233233
* The accuracy of approximation can be controlled through the relative standard deviation
234234
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
235-
* more accurate counts but increase the memory footprint and vise versa. HashPartitions the
235+
* more accurate counts but increase the memory footprint and vice versa. HashPartitions the
236236
* output RDD into numPartitions.
237237
*
238238
*/
@@ -244,7 +244,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
244244
* Return approximate number of distinct values for each key this RDD.
245245
* The accuracy of approximation can be controlled through the relative standard deviation
246246
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
247-
* more accurate counts but increase the memory footprint and vise versa. The default value of
247+
* more accurate counts but increase the memory footprint and vice versa. The default value of
248248
* relativeSD is 0.05. Hash-partitions the output RDD using the existing partitioner/parallelism
249249
* level.
250250
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
9393
protected val AND = Keyword("AND")
9494
protected val AS = Keyword("AS")
9595
protected val ASC = Keyword("ASC")
96+
protected val APPROXIMATE = Keyword("APPROXIMATE")
9697
protected val AVG = Keyword("AVG")
9798
protected val BY = Keyword("BY")
9899
protected val CAST = Keyword("CAST")
@@ -318,6 +319,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
318319
COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } |
319320
COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } |
320321
COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
322+
APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ {
323+
case exp => ApproxCountDistinct(exp)
324+
} |
325+
APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ {
326+
case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble)
327+
} |
321328
FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
322329
AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
323330
MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import com.clearspring.analytics.stream.cardinality.HyperLogLog
21+
2022
import org.apache.spark.sql.catalyst.types._
2123
import org.apache.spark.sql.catalyst.trees
2224
import org.apache.spark.sql.catalyst.errors.TreeNodeException
@@ -146,7 +148,6 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
146148
override def eval(input: Row): Any = currentMax
147149
}
148150

149-
150151
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
151152
override def references = child.references
152153
override def nullable = false
@@ -166,10 +167,47 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
166167
override def references = expressions.flatMap(_.references).toSet
167168
override def nullable = false
168169
override def dataType = IntegerType
169-
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
170+
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
170171
override def newInstance() = new CountDistinctFunction(expressions, this)
171172
}
172173

174+
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
175+
extends AggregateExpression with trees.UnaryNode[Expression] {
176+
override def references = child.references
177+
override def nullable = false
178+
override def dataType = child.dataType
179+
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
180+
override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
181+
}
182+
183+
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
184+
extends AggregateExpression with trees.UnaryNode[Expression] {
185+
override def references = child.references
186+
override def nullable = false
187+
override def dataType = IntegerType
188+
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
189+
override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD)
190+
}
191+
192+
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
193+
extends PartialAggregate with trees.UnaryNode[Expression] {
194+
override def references = child.references
195+
override def nullable = false
196+
override def dataType = IntegerType
197+
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
198+
199+
override def asPartial: SplitEvaluation = {
200+
val partialCount =
201+
Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")()
202+
203+
SplitEvaluation(
204+
ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD),
205+
partialCount :: Nil)
206+
}
207+
208+
override def newInstance() = new CountDistinctFunction(child :: Nil, this)
209+
}
210+
173211
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
174212
override def references = child.references
175213
override def nullable = false
@@ -269,6 +307,42 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
269307
override def eval(input: Row): Any = count
270308
}
271309

310+
case class ApproxCountDistinctPartitionFunction(
311+
expr: Expression,
312+
base: AggregateExpression,
313+
relativeSD: Double)
314+
extends AggregateFunction {
315+
def this() = this(null, null, 0) // Required for serialization.
316+
317+
private val hyperLogLog = new HyperLogLog(relativeSD)
318+
319+
override def update(input: Row): Unit = {
320+
val evaluatedExpr = expr.eval(input)
321+
if (evaluatedExpr != null) {
322+
hyperLogLog.offer(evaluatedExpr)
323+
}
324+
}
325+
326+
override def eval(input: Row): Any = hyperLogLog
327+
}
328+
329+
case class ApproxCountDistinctMergeFunction(
330+
expr: Expression,
331+
base: AggregateExpression,
332+
relativeSD: Double)
333+
extends AggregateFunction {
334+
def this() = this(null, null, 0) // Required for serialization.
335+
336+
private val hyperLogLog = new HyperLogLog(relativeSD)
337+
338+
override def update(input: Row): Unit = {
339+
val evaluatedExpr = expr.eval(input)
340+
hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
341+
}
342+
343+
override def eval(input: Row): Any = hyperLogLog.cardinality()
344+
}
345+
272346
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
273347
def this() = this(null, null) // Required for serialization.
274348

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.nio.ByteBuffer
2121

2222
import scala.reflect.ClassTag
2323

24+
import com.clearspring.analytics.stream.cardinality.HyperLogLog
2425
import com.esotericsoftware.kryo.io.{Input, Output}
2526
import com.esotericsoftware.kryo.{Serializer, Kryo}
2627

@@ -44,6 +45,8 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
4445
kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
4546
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
4647
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
48+
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
49+
new HyperLogLogSerializer)
4750
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
4851
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
4952
kryo.setReferences(false)
@@ -81,6 +84,20 @@ private[sql] class BigDecimalSerializer extends Serializer[BigDecimal] {
8184
}
8285
}
8386

87+
private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
88+
def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) {
89+
val bytes = hyperLogLog.getBytes()
90+
output.writeInt(bytes.length)
91+
output.writeBytes(bytes)
92+
}
93+
94+
def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = {
95+
val length = input.readInt()
96+
val bytes = input.readBytes(length)
97+
HyperLogLog.Builder.build(bytes)
98+
}
99+
}
100+
84101
/**
85102
* Maps do not have a no arg constructor and so cannot be serialized by default. So, we serialize
86103
* them as `Array[(k,v)]`.

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,25 @@ class SQLQuerySuite extends QueryTest {
9696
test("count") {
9797
checkAnswer(
9898
sql("SELECT COUNT(*) FROM testData2"),
99-
testData2.count()
100-
)
99+
testData2.count())
100+
}
101+
102+
test("count distinct") {
103+
checkAnswer(
104+
sql("SELECT COUNT(DISTINCT b) FROM testData2"),
105+
2)
106+
}
107+
108+
test("approximate count distinct") {
109+
checkAnswer(
110+
sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"),
111+
3)
112+
}
113+
114+
test("approximate count distinct with user provided standard deviation") {
115+
checkAnswer(
116+
sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"),
117+
3)
101118
}
102119

103120
// No support for primitive nulls yet.

0 commit comments

Comments
 (0)