Skip to content

Commit cf6c7a0

Browse files
committed
Implementation.
1 parent aa2b0ae commit cf6c7a0

File tree

14 files changed

+1496
-34
lines changed

14 files changed

+1496
-34
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20+
import scala.collection.mutable.ArrayBuffer
21+
2022
import org.apache.spark.util.collection.OpenHashSet
2123
import org.apache.spark.sql.AnalysisException
2224
import org.apache.spark.sql.catalyst.expressions._
@@ -61,6 +63,7 @@ class Analyzer(
6163
ResolveGenerate ::
6264
ImplicitGenerate ::
6365
ResolveFunctions ::
66+
ResolveWindowFunction ::
6467
GlobalAggregates ::
6568
UnresolvedHavingClauseAttributes ::
6669
TrimGroupingAliases ::
@@ -529,6 +532,165 @@ class Analyzer(
529532
makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
530533
}
531534
}
535+
536+
object ResolveWindowFunction extends Rule[LogicalPlan] {
537+
def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
538+
projectList.exists(hasWindowFunction)
539+
540+
def hasWindowFunction(expr: NamedExpression): Boolean = {
541+
expr.find {
542+
case window: WindowExpression => true
543+
case _ => false
544+
}.isDefined
545+
}
546+
547+
/**
548+
* From a Seq of [[NamedExpression]]s, extract window expressions and
549+
* other regular expressions.
550+
*/
551+
def extract(
552+
expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = {
553+
val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction)
554+
// Extract expressions which in windowExpressions but not in regularExpressions.
555+
val extractedExprBuffer = new ArrayBuffer[NamedExpression]()
556+
def extractExpr(expr: Expression): Expression = expr match {
557+
case ne: NamedExpression =>
558+
// If a named expression is not in regularExpressions, add iut
559+
val missingExpr =
560+
AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer)
561+
if (missingExpr.nonEmpty) {
562+
extractedExprBuffer += ne
563+
}
564+
ne.toAttribute
565+
case e: Expression if e.foldable =>
566+
e // No need to create an attribute reference if it will be evaluated as a Literal.
567+
case e: Expression =>
568+
val withName = Alias(e, s"_w${extractedExprBuffer.length}")()
569+
extractedExprBuffer += withName
570+
withName.toAttribute
571+
}
572+
573+
val newWindowExpressions = windowExpressions.map {
574+
_.transform {
575+
case wf : WindowFunction =>
576+
// Extracts children expressions of a WindowFunction.
577+
val newChildren = wf.children.map(extractExpr(_))
578+
wf.withNewChildren(newChildren)
579+
case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) =>
580+
// Extracts expressions from the partition spec and order spec.
581+
val newPartitionSpec = partitionSpec.map(extractExpr(_))
582+
val newOrderSpec = orderSpec.map { so =>
583+
val newChild = extractExpr(so.child)
584+
so.copy(child = newChild)
585+
}
586+
wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec)
587+
case agg: AggregateExpression =>
588+
// We also need to take care aggregate expressions.
589+
val withName = Alias(agg, s"_w${extractedExprBuffer.length}")()
590+
extractedExprBuffer += withName
591+
withName.toAttribute
592+
}.asInstanceOf[NamedExpression]
593+
}
594+
(newWindowExpressions, regularExpressions ++ extractedExprBuffer)
595+
}
596+
597+
/**
598+
* Add operators for Window Functions. Every Window operator handle a single Window Spec.
599+
*/
600+
def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = {
601+
// First, we group window expressions based on their Window Spec.
602+
val groupedWindowExpression = windowExpressions.groupBy { expr =>
603+
val windowExpression = expr.find {
604+
case window: WindowExpression => true
605+
case other => false
606+
}.map(_.asInstanceOf[WindowExpression].windowSpec)
607+
windowExpression.getOrElse(
608+
failAnalysis(s"$windowExpressions does not have any WindowExpression."))
609+
}.toSeq
610+
611+
// For every Window Spec, add a Window Operator.
612+
var currentChild = child
613+
var i = 0
614+
while (i < groupedWindowExpression.size) {
615+
val (windowSpec, windowExpressions) = groupedWindowExpression(i)
616+
currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild)
617+
618+
i += 1
619+
}
620+
621+
currentChild
622+
}
623+
624+
/**
625+
* We have to use transformDown at here to make sure the rule of
626+
* "Aggregate with Having clause" will be triggered.
627+
*/
628+
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
629+
// Fill WindowSpecDefinitions. This one work with unresolved children.
630+
case WithWindowDefinition(windowDefinitions, child) =>
631+
child.transform {
632+
case plan => plan.transformExpressions {
633+
case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) =>
634+
val errorMessage =
635+
s"Window specification $windowName is not defined in the WINDOW clause."
636+
val windowSpecDefinition =
637+
windowDefinitions
638+
.get(windowName)
639+
.getOrElse(failAnalysis(errorMessage))
640+
WindowExpression(c, windowSpecDefinition)
641+
}
642+
}
643+
644+
// Aggregate with Having clause
645+
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
646+
if child.resolved &&
647+
hasWindowFunction(aggregateExprs) &&
648+
!a.expressions.exists(!_.resolved) =>
649+
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
650+
// Create an Aggregate operator to evaluate aggregation functions.
651+
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
652+
// Add a Filter operator for conditions in the Having clause.
653+
val withFilter = Filter(condition, withAggregate)
654+
val withWindow = addWindow(windowExpressions, withFilter)
655+
656+
// Finally, generate output columns according to the original projectList.
657+
val finalProjectList = aggregateExprs.map (_.toAttribute)
658+
Project(finalProjectList, withWindow)
659+
660+
case p: LogicalPlan if !p.childrenResolved => p
661+
662+
// Aggregate without Having clause
663+
case a @ Aggregate(groupingExprs, aggregateExprs, child)
664+
if hasWindowFunction(aggregateExprs) &&
665+
!a.expressions.exists(!_.resolved) =>
666+
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
667+
668+
// Create an Aggregate operator to evaluate aggregation functions.
669+
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
670+
// Add Window operators.
671+
val withWindow = addWindow(windowExpressions, withAggregate)
672+
673+
// Finally, generate output columns according to the original projectList.
674+
val finalProjectList = aggregateExprs.map (_.toAttribute)
675+
Project(finalProjectList, withWindow)
676+
677+
// We only extract Window Expressions after all expressions of the Project
678+
// have been resolved.
679+
case p @ Project(projectList, child)
680+
if hasWindowFunction(projectList) && !p.expressions.exists(!_.resolved) =>
681+
val (windowExpressions, regularExpressions) = extract(projectList)
682+
683+
// We add a project to get all needed expressions of window expressions in the
684+
// original projectList.
685+
val withProject = Project(regularExpressions, child)
686+
// Add Window operators.
687+
val withWindow = addWindow(windowExpressions, withProject)
688+
689+
// Finally, generate output columns according to the original projectList.
690+
val finalProjectList = projectList.map (_.toAttribute)
691+
Project(finalProjectList, withWindow)
692+
}
693+
}
532694
}
533695

534696
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ trait CheckAnalysis {
7070
failAnalysis(
7171
s"invalid expression ${b.prettyString} " +
7272
s"between ${b.left.simpleString} and ${b.right.simpleString}")
73+
74+
case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty =>
75+
// The window spec is not valid.
76+
val reason = windowSpec.validate.get
77+
failAnalysis(s"Window specification $windowSpec is not valid because $reason")
7378
}
7479

7580
operator match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,3 +548,97 @@ class JoinedRow5 extends Row {
548548
}
549549
}
550550
}
551+
552+
/**
553+
* JIT HACK: Replace with macros
554+
*/
555+
class JoinedRow6 extends Row {
556+
private[this] var row1: Row = _
557+
private[this] var row2: Row = _
558+
559+
def this(left: Row, right: Row) = {
560+
this()
561+
row1 = left
562+
row2 = right
563+
}
564+
565+
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
566+
def apply(r1: Row, r2: Row): Row = {
567+
row1 = r1
568+
row2 = r2
569+
this
570+
}
571+
572+
/** Updates this JoinedRow by updating its left base row. Returns itself. */
573+
def withLeft(newLeft: Row): Row = {
574+
row1 = newLeft
575+
this
576+
}
577+
578+
/** Updates this JoinedRow by updating its right base row. Returns itself. */
579+
def withRight(newRight: Row): Row = {
580+
row2 = newRight
581+
this
582+
}
583+
584+
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
585+
586+
override def length: Int = row1.length + row2.length
587+
588+
override def apply(i: Int): Any =
589+
if (i < row1.length) row1(i) else row2(i - row1.length)
590+
591+
override def isNullAt(i: Int): Boolean =
592+
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
593+
594+
override def getInt(i: Int): Int =
595+
if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length)
596+
597+
override def getLong(i: Int): Long =
598+
if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length)
599+
600+
override def getDouble(i: Int): Double =
601+
if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length)
602+
603+
override def getBoolean(i: Int): Boolean =
604+
if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length)
605+
606+
override def getShort(i: Int): Short =
607+
if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length)
608+
609+
override def getByte(i: Int): Byte =
610+
if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length)
611+
612+
override def getFloat(i: Int): Float =
613+
if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length)
614+
615+
override def getString(i: Int): String =
616+
if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length)
617+
618+
override def getAs[T](i: Int): T =
619+
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
620+
621+
override def copy(): Row = {
622+
val totalSize = row1.length + row2.length
623+
val copiedValues = new Array[Any](totalSize)
624+
var i = 0
625+
while(i < totalSize) {
626+
copiedValues(i) = apply(i)
627+
i += 1
628+
}
629+
new GenericRow(copiedValues)
630+
}
631+
632+
override def toString: String = {
633+
// Make sure toString never throws NullPointerException.
634+
if ((row1 eq null) && (row2 eq null)) {
635+
"[ empty row ]"
636+
} else if (row1 eq null) {
637+
row2.mkString("[", ",", "]")
638+
} else if (row2 eq null) {
639+
row1.mkString("[", ",", "]")
640+
} else {
641+
mkString("[", ",", "]")
642+
}
643+
}
644+
}

0 commit comments

Comments
 (0)