Skip to content

Commit 646a96d

Browse files
author
aokolnychyi
committed
[SPARK-21896][SQL] Fix StackOverflow caused by window functions inside aggregate functions
1 parent 0d89943 commit 646a96d

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,11 +1744,14 @@ class Analyzer(
17441744
* it into the plan tree.
17451745
*/
17461746
object ExtractWindowExpressions extends Rule[LogicalPlan] {
1747-
private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
1748-
projectList.exists(hasWindowFunction)
1747+
private def hasWindowFunction(exprs: Seq[Expression]): Boolean =
1748+
exprs.exists(hasWindowFunction)
17491749

1750-
private def hasWindowFunction(expr: NamedExpression): Boolean = {
1750+
private def hasWindowFunction(expr: Expression): Boolean = {
17511751
expr.find {
1752+
case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) =>
1753+
failAnalysis("It is not allowed to use a window function inside an aggregate function. " +
1754+
"Please use the inner window function in a sub-query.")
17521755
case window: WindowExpression => true
17531756
case _ => false
17541757
}.isDefined

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ package org.apache.spark.sql
1919

2020
import scala.util.Random
2121

22-
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
23-
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
22+
import org.scalatest.Matchers.the
23+
2424
import org.apache.spark.sql.execution.WholeStageCodegenExec
2525
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
2626
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -687,4 +687,29 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
687687
}
688688
}
689689
}
690+
691+
test("SPARK-21896: Window functions inside aggregate functions") {
692+
def checkWindowError(df: => DataFrame): Unit = {
693+
val thrownException = the [AnalysisException] thrownBy {
694+
df.queryExecution.analyzed
695+
}
696+
assert(thrownException.message.contains("not allowed to use a window function"))
697+
}
698+
699+
checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a)))))
700+
checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a)))))
701+
checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b)))))
702+
checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('b)))))
703+
704+
checkWindowError(
705+
sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3"))
706+
checkWindowError(
707+
sql("SELECT MAX(RANK() OVER(ORDER BY a)) FROM testData2"))
708+
checkWindowError(
709+
sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a"))
710+
checkAnswer(
711+
sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"),
712+
Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil)
713+
}
714+
690715
}

0 commit comments

Comments
 (0)