Skip to content

Commit a2d5d10

Browse files
larvaboyLiuyang Li
authored andcommitted
Add ApproximateCountDistinct aggregates and functions.
We use stream-lib's HyperLogLog to approximately count the number of distinct elements in each partition, and merge the HyperLogLogs to compute the final result. If the expressions can not be successfully broken apart, we fall back to the exact CountDistinct.
1 parent 7ad273a commit a2d5d10

File tree

1 file changed

+68
-0
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions

1 file changed

+68
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import com.clearspring.analytics.stream.cardinality.HyperLogLog
21+
2022
import org.apache.spark.sql.catalyst.types._
2123
import org.apache.spark.sql.catalyst.trees
2224
import org.apache.spark.sql.catalyst.errors.TreeNodeException
@@ -169,6 +171,44 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
169171
override def newInstance() = new CountDistinctFunction(expressions, this)
170172
}
171173

174+
case class ApproxCountDistinctPartition(child: Expression)
175+
extends AggregateExpression with trees.UnaryNode[Expression] {
176+
override def references = child.references
177+
override def nullable = false
178+
override def dataType = child.dataType
179+
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
180+
override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this)
181+
}
182+
183+
case class ApproxCountDistinctMerge(child: Expression)
184+
extends AggregateExpression with trees.UnaryNode[Expression] {
185+
override def references = child.references
186+
override def nullable = false
187+
override def dataType = IntegerType
188+
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
189+
override def newInstance() = new ApproxCountDistinctMergeFunction(child, this)
190+
}
191+
192+
object ApproxCountDistinct {
193+
val RelativeSD = 0.05
194+
}
195+
196+
case class ApproxCountDistinct(child: Expression)
197+
extends PartialAggregate with trees.UnaryNode[Expression] {
198+
override def references = child.references
199+
override def nullable = false
200+
override def dataType = IntegerType
201+
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
202+
203+
override def asPartial: SplitEvaluation = {
204+
val partialCount = Alias(ApproxCountDistinctPartition(child),
205+
"PartialApproxCountDistinct")()
206+
SplitEvaluation(ApproxCountDistinctMerge(partialCount.toAttribute), partialCount :: Nil)
207+
}
208+
209+
override def newInstance() = new CountDistinctFunction((child :: Nil), this)
210+
}
211+
172212
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
173213
override def references = child.references
174214
override def nullable = false
@@ -268,6 +308,34 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
268308
override def eval(input: Row): Any = count
269309
}
270310

311+
case class ApproxCountDistinctPartitionFunction(expr: Expression, base: AggregateExpression)
312+
extends AggregateFunction {
313+
def this() = this(null, null) // Required for serialization.
314+
315+
private val hyperLogLog = new HyperLogLog(ApproxCountDistinct.RelativeSD)
316+
317+
override def update(input: Row): Unit = {
318+
val evaluatedExpr = expr.eval(input)
319+
Option(evaluatedExpr).foreach(hyperLogLog.offer(_))
320+
}
321+
322+
override def eval(input: Row): Any = hyperLogLog
323+
}
324+
325+
case class ApproxCountDistinctMergeFunction(expr: Expression, base: AggregateExpression)
326+
extends AggregateFunction {
327+
def this() = this(null, null) // Required for serialization.
328+
329+
private val hyperLogLog = new HyperLogLog(ApproxCountDistinct.RelativeSD)
330+
331+
override def update(input: Row): Unit = {
332+
val evaluatedExpr = expr.eval(input)
333+
Option(evaluatedExpr.asInstanceOf[HyperLogLog]).foreach(hyperLogLog.addAll(_))
334+
}
335+
336+
override def eval(input: Row): Any = hyperLogLog.cardinality()
337+
}
338+
271339
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
272340
def this() = this(null, null) // Required for serialization.
273341

0 commit comments

Comments
 (0)