Skip to content

Commit 772e7c1

Browse files
yhuaimarmbrus
authored andcommitted
[SPARK-9592] [SQL] Fix Last function implemented based on AggregateExpression1.
https://issues.apache.org/jira/browse/SPARK-9592 #8113 has the fundamental fix. But, if we want to minimize the number of changed lines, we can go with this one. Then, in 1.6, we merge #8113. Author: Yin Huai <[email protected]> Closes #8172 from yhuai/lastFix and squashes the following commits: b28c42a [Yin Huai] Regression test. af87086 [Yin Huai] Fix last.
1 parent b265e28 commit 772e7c1

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression1) extends A
650650
var result: Any = null
651651

652652
override def update(input: InternalRow): Unit = {
653+
// We ignore null values.
653654
if (result == null) {
654655
result = expr.eval(input)
655656
}
@@ -679,10 +680,14 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag
679680
var result: Any = null
680681

681682
override def update(input: InternalRow): Unit = {
682-
result = input
683+
val value = expr.eval(input)
684+
// We ignore null values.
685+
if (value != null) {
686+
result = value
687+
}
683688
}
684689

685690
override def eval(input: InternalRow): Any = {
686-
if (result != null) expr.eval(result.asInstanceOf[InternalRow]) else null
691+
result
687692
}
688693
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
480480
Row(0, null, 1, 1, null, 0) :: Nil)
481481
}
482482

483+
test("test Last implemented based on AggregateExpression1") {
484+
// TODO: Remove this test once we remove AggregateExpression1.
485+
import org.apache.spark.sql.functions._
486+
val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1)
487+
withSQLConf(
488+
SQLConf.SHUFFLE_PARTITIONS.key -> "1",
489+
SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
490+
491+
checkAnswer(
492+
df.groupBy("i").agg(last("j")),
493+
df
494+
)
495+
}
496+
}
497+
483498
test("error handling") {
484499
withSQLConf("spark.sql.useAggregate2" -> "false") {
485500
val errorMessage = intercept[AnalysisException] {

0 commit comments

Comments
 (0)