From 2afa2386b0f59d1d50f11de2d444f191335867ef Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 23 May 2014 19:15:44 +0900 Subject: [PATCH] Simplify CountFunction not to traverse to evaluate all child expressions. --- .../apache/spark/sql/catalyst/expressions/aggregates.scala | 4 ++-- .../src/test/scala/org/apache/spark/sql/DslQuerySuite.scala | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5dbaaa3b0ce35..a1e61f80bc692 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -298,8 +298,8 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag var count: Int = _ override def update(input: Row): Unit = { - val evaluatedExpr = expr.map(_.eval(input)) - if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { count += 1 } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index f43e98d614094..cb5387c1255d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -120,6 +120,11 @@ class DslQuerySuite extends QueryTest { Seq((1,0), (2, 1)) ) + checkAnswer( + testData3.groupBy('a)('a, Count('a + 'b)), + Seq((1,0), (2, 1)) + ) + checkAnswer( testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)), (2, 1, 2, 2, 1) :: Nil