Skip to content

Commit de28e4d

Browse files
committed
[SPARK-11973][SQL] Improve optimizer code readability.
This is a followup for #9959. I added more documentation and rewrote some monadic code into simpler ifs. Author: Reynold Xin <[email protected]> Closes #9995 from rxin/SPARK-11973.
1 parent ad76562 commit de28e4d

File tree

2 files changed

+26
-26
lines changed

2 files changed

+26
-26
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ object DefaultOptimizer extends Optimizer {
5959
ConstantFolding,
6060
LikeSimplification,
6161
BooleanSimplification,
62-
RemoveDispensable,
62+
RemoveDispensableExpressions,
6363
SimplifyFilters,
6464
SimplifyCasts,
6565
SimplifyCaseConversionExpressions) ::
@@ -660,49 +660,49 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp
660660
case filter @ Filter(condition, g: Generate) =>
661661
// Predicates that reference attributes produced by the `Generate` operator cannot
662662
// be pushed below the operator.
663-
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
664-
conjunct => conjunct.references subsetOf g.child.outputSet
663+
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
664+
cond.references subsetOf g.child.outputSet
665665
}
666666
if (pushDown.nonEmpty) {
667667
val pushDownPredicate = pushDown.reduce(And)
668-
val withPushdown = Generate(g.generator, join = g.join, outer = g.outer,
668+
val newGenerate = Generate(g.generator, join = g.join, outer = g.outer,
669669
g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
670-
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
670+
if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate)
671671
} else {
672672
filter
673673
}
674674
}
675675
}
676676

677677
/**
678-
* Push [[Filter]] operators through [[Aggregate]] operators. Parts of the predicate that reference
679-
* attributes which are subset of group by attribute set of [[Aggregate]] will be pushed beneath,
680-
* and the rest should remain above.
678+
* Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only
679+
* non-aggregate attributes (typically literals or grouping expressions).
681680
*/
682681
object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper {
683682

684683
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
685-
case filter @ Filter(condition,
686-
aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) =>
687-
688-
def hasAggregate(expression: Expression): Boolean = expression match {
689-
case agg: AggregateExpression => true
690-
case other => expression.children.exists(hasAggregate)
691-
}
692-
// Create a map of Alias for expressions that does not have AggregateExpression
693-
val aliasMap = AttributeMap(aggregateExpressions.collect {
694-
case a: Alias if !hasAggregate(a.child) => (a.toAttribute, a.child)
684+
case filter @ Filter(condition, aggregate: Aggregate) =>
685+
// Find all the aliased expressions in the aggregate list that don't include any actual
686+
// AggregateExpression, and create a map from the alias to the expression
687+
val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect {
688+
case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
689+
(a.toAttribute, a.child)
695690
})
696691

697-
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct =>
698-
val replaced = replaceAlias(conjunct, aliasMap)
699-
replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic
692+
// For each filter, expand the alias and check if the filter can be evaluated using
693+
// attributes produced by the aggregate operator's child operator.
694+
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
695+
val replaced = replaceAlias(cond, aliasMap)
696+
replaced.references.subsetOf(aggregate.child.outputSet) && replaced.deterministic
700697
}
698+
701699
if (pushDown.nonEmpty) {
702700
val pushDownPredicate = pushDown.reduce(And)
703701
val replaced = replaceAlias(pushDownPredicate, aliasMap)
704-
val withPushdown = aggregate.copy(child = Filter(replaced, grandChild))
705-
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
702+
val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child))
703+
// If there is no more filter to stay up, just eliminate the filter.
704+
// Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)".
705+
if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate)
706706
} else {
707707
filter
708708
}
@@ -714,7 +714,7 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel
714714
* evaluated using only the attributes of the left or right side of a join. Other
715715
* [[Filter]] conditions are moved into the `condition` of the [[Join]].
716716
*
717-
* And also Pushes down the join filter, where the `condition` can be evaluated using only the
717+
* And also pushes down the join filter, where the `condition` can be evaluated using only the
718718
* attributes of the left or right side of sub query when applicable.
719719
*
720720
* Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details
@@ -821,7 +821,7 @@ object SimplifyCasts extends Rule[LogicalPlan] {
821821
/**
822822
* Removes nodes that are not necessary.
823823
*/
824-
object RemoveDispensable extends Rule[LogicalPlan] {
824+
object RemoveDispensableExpressions extends Rule[LogicalPlan] {
825825
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
826826
case UnaryPositive(child) => child
827827
case PromotePrecision(child) => child

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ class FilterPushdownSuite extends PlanTest {
734734
comparePlans(optimized, correctAnswer)
735735
}
736736

737-
test("aggregate: don't push down filters which is nondeterministic") {
737+
test("aggregate: don't push down filters that are nondeterministic") {
738738
val originalQuery = testRelation
739739
.select('a, 'b)
740740
.groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd"))

0 commit comments

Comments
 (0)