|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.expressions |
19 | 19 |
|
| 20 | +import com.clearspring.analytics.stream.cardinality.HyperLogLog |
| 21 | + |
20 | 22 | import org.apache.spark.sql.catalyst.types._ |
21 | 23 | import org.apache.spark.sql.catalyst.trees |
22 | 24 | import org.apache.spark.sql.catalyst.errors.TreeNodeException |
@@ -169,6 +171,44 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi |
169 | 171 | override def newInstance() = new CountDistinctFunction(expressions, this) |
170 | 172 | } |
171 | 173 |
|
| 174 | +case class ApproxCountDistinctPartition(child: Expression) |
| 175 | + extends AggregateExpression with trees.UnaryNode[Expression] { |
| 176 | + override def references = child.references |
| 177 | + override def nullable = false |
| 178 | + override def dataType = child.dataType |
| 179 | + override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" |
| 180 | + override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this) |
| 181 | +} |
| 182 | + |
| 183 | +case class ApproxCountDistinctMerge(child: Expression) |
| 184 | + extends AggregateExpression with trees.UnaryNode[Expression] { |
| 185 | + override def references = child.references |
| 186 | + override def nullable = false |
| 187 | + override def dataType = IntegerType |
| 188 | + override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" |
| 189 | + override def newInstance() = new ApproxCountDistinctMergeFunction(child, this) |
| 190 | +} |
| 191 | + |
| 192 | +object ApproxCountDistinct { |
| 193 | + val RelativeSD = 0.05 |
| 194 | +} |
| 195 | + |
| 196 | +case class ApproxCountDistinct(child: Expression) |
| 197 | + extends PartialAggregate with trees.UnaryNode[Expression] { |
| 198 | + override def references = child.references |
| 199 | + override def nullable = false |
| 200 | + override def dataType = IntegerType |
| 201 | + override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" |
| 202 | + |
| 203 | + override def asPartial: SplitEvaluation = { |
| 204 | + val partialCount = Alias(ApproxCountDistinctPartition(child), |
| 205 | + "PartialApproxCountDistinct")() |
| 206 | + SplitEvaluation(ApproxCountDistinctMerge(partialCount.toAttribute), partialCount :: Nil) |
| 207 | + } |
| 208 | + |
| 209 | + override def newInstance() = new CountDistinctFunction((child :: Nil), this) |
| 210 | +} |
| 211 | + |
172 | 212 | case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { |
173 | 213 | override def references = child.references |
174 | 214 | override def nullable = false |
@@ -268,6 +308,34 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag |
268 | 308 | override def eval(input: Row): Any = count |
269 | 309 | } |
270 | 310 |
|
| 311 | +case class ApproxCountDistinctPartitionFunction(expr: Expression, base: AggregateExpression) |
| 312 | + extends AggregateFunction { |
| 313 | + def this() = this(null, null) // Required for serialization. |
| 314 | + |
| 315 | + private val hyperLogLog = new HyperLogLog(ApproxCountDistinct.RelativeSD) |
| 316 | + |
| 317 | + override def update(input: Row): Unit = { |
| 318 | + val evaluatedExpr = expr.eval(input) |
| 319 | + Option(evaluatedExpr).foreach(hyperLogLog.offer(_)) |
| 320 | + } |
| 321 | + |
| 322 | + override def eval(input: Row): Any = hyperLogLog |
| 323 | +} |
| 324 | + |
| 325 | +case class ApproxCountDistinctMergeFunction(expr: Expression, base: AggregateExpression) |
| 326 | + extends AggregateFunction { |
| 327 | + def this() = this(null, null) // Required for serialization. |
| 328 | + |
| 329 | + private val hyperLogLog = new HyperLogLog(ApproxCountDistinct.RelativeSD) |
| 330 | + |
| 331 | + override def update(input: Row): Unit = { |
| 332 | + val evaluatedExpr = expr.eval(input) |
| 333 | + Option(evaluatedExpr.asInstanceOf[HyperLogLog]).foreach(hyperLogLog.addAll(_)) |
| 334 | + } |
| 335 | + |
| 336 | + override def eval(input: Row): Any = hyperLogLog.cardinality() |
| 337 | +} |
| 338 | + |
271 | 339 | case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { |
272 | 340 | def this() = this(null, null) // Required for serialization. |
273 | 341 |
|
|
0 commit comments