Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1727,16 +1727,23 @@ class Analyzer(
* it into the plan tree.
*/
object ExtractWindowExpressions extends Rule[LogicalPlan] {
private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
projectList.exists(hasWindowFunction)
private def hasWindowFunction(exprs: Seq[Expression]): Boolean =
exprs.exists(hasWindowFunction)

private def hasWindowFunction(expr: NamedExpression): Boolean = {
private def hasWindowFunction(expr: Expression): Boolean = {
expr.find {
case window: WindowExpression => true
case _ => false
}.isDefined
}

private def containsAggregateFunctionWithWindowExpression(exprs: Seq[Expression]): Boolean = {
exprs.exists(expr => expr.find {
case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) => true
case _ => false
}.isDefined)
}

/**
* From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and
* other regular expressions that do not contain any window expression. For example, for
Expand Down Expand Up @@ -1920,7 +1927,34 @@ class Analyzer(

case p: LogicalPlan if !p.childrenResolved => p

// Aggregate without Having clause.
// Extract window expressions from aggregate functions. There might be an aggregate whose
// aggregate function contains a window expression as a child, which we need to extract.
// e.g., df.groupBy().agg(max(rank().over(window))
case a @ Aggregate(groupingExprs, aggregateExprs, child)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be sure: What happens if there is a window function on-top of the aggregate function? This gets resolved in two passes right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test case for this scenario?

if containsAggregateFunctionWithWindowExpression(aggregateExprs) &&
a.expressions.forall(_.resolved) =>

val windowExprAliases = new ArrayBuffer[NamedExpression]()
val newAggregateExprs = aggregateExprs.map { expr =>
expr.transform {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code below assumes that there are no window aggregates on top of a regular aggregate, and it will push the regular aggregate into the underlying window. An example of this:
df.groupBy(a).agg(max(rank().over(window1)), sum(sum(c)).over(window2))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this. I am not sure I fully understood "it will push the regular aggregate into the underlying window". Could you, please, elaborate?

This is what I tried:

    val df = Seq((1, 2), (1, 3), (2, 4), (5, 5)).toDF("a", "b")
    val window1 = Window.orderBy('a)
    val window2 = Window.orderBy('a.desc)

    df.groupBy('a).agg(max(rank().over(window1)), sum('b), sum(sum('b)).over(window2)).explain(true)
    df.groupBy('a).agg(max(rank().over(window1)), sum('b), sum(sum('b)).over(window2)).show(false)

It produced the following plans:

== Analyzed Logical Plan ==
a: int, max(RANK() OVER (ORDER BY a ASC NULLS FIRST unspecifiedframe$())): int, sum(b): bigint, sum(sum(b)) OVER (ORDER BY a DESC NULLS LAST unspecifiedframe$()): bigint
Project [a#5, max(RANK() OVER (ORDER BY a ASC NULLS FIRST unspecifiedframe$()))#19, sum(b)#20L, sum(sum(b)) OVER (ORDER BY a DESC NULLS LAST unspecifiedframe$())#21L]
+- Project [a#5, max(RANK() OVER (ORDER BY a ASC NULLS FIRST unspecifiedframe$()))#19, sum(b)#20L, _w0#40L, sum(sum(b)) OVER (ORDER BY a DESC NULLS LAST unspecifiedframe$())#21L, sum(sum(b)) OVER (ORDER BY a DESC NULLS LAST unspecifiedframe$())#21L]
   +- Window [sum(_w0#40L) windowspecdefinition(a#5 DESC NULLS LAST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS sum(sum(b)) OVER (ORDER BY a DESC NULLS LAST unspecifiedframe$())#21L], [a#5 DESC NULLS LAST]
      +- Aggregate [a#5], [a#5, max(_we0#36) AS max(RANK() OVER (ORDER BY a ASC NULLS FIRST unspecifiedframe$()))#19, sum(cast(b#6 as bigint)) AS sum(b)#20L, sum(cast(b#6 as bigint)) AS _w0#40L]
         +- Project [a#5, b#6, _we0#36, _we0#36]
            +- Window [rank(a#5) windowspecdefinition(a#5 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _we0#36], [a#5 ASC NULLS FIRST]
               +- Project [_1#2 AS a#5, _2#3 AS b#6]
                  +- LocalRelation [_1#2, _2#3]

== Optimized Logical Plan ==
Project [a#5, max(RANK() OVER (ORDER BY a ASC NULLS FIRST unspecifiedframe$()))#19, sum(b)#20L, sum(sum(b)) OVER (ORDER BY a DESC NULLS LAST unspecifiedframe$())#21L]
+- Window [sum(_w0#40L) windowspecdefinition(a#5 DESC NULLS LAST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS sum(sum(b)) OVER (ORDER BY a DESC NULLS LAST unspecifiedframe$())#21L], [a#5 DESC NULLS LAST]
   +- Aggregate [a#5], [a#5, max(_we0#36) AS max(RANK() OVER (ORDER BY a ASC NULLS FIRST unspecifiedframe$()))#19, sum(cast(b#6 as bigint)) AS sum(b)#20L, sum(cast(b#6 as bigint)) AS _w0#40L]
      +- Project [a#5, b#6, _we0#36, _we0#36]
         +- Window [rank(a#5) windowspecdefinition(a#5 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _we0#36], [a#5 ASC NULLS FIRST]
            +- LocalRelation [a#5, b#6]

The result was:

+---+-----------------------------------------------------------------+------+-----------------------------------------------------------------+
|a  |max(RANK() OVER (ORDER BY a ASC NULLS FIRST unspecifiedframe$()))|sum(b)|sum(sum(b)) OVER (ORDER BY a DESC NULLS LAST unspecifiedframe$())|
+---+-----------------------------------------------------------------+------+-----------------------------------------------------------------+
|5  |4                                                                |5     |5                                                                |
|2  |3                                                                |4     |9                                                                |
|1  |1                                                                |5     |14                                                               |
+---+-----------------------------------------------------------------+------+-----------------------------------------------------------------+

So, we have a window expression on top of a regular aggregate in sum(sum(c)).over(window2), right? This expression is handled by the existing part and is not touched by the new case clause.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, do you actually mean smth like this?

    val df = Seq((1, 2), (1, 3), (2, 4), (5, 5)).toDF("a", "b")
    val window1 = Window.orderBy('a)
    df.groupBy('a).agg(max(sum(sum('b)).over(window1))).explain(true)
    df.groupBy('a).agg(max(sum(sum('b)).over(window1))).show(false)
== Analyzed Logical Plan ==
a: int, max(sum(sum(b)) OVER (ORDER BY a ASC NULLS FIRST unspecifiedframe$())): bigint
Aggregate [a#5], [a#5, max(_we0#22L) AS max(sum(sum(b)) OVER (ORDER BY a ASC NULLS FIRST unspecifiedframe$()))#16L]
+- Project [a#5, b#6, _we0#22L, _we0#22L]
   +- Window [sum(sum(cast(b#6 as bigint))) windowspecdefinition(a#5 ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS _we0#22L], [a#5 ASC NULLS FIRST]
      +- Project [_1#2 AS a#5, _2#3 AS b#6]
         +- LocalRelation [_1#2, _2#3]

== Optimized Logical Plan ==
Aggregate [a#5], [a#5, max(_we0#22L) AS max(sum(sum(b)) OVER (ORDER BY a ASC NULLS FIRST unspecifiedframe$()))#16L]
+- Project [a#5, _we0#22L, _we0#22L]
   +- Window [sum(sum(cast(b#6 as bigint))) windowspecdefinition(a#5 ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS _we0#22L], [a#5 ASC NULLS FIRST]
      +- LocalRelation [a#5, b#6]

case aggExpr @ AggregateExpression(func, _, _, _) if hasWindowFunction(func.children) =>
val newFuncChildren = func.children.map { funcExpr =>
funcExpr.transform {
case we: WindowExpression =>
// Replace window expressions with aliases to them
val windowExprAlias = Alias(we, s"_we${windowExprAliases.length}")()
windowExprAliases += windowExprAlias
windowExprAlias.toAttribute
}
}
val newFunc = func.withNewChildren(newFuncChildren).asInstanceOf[AggregateFunction]
aggExpr.copy(aggregateFunction = newFunc)
}.asInstanceOf[NamedExpression]
}
val window = addWindow(windowExprAliases, child)
// TODO do we also need a projection here?
Aggregate(groupingExprs, newAggregateExprs, window)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you don't need a Project.


case a @ Aggregate(groupingExprs, aggregateExprs, child)
if hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution

import org.apache.spark.TestUtils.assertSpilled
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{max, rank}
import org.apache.spark.sql.test.SharedSQLContext

case class WindowData(month: Int, area: String, product: Int)
Expand Down Expand Up @@ -486,4 +488,10 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {

spark.catalog.dropTempView("nums")
}

test("[SPARK-21896] nested window functions in aggregations") {
val df = Seq((1, 2), (1, 3), (2, 4), (5, 5)).toDF("a", "b")
val window = Window.orderBy('a)
checkAnswer(df.groupBy().agg(max(rank().over(window))), Seq(Row(4)))
}
}