Skip to content

Commit 5784c8d

Browse files
yhuaimarmbrus
authored andcommitted
[SPARK-1442] [SQL] [FOLLOW-UP] Address minor comments in Window Function PR (#5604).
Address marmbrus and scwf's comments in #5604. Author: Yin Huai <[email protected]> Closes #5945 from yhuai/windowFollowup and squashes the following commits: 0ef879d [Yin Huai] Add collectFirst to TreeNode. 2373968 [Yin Huai] wip 4a16df9 [Yin Huai] Address minor comments for [SPARK-1442].
1 parent 1712a7c commit 5784c8d

File tree

3 files changed

+68
-8
lines changed

3 files changed

+68
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -638,11 +638,10 @@ class Analyzer(
638638
def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = {
639639
// First, we group window expressions based on their Window Spec.
640640
val groupedWindowExpression = windowExpressions.groupBy { expr =>
641-
val windowExpression = expr.find {
642-
case window: WindowExpression => true
643-
case other => false
644-
}.map(_.asInstanceOf[WindowExpression].windowSpec)
645-
windowExpression.getOrElse(
641+
val windowSpec = expr.collectFirst {
642+
case window: WindowExpression => window.windowSpec
643+
}
644+
windowSpec.getOrElse(
646645
failAnalysis(s"$windowExpressions does not have any WindowExpression."))
647646
}.toSeq
648647

@@ -685,7 +684,7 @@ class Analyzer(
685684
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
686685
if child.resolved &&
687686
hasWindowFunction(aggregateExprs) &&
688-
!a.expressions.exists(!_.resolved) =>
687+
a.expressions.forall(_.resolved) =>
689688
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
690689
// Create an Aggregate operator to evaluate aggregation functions.
691690
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
@@ -702,7 +701,7 @@ class Analyzer(
702701
// Aggregate without Having clause.
703702
case a @ Aggregate(groupingExprs, aggregateExprs, child)
704703
if hasWindowFunction(aggregateExprs) &&
705-
!a.expressions.exists(!_.resolved) =>
704+
a.expressions.forall(_.resolved) =>
706705
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
707706
// Create an Aggregate operator to evaluate aggregation functions.
708707
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
130130
ret
131131
}
132132

133+
/**
134+
* Finds and returns the first [[TreeNode]] of the tree for which the given partial function
135+
* is defined (pre-order), and applies the partial function to it.
136+
*/
137+
def collectFirst[B](pf: PartialFunction[BaseType, B]): Option[B] = {
138+
val lifted = pf.lift
139+
lifted(this).orElse {
140+
children.foldLeft(None: Option[B]) { (l, r) => l.orElse(r.collectFirst(pf)) }
141+
}
142+
}
143+
133144
/**
134145
* Returns a copy of this node where `f` has been applied to all the nodes children.
135146
*/
@@ -160,7 +171,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
160171
val remainingNewChildren = newChildren.toBuffer
161172
val remainingOldChildren = children.toBuffer
162173
val newArgs = productIterator.map {
163-
// This rule is used to handle children is a input argument.
174+
// Handle Seq[TreeNode] in TreeNode parameters.
164175
case s: Seq[_] => s.map {
165176
case arg: TreeNode[_] if children contains arg =>
166177
val newChild = remainingNewChildren.remove(0)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,54 @@ class TreeNodeSuite extends FunSuite {
172172
expected = None
173173
assert(expected === actual)
174174
}
175+
176+
test("collectFirst") {
177+
val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4))))
178+
179+
// Collect the top node.
180+
{
181+
val actual = expression.collectFirst {
182+
case add: Add => add
183+
}
184+
val expected =
185+
Some(Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))))
186+
assert(expected === actual)
187+
}
188+
189+
// Collect the first children.
190+
{
191+
val actual = expression.collectFirst {
192+
case l @ Literal(1, IntegerType) => l
193+
}
194+
val expected = Some(Literal(1))
195+
assert(expected === actual)
196+
}
197+
198+
// Collect an internal node (Subtract).
199+
{
200+
val actual = expression.collectFirst {
201+
case sub: Subtract => sub
202+
}
203+
val expected = Some(Subtract(Literal(3), Literal(4)))
204+
assert(expected === actual)
205+
}
206+
207+
// Collect a leaf node.
208+
{
209+
val actual = expression.collectFirst {
210+
case l @ Literal(3, IntegerType) => l
211+
}
212+
val expected = Some(Literal(3))
213+
assert(expected === actual)
214+
}
215+
216+
// Collect nothing.
217+
{
218+
val actual = expression.collectFirst {
219+
case l @ Literal(100, IntegerType) => l
220+
}
221+
val expected = None
222+
assert(expected === actual)
223+
}
224+
}
175225
}

0 commit comments

Comments
 (0)