From 646a96d75c12b2e3c6886bc0cc9743e7ba838c8a Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Thu, 31 May 2018 12:58:29 +0200 Subject: [PATCH 1/3] [SPARK-21896][SQL] Fix StackOverflow caused by window functions inside aggregate functions --- .../sql/catalyst/analysis/Analyzer.scala | 9 ++++-- .../spark/sql/DataFrameAggregateSuite.scala | 29 +++++++++++++++++-- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3eaa9ecf5d075..cd657dbd7a3d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1744,11 +1744,14 @@ class Analyzer( * it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { - private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = - projectList.exists(hasWindowFunction) + private def hasWindowFunction(exprs: Seq[Expression]): Boolean = + exprs.exists(hasWindowFunction) - private def hasWindowFunction(expr: NamedExpression): Boolean = { + private def hasWindowFunction(expr: Expression): Boolean = { expr.find { + case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) => + failAnalysis("It is not allowed to use a window function inside an aggregate function. " + + "Please use the inner window function in a sub-query.") case window: WindowExpression => true case _ => false }.isDefined diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 96c28961e5aaf..b23a9e71bc2e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql import scala.util.Random -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.scalatest.Matchers.the + import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -687,4 +687,29 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-21896: Window functions inside aggregate functions") { + def checkWindowError(df: => DataFrame): Unit = { + val thrownException = the [AnalysisException] thrownBy { + df.queryExecution.analyzed + } + assert(thrownException.message.contains("not allowed to use a window function")) + } + + checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a))))) + checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a))))) + checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b))))) + checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('b))))) + + checkWindowError( + sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3")) + checkWindowError( + sql("SELECT MAX(RANK() OVER(ORDER BY a)) FROM testData2")) + checkWindowError( + sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"), + Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil) + } + } From 185aba12a66db07fcc1a88a4197b59786def5182 Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Sat, 2 Jun 2018 12:11:04 +0200 Subject: [PATCH 2/3] Add test cases and move the check --- .../spark/sql/catalyst/analysis/Analyzer.scala | 7 ++++--- .../spark/sql/DataFrameAggregateSuite.scala | 17 +++++++++++++---- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cd657dbd7a3d6..f9947d1fa6c78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1749,9 +1749,6 @@ class Analyzer( private def hasWindowFunction(expr: Expression): Boolean = { expr.find { - case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) => - failAnalysis("It is not allowed to use a window function inside an aggregate function. " + - "Please use the inner window function in a sub-query.") case window: WindowExpression => true case _ => false }.isDefined @@ -1833,6 +1830,10 @@ class Analyzer( seenWindowAggregates += newAgg WindowExpression(newAgg, spec) + case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) => + failAnalysis("It is not allowed to use a window function inside an aggregate " + + "function. Please use the inner window function in a sub-query.") + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b23a9e71bc2e0..9e2f93ecfda98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -699,14 +699,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a))))) checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a))))) checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b))))) - checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('b))))) + checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('a))))) + checkWindowError( + testData2.groupBy('a).agg(sum('b).as("s"), max(count("*").over())).where('s === 3)) + checkAnswer( + testData2.groupBy('a).agg(max('b), sum('b).as("s"), count("*").over()).where('s === 3), + Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil) checkWindowError( - sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3")) + sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2")) checkWindowError( - sql("SELECT MAX(RANK() OVER(ORDER BY a)) FROM testData2")) + sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2")) checkWindowError( - sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a")) + sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a")) + checkWindowError( + sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a")) + checkWindowError( + sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3")) checkAnswer( sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"), Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil) From d06f282541930ea1b6bbb3176bf1f3e67664c05f Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Sat, 2 Jun 2018 12:24:15 +0200 Subject: [PATCH 3/3] Inline some calls --- .../apache/spark/sql/DataFrameAggregateSuite.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 9e2f93ecfda98..f495a949ebc5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -706,14 +706,10 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { testData2.groupBy('a).agg(max('b), sum('b).as("s"), count("*").over()).where('s === 3), Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil) - checkWindowError( - sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2")) - checkWindowError( - sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2")) - checkWindowError( - sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a")) - checkWindowError( - sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a")) + checkWindowError(sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2")) + checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2")) + checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a")) + checkWindowError(sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a")) checkWindowError( sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3")) checkAnswer(