Skip to content

Commit d6395d8

Browse files
ueshinrxin
authored andcommitted
[SPARK-1914] [SQL] Simplify CountFunction not to traverse to evaluate all child expressions.
`CountFunction` should count up only if the child's evaluated value is not null. Because it traverses to evaluate all child expressions, even if the child is null, it counts up if one of the all children is not null. Author: Takuya UESHIN <[email protected]> Closes #861 from ueshin/issues/SPARK-1914 and squashes the following commits: 3b37315 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-1914 2afa238 [Takuya UESHIN] Simplify CountFunction not to traverse to evaluate all child expressions.
1 parent b6d22af commit d6395d8

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,8 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
298298
var count: Long = _
299299

300300
override def update(input: Row): Unit = {
301-
val evaluatedExpr = expr.map(_.eval(input))
302-
if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) {
301+
val evaluatedExpr = expr.eval(input)
302+
if (evaluatedExpr != null) {
303303
count += 1L
304304
}
305305
}

sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ class DslQuerySuite extends QueryTest {
125125
Seq((1,0), (2, 1))
126126
)
127127

128+
checkAnswer(
129+
testData3.groupBy('a)('a, Count('a + 'b)),
130+
Seq((1,0), (2, 1))
131+
)
132+
128133
checkAnswer(
129134
testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)),
130135
(2, 1, 2, 2, 1) :: Nil

0 commit comments

Comments
 (0)