From cb19ad614b7b757b71e6995aa0c5d8295d8f2485 Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Sun, 10 Sep 2017 21:04:38 +0200 Subject: [PATCH] [SPARK-21896][SQL] Fix Stack Overflow when a window function is nested inside an aggregate function --- .../sql/catalyst/analysis/Analyzer.scala | 42 +++++++++++++++++-- .../execution/SQLWindowFunctionSuite.scala | 8 ++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0d5e866c0683e..c83d3efa8982d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1727,16 +1727,23 @@ class Analyzer( * it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { - private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = - projectList.exists(hasWindowFunction) + private def hasWindowFunction(exprs: Seq[Expression]): Boolean = + exprs.exists(hasWindowFunction) - private def hasWindowFunction(expr: NamedExpression): Boolean = { + private def hasWindowFunction(expr: Expression): Boolean = { expr.find { case window: WindowExpression => true case _ => false }.isDefined } + private def containsAggregateFunctionWithWindowExpression(exprs: Seq[Expression]): Boolean = { + exprs.exists(expr => expr.find { + case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) => true + case _ => false + }.isDefined) + } + /** * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and * other regular expressions that do not contain any window expression. For example, for @@ -1920,7 +1927,34 @@ class Analyzer( case p: LogicalPlan if !p.childrenResolved => p - // Aggregate without Having clause. + // Extract window expressions from aggregate functions. There might be an aggregate whose + // aggregate function contains a window expression as a child, which we need to extract. + // e.g., df.groupBy().agg(max(rank().over(window)) + case a @ Aggregate(groupingExprs, aggregateExprs, child) + if containsAggregateFunctionWithWindowExpression(aggregateExprs) && + a.expressions.forall(_.resolved) => + + val windowExprAliases = new ArrayBuffer[NamedExpression]() + val newAggregateExprs = aggregateExprs.map { expr => + expr.transform { + case aggExpr @ AggregateExpression(func, _, _, _) if hasWindowFunction(func.children) => + val newFuncChildren = func.children.map { funcExpr => + funcExpr.transform { + case we: WindowExpression => + // Replace window expressions with aliases to them + val windowExprAlias = Alias(we, s"_we${windowExprAliases.length}")() + windowExprAliases += windowExprAlias + windowExprAlias.toAttribute + } + } + val newFunc = func.withNewChildren(newFuncChildren).asInstanceOf[AggregateFunction] + aggExpr.copy(aggregateFunction = newFunc) + }.asInstanceOf[NamedExpression] + } + val window = addWindow(windowExprAliases, child) + // TODO do we also need a projection here? + Aggregate(groupingExprs, newAggregateExprs, window) + case a @ Aggregate(groupingExprs, aggregateExprs, child) if hasWindowFunction(aggregateExprs) && a.expressions.forall(_.resolved) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index 1c6fc3530cbe1..c89d1516e4903 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import org.apache.spark.TestUtils.assertSpilled import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions.{max, rank} import org.apache.spark.sql.test.SharedSQLContext case class WindowData(month: Int, area: String, product: Int) @@ -486,4 +488,10 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { spark.catalog.dropTempView("nums") } + + test("[SPARK-21896] nested window functions in aggregations") { + val df = Seq((1, 2), (1, 3), (2, 4), (5, 5)).toDF("a", "b") + val window = Window.orderBy('a) + checkAnswer(df.groupBy().agg(max(rank().over(window))), Seq(Row(4))) + } }