Skip to content

Commit bd8ef3f

Browse files
larvaboyLiuyang Li
authored andcommitted
Add support of user-provided standard deviation to ApproxCountDistinct.
1 parent 9ba8360 commit bd8ef3f

File tree

3 files changed

+35
-18
lines changed

3 files changed

+35
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
319319
COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } |
320320
COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } |
321321
COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
322-
APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => ApproxCountDistinct(exp) } |
322+
APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ {
323+
case exp => ApproxCountDistinct(exp)
324+
} |
325+
APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ {
326+
case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble)
327+
} |
323328
FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
324329
AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
325330
MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -171,38 +171,38 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
171171
override def newInstance() = new CountDistinctFunction(expressions, this)
172172
}
173173

174-
case class ApproxCountDistinctPartition(child: Expression)
174+
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
175175
extends AggregateExpression with trees.UnaryNode[Expression] {
176176
override def references = child.references
177177
override def nullable = false
178178
override def dataType = child.dataType
179179
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
180-
override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this)
180+
override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
181181
}
182182

183-
case class ApproxCountDistinctMerge(child: Expression)
183+
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
184184
extends AggregateExpression with trees.UnaryNode[Expression] {
185185
override def references = child.references
186186
override def nullable = false
187187
override def dataType = IntegerType
188188
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
189-
override def newInstance() = new ApproxCountDistinctMergeFunction(child, this)
189+
override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD)
190190
}
191191

192-
object ApproxCountDistinct {
193-
val RelativeSD = 0.05
194-
}
195-
196-
case class ApproxCountDistinct(child: Expression)
192+
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
197193
extends PartialAggregate with trees.UnaryNode[Expression] {
198194
override def references = child.references
199195
override def nullable = false
200196
override def dataType = IntegerType
201197
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
202198

203199
override def asPartial: SplitEvaluation = {
204-
val partialCount = Alias(ApproxCountDistinctPartition(child), "PartialApproxCountDistinct")()
205-
SplitEvaluation(ApproxCountDistinctMerge(partialCount.toAttribute), partialCount :: Nil)
200+
val partialCount =
201+
Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")()
202+
203+
SplitEvaluation(
204+
ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD),
205+
partialCount :: Nil)
206206
}
207207

208208
override def newInstance() = new CountDistinctFunction(child :: Nil, this)
@@ -307,11 +307,14 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
307307
override def eval(input: Row): Any = count
308308
}
309309

310-
case class ApproxCountDistinctPartitionFunction(expr: Expression, base: AggregateExpression)
310+
case class ApproxCountDistinctPartitionFunction(
311+
expr: Expression,
312+
base: AggregateExpression,
313+
relativeSD: Double)
311314
extends AggregateFunction {
312-
def this() = this(null, null) // Required for serialization.
315+
def this() = this(null, null, 0) // Required for serialization.
313316

314-
private val hyperLogLog = new HyperLogLog(ApproxCountDistinct.RelativeSD)
317+
private val hyperLogLog = new HyperLogLog(relativeSD)
315318

316319
override def update(input: Row): Unit = {
317320
val evaluatedExpr = expr.eval(input)
@@ -323,11 +326,14 @@ case class ApproxCountDistinctPartitionFunction(expr: Expression, base: Aggregat
323326
override def eval(input: Row): Any = hyperLogLog
324327
}
325328

326-
case class ApproxCountDistinctMergeFunction(expr: Expression, base: AggregateExpression)
329+
case class ApproxCountDistinctMergeFunction(
330+
expr: Expression,
331+
base: AggregateExpression,
332+
relativeSD: Double)
327333
extends AggregateFunction {
328-
def this() = this(null, null) // Required for serialization.
334+
def this() = this(null, null, 0) // Required for serialization.
329335

330-
private val hyperLogLog = new HyperLogLog(ApproxCountDistinct.RelativeSD)
336+
private val hyperLogLog = new HyperLogLog(relativeSD)
331337

332338
override def update(input: Row): Unit = {
333339
val evaluatedExpr = expr.eval(input)

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ class SQLQuerySuite extends QueryTest {
111111
3)
112112
}
113113

114+
test("approximate count distinct with user provided standard deviation") {
115+
checkAnswer(
116+
sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"),
117+
3)
118+
}
119+
114120
// No support for primitive nulls yet.
115121
ignore("null count") {
116122
checkAnswer(

0 commit comments

Comments
 (0)