Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return approximate number of distinct values for each key in this RDD.
* The accuracy of approximation can be controlled through the relative standard deviation
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
* more accurate counts but increase the memory footprint and vise versa. Uses the provided
* more accurate counts but increase the memory footprint and vice versa. Uses the provided
* Partitioner to partition the output RDD.
*/
def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = {
Expand All @@ -232,7 +232,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return approximate number of distinct values for each key in this RDD.
* The accuracy of approximation can be controlled through the relative standard deviation
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
* more accurate counts but increase the memory footprint and vise versa. HashPartitions the
* more accurate counts but increase the memory footprint and vice versa. HashPartitions the
* output RDD into numPartitions.
*
*/
Expand All @@ -244,7 +244,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return approximate number of distinct values for each key this RDD.
* The accuracy of approximation can be controlled through the relative standard deviation
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
* more accurate counts but increase the memory footprint and vise versa. The default value of
* more accurate counts but increase the memory footprint and vice versa. The default value of
* relativeSD is 0.05. Hash-partitions the output RDD using the existing partitioner/parallelism
* level.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val AND = Keyword("AND")
protected val AS = Keyword("AS")
protected val ASC = Keyword("ASC")
protected val APPROXIMATE = Keyword("APPROXIMATE")
protected val AVG = Keyword("AVG")
protected val BY = Keyword("BY")
protected val CAST = Keyword("CAST")
Expand Down Expand Up @@ -318,6 +319,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } |
COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } |
COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ {
case exp => ApproxCountDistinct(exp)
} |
APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ {
case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble)
} |
FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import com.clearspring.analytics.stream.cardinality.HyperLogLog

import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
Expand Down Expand Up @@ -146,7 +148,6 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def eval(input: Row): Any = currentMax
}


case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
Expand All @@ -166,10 +167,47 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = IntegerType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
override def newInstance() = new CountDistinctFunction(expressions, this)
}

case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
override def dataType = child.dataType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
}

case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
override def dataType = IntegerType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD)
}

case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
override def dataType = IntegerType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"

override def asPartial: SplitEvaluation = {
val partialCount =
Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")()

SplitEvaluation(
ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD),
partialCount :: Nil)
}

override def newInstance() = new CountDistinctFunction(child :: Nil, this)
}

case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
Expand Down Expand Up @@ -269,6 +307,42 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: Row): Any = count
}

case class ApproxCountDistinctPartitionFunction(
expr: Expression,
base: AggregateExpression,
relativeSD: Double)
extends AggregateFunction {
def this() = this(null, null, 0) // Required for serialization.

private val hyperLogLog = new HyperLogLog(relativeSD)

override def update(input: Row): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
hyperLogLog.offer(evaluatedExpr)
}
}

override def eval(input: Row): Any = hyperLogLog
}

case class ApproxCountDistinctMergeFunction(
expr: Expression,
base: AggregateExpression,
relativeSD: Double)
extends AggregateFunction {
def this() = this(null, null, 0) // Required for serialization.

private val hyperLogLog = new HyperLogLog(relativeSD)

override def update(input: Row): Unit = {
val evaluatedExpr = expr.eval(input)
hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
}

override def eval(input: Row): Any = hyperLogLog.cardinality()
}

case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.nio.ByteBuffer

import scala.reflect.ClassTag

import com.clearspring.analytics.stream.cardinality.HyperLogLog
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Serializer, Kryo}

Expand All @@ -44,6 +45,8 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
new HyperLogLogSerializer)
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
kryo.setReferences(false)
Expand Down Expand Up @@ -81,6 +84,20 @@ private[sql] class BigDecimalSerializer extends Serializer[BigDecimal] {
}
}

private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) {
val bytes = hyperLogLog.getBytes()
output.writeInt(bytes.length)
output.writeBytes(bytes)
}

def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = {
val length = input.readInt()
val bytes = input.readBytes(length)
HyperLogLog.Builder.build(bytes)
}
}

/**
* Maps do not have a no arg constructor and so cannot be serialized by default. So, we serialize
* them as `Array[(k,v)]`.
Expand Down
21 changes: 19 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,25 @@ class SQLQuerySuite extends QueryTest {
test("count") {
checkAnswer(
sql("SELECT COUNT(*) FROM testData2"),
testData2.count()
)
testData2.count())
}

test("count distinct") {
checkAnswer(
sql("SELECT COUNT(DISTINCT b) FROM testData2"),
2)
}

test("approximate count distinct") {
checkAnswer(
sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"),
3)
}

test("approximate count distinct with user provided standard deviation") {
checkAnswer(
sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"),
3)
}

// No support for primitive nulls yet.
Expand Down