Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 20b65b7

Browse files
committed
Handle expressions containing multiple window expressions.
1 parent 9568b21 commit 20b65b7

File tree

1 file changed

+67
-24
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis

1 file changed

+67
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -644,14 +644,16 @@ class Analyzer(
644644
}
645645

646646
/**
647-
* From a Seq of [[NamedExpression]]s, extract window expressions and
648-
* other regular expressions.
647+
* From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and
648+
* other regular expressions that do not contain any window expression.
649649
*/
650-
def extract(
650+
def extractRegularExpressions(
651651
expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = {
652-
// First, we simple partition the input expressions to two part, one having
653-
// WindowExpressions and another one without WindowExpressions.
654-
val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction)
652+
// First, we partition the input expressions to two part. For the first part,
653+
// every expression in it contain at least one WindowExpression.
654+
// Expressions in the second part do not have any WindowExpression.
655+
val (expressionsWithWindowFunctions, regularExpressions) =
656+
expressions.partition(hasWindowFunction)
655657

656658
// Then, we need to extract those regular expressions used in the WindowExpression.
657659
// For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5),
@@ -678,8 +680,8 @@ class Analyzer(
678680
withName.toAttribute
679681
}
680682

681-
// Now, we extract expressions from windowExpressions by using extractExpr.
682-
val newWindowExpressions = windowExpressions.map {
683+
// Now, we extract regular expressions from expressionsWithWindowFunctions by using extractExpr.
684+
val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
683685
_.transform {
684686
// Extracts children expressions of a WindowFunction (input parameters of
685687
// a WindowFunction).
@@ -705,36 +707,77 @@ class Analyzer(
705707
}.asInstanceOf[NamedExpression]
706708
}
707709

708-
(newWindowExpressions, regularExpressions ++ extractedExprBuffer)
710+
(newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer)
709711
}
710712

711713
/**
712714
* Adds operators for Window Expressions. Every Window operator handles a single Window Spec.
713715
*/
714-
def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = {
715-
// First, we group window expressions based on their Window Spec.
716-
val groupedWindowExpression = windowExpressions.groupBy { expr =>
717-
val windowSpec = expr.collectFirst {
716+
def addWindow(
717+
expressionsWithWindowFunctions: Seq[NamedExpression],
718+
child: LogicalPlan): LogicalPlan = {
719+
// First, we need to extract all WindowExpressions from expressionsWithWindowFunctions
720+
// and put those extracted WindowExpressions to extractedWindowExprBuffer.
721+
// This step is needed because it is possible that an expression contains multiple
722+
// WindowExpressions with different Window Specs.
723+
// After extracting WindowExpressions, we need to construct a project list to generate
724+
// expressionsWithWindowFunctions based on extractedWindowExprBuffer.
725+
// For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract
726+
// "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to
727+
// "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)".
728+
// Then, the projectList will be [_we0/_we1].
729+
val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]()
730+
val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
731+
// We need to use transformDown because we want to trigger
732+
// "case alias @ Alias(window: WindowExpression, _)" first.
733+
_.transformDown {
734+
case alias @ Alias(window: WindowExpression, _) =>
735+
// If a WindowExpression has an assigned alias, just use it.
736+
extractedWindowExprBuffer += alias
737+
alias.toAttribute
738+
case window: WindowExpression =>
739+
// If there is no alias assigned to the WindowExpressions. We create an
740+
// internal column.
741+
val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")()
742+
extractedWindowExprBuffer += withName
743+
withName.toAttribute
744+
}.asInstanceOf[NamedExpression]
745+
}
746+
747+
// Second, we group extractedWindowExprBuffer based on their Window Spec.
748+
val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr =>
749+
val distinctWindowSpec = expr.collect {
718750
case window: WindowExpression => window.windowSpec
751+
}.distinct
752+
753+
// We do a final check and see if we only have a single Window Spec defined in an
754+
// expressions.
755+
if (distinctWindowSpec.length == 0 ) {
756+
failAnalysis(s"$expr does not have any WindowExpression.")
757+
} else if (distinctWindowSpec.length > 1) {
758+
failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." +
759+
s"Please file a bug report with this error message, stack trace, and the query.")
760+
} else {
761+
distinctWindowSpec.head
719762
}
720-
windowSpec.getOrElse(
721-
failAnalysis(s"$windowExpressions does not have any WindowExpression."))
722763
}.toSeq
723764

724-
// For every Window Spec, we add a Window operator and set currentChild as the child of it.
765+
// Third, for every Window Spec, we add a Window operator and set currentChild as the
766+
// child of it.
725767
var currentChild = child
726768
var i = 0
727-
while (i < groupedWindowExpression.size) {
728-
val (windowSpec, windowExpressions) = groupedWindowExpression(i)
769+
while (i < groupedWindowExpressions.size) {
770+
val (windowSpec, windowExpressions) = groupedWindowExpressions(i)
729771
// Set currentChild to the newly created Window operator.
730772
currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild)
731773

732-
// Move to next WindowExpression.
774+
// Move to next Window Spec.
733775
i += 1
734776
}
735777

736-
// We return the top operator.
737-
currentChild
778+
// Finally, we create a Project to output currentChild's output
779+
// newExpressionsWithWindowFunctions.
780+
Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild)
738781
}
739782

740783
// We have to use transformDown at here to make sure the rule of
@@ -746,7 +789,7 @@ class Analyzer(
746789
if child.resolved &&
747790
hasWindowFunction(aggregateExprs) &&
748791
a.expressions.forall(_.resolved) =>
749-
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
792+
val (windowExpressions, aggregateExpressions) = extractRegularExpressions(aggregateExprs)
750793
// Create an Aggregate operator to evaluate aggregation functions.
751794
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
752795
// Add a Filter operator for conditions in the Having clause.
@@ -763,7 +806,7 @@ class Analyzer(
763806
case a @ Aggregate(groupingExprs, aggregateExprs, child)
764807
if hasWindowFunction(aggregateExprs) &&
765808
a.expressions.forall(_.resolved) =>
766-
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
809+
val (windowExpressions, aggregateExpressions) = extractRegularExpressions(aggregateExprs)
767810
// Create an Aggregate operator to evaluate aggregation functions.
768811
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
769812
// Add Window operators.
@@ -777,7 +820,7 @@ class Analyzer(
777820
// have been resolved.
778821
case p @ Project(projectList, child)
779822
if hasWindowFunction(projectList) && !p.expressions.exists(!_.resolved) =>
780-
val (windowExpressions, regularExpressions) = extract(projectList)
823+
val (windowExpressions, regularExpressions) = extractRegularExpressions(projectList)
781824
// We add a project to get all needed expressions for window expressions from the child
782825
// of the original Project operator.
783826
val withProject = Project(regularExpressions, child)

0 commit comments

Comments
 (0)