Skip to content

Commit b4787f7

Browse files
committed
improve FoldablePropagation
1 parent a6fc300 commit b4787f7

File tree

2 files changed

+57
-23
lines changed

2 files changed

+57
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -506,18 +506,21 @@ object NullPropagation extends Rule[LogicalPlan] {
506506

507507

508508
/**
509-
* Propagate foldable expressions:
510509
* Replace attributes with aliases of the original foldable expressions if possible.
511-
* Other optimizations will take advantage of the propagated foldable expressions.
512-
*
510+
* Other optimizations will take advantage of the propagated foldable expressions. For example,
511+
* This rule can optimize
513512
* {{{
514513
* SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3
515-
* ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now()
516514
* }}}
515+
* to
516+
* {{{
517+
* SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now()
518+
* }}}
519+
* and other rules can further optimize it and remove the ORDER BY operator.
517520
*/
518521
object FoldablePropagation extends Rule[LogicalPlan] {
519522
def apply(plan: LogicalPlan): LogicalPlan = {
520-
val foldableMap = AttributeMap(plan.flatMap {
523+
var foldableMap = AttributeMap(plan.flatMap {
521524
case Project(projectList, _) => projectList.collect {
522525
case a: Alias if a.child.foldable => (a.toAttribute, a)
523526
}
@@ -530,38 +533,50 @@ object FoldablePropagation extends Rule[LogicalPlan] {
530533
if (foldableMap.isEmpty) {
531534
plan
532535
} else {
533-
var stop = false
534536
CleanupAliases(plan.transformUp {
535-
// A leaf node should not stop the folding process (note that we are traversing up the
536-
// tree, starting at the leaf nodes); so we are allowing it.
537-
case l: LeafNode =>
538-
l
539-
540537
// We can only propagate foldables for a subset of unary nodes.
541-
case u: UnaryNode if !stop && canPropagateFoldables(u) =>
538+
case u: UnaryNode if canPropagateFoldables(u) =>
542539
u.transformExpressions(replaceFoldable)
543540

544-
// Allow inner joins. We do not allow outer join, although its output attributes are
545-
// derived from its children, they are actually different attributes: the output of outer
546-
// join is not always picked from its children, but can also be null.
541+
// Join derives the output attributes from its child while they are actually not the
542+
// same attributes. For example, the output of outer join is not always picked from its
543+
// children, but can also be null. We should exclude these miss-derived attributes when
544+
// propagating the foldable expressions.
547545
// TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
548546
// of outer join.
549-
case j @ Join(_, _, Inner, _) if !stop =>
550-
j.transformExpressions(replaceFoldable)
547+
case j @ Join(left, right, joinType, _) =>
548+
val newJoin = j.transformExpressions(replaceFoldable)
549+
val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
550+
case _: InnerLike | LeftExistence(_) => Nil
551+
case LeftOuter => right.output
552+
case RightOuter => left.output
553+
case FullOuter => left.output ++ right.output
554+
})
555+
foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
556+
case (attr, _) => missDerivedAttrsSet.contains(attr)
557+
}.toSeq)
558+
newJoin
551559

552-
// We can fold the projections an expand holds. However expand changes the output columns
553-
// and often reuses the underlying attributes; so we cannot assume that a column is still
554-
// foldable after the expand has been applied.
560+
// Similar to Join, Expand also miss-derives output attributes from child attributes, we
561+
// should exclude them when propagating.
555562
// TODO(hvanhovell): Expand should use new attributes as the output attributes.
556-
case expand: Expand if !stop =>
563+
case expand: Expand =>
557564
val newExpand = expand.copy(projections = expand.projections.map { projection =>
558565
projection.map(_.transform(replaceFoldable))
559566
})
560-
stop = true
567+
val missDerivedAttrsSet = expand.child.outputSet
568+
foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
569+
case (attr, _) => missDerivedAttrsSet.contains(attr)
570+
}.toSeq)
561571
newExpand
562572

573+
// For other plans, they are not safe to apply foldable propagation, and they should not
574+
// propagate foldable expressions from children.
563575
case other =>
564-
stop = true
576+
val childrenOutputSet = AttributeSet(other.children.flatMap(_.output))
577+
foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
578+
case (attr, _) => childrenOutputSet.contains(attr)
579+
}.toSeq)
565580
other
566581
})
567582
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,23 @@ class FoldablePropagationSuite extends PlanTest {
161161
val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze
162162
comparePlans(optimized, correctAnswer)
163163
}
164+
165+
test("Propagate above outer join") {
166+
val left = LocalRelation('a.int).select('a, Literal(1).as('b))
167+
val right = LocalRelation('c.int).select('c, Literal(1).as("d"))
168+
169+
val join = left.join(
170+
right,
171+
joinType = LeftOuter,
172+
condition = Some('a === 'c && 'b === 'd))
173+
val query = join.select(('b + 3).as('res)).analyze
174+
val optimized = Optimize.execute(query)
175+
176+
val correctAnswer = left.join(
177+
right,
178+
joinType = LeftOuter,
179+
condition = Some('a === 'c && Literal(1) === Literal(1)))
180+
.select((Literal(1) + 3).as('res)).analyze
181+
comparePlans(optimized, correctAnswer)
182+
}
164183
}

0 commit comments

Comments
 (0)