@@ -59,7 +59,7 @@ object DefaultOptimizer extends Optimizer {
5959 ConstantFolding ,
6060 LikeSimplification ,
6161 BooleanSimplification ,
62- RemoveDispensable ,
62+ RemoveDispensableExpressions ,
6363 SimplifyFilters ,
6464 SimplifyCasts ,
6565 SimplifyCaseConversionExpressions ) ::
@@ -660,49 +660,49 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp
660660 case filter @ Filter (condition, g : Generate ) =>
661661 // Predicates that reference attributes produced by the `Generate` operator cannot
662662 // be pushed below the operator.
663- val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
664- conjunct => conjunct .references subsetOf g.child.outputSet
663+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
664+ cond .references subsetOf g.child.outputSet
665665 }
666666 if (pushDown.nonEmpty) {
667667 val pushDownPredicate = pushDown.reduce(And )
668- val withPushdown = Generate (g.generator, join = g.join, outer = g.outer,
668+ val newGenerate = Generate (g.generator, join = g.join, outer = g.outer,
669669 g.qualifier, g.generatorOutput, Filter (pushDownPredicate, g.child))
670- stayUp.reduceOption( And ).map( Filter (_, withPushdown)).getOrElse(withPushdown )
670+ if ( stayUp.isEmpty) newGenerate else Filter (stayUp.reduce( And ), newGenerate )
671671 } else {
672672 filter
673673 }
674674 }
675675}
676676
677677/**
678- * Push [[Filter ]] operators through [[Aggregate ]] operators. Parts of the predicate that reference
679- * attributes which are subset of group by attribute set of [[Aggregate ]] will be pushed beneath,
680- * and the rest should remain above.
678+ * Push [[Filter ]] operators through [[Aggregate ]] operators, iff the filters reference only
679+ * non-aggregate attributes (typically literals or grouping expressions).
681680 */
682681object PushPredicateThroughAggregate extends Rule [LogicalPlan ] with PredicateHelper {
683682
684683 def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
685- case filter @ Filter (condition,
686- aggregate @ Aggregate (groupingExpressions, aggregateExpressions, grandChild)) =>
687-
688- def hasAggregate (expression : Expression ): Boolean = expression match {
689- case agg : AggregateExpression => true
690- case other => expression.children.exists(hasAggregate)
691- }
692- // Create a map of Alias for expressions that does not have AggregateExpression
693- val aliasMap = AttributeMap (aggregateExpressions.collect {
694- case a : Alias if ! hasAggregate(a.child) => (a.toAttribute, a.child)
684+ case filter @ Filter (condition, aggregate : Aggregate ) =>
685+ // Find all the aliased expressions in the aggregate list that don't include any actual
686+ // AggregateExpression, and create a map from the alias to the expression
687+ val aliasMap = AttributeMap (aggregate.aggregateExpressions.collect {
688+ case a : Alias if a.child.find(_.isInstanceOf [AggregateExpression ]).isEmpty =>
689+ (a.toAttribute, a.child)
695690 })
696691
697- val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct =>
698- val replaced = replaceAlias(conjunct, aliasMap)
699- replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic
692+ // For each filter, expand the alias and check if the filter can be evaluated using
693+ // attributes produced by the aggregate operator's child operator.
694+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
695+ val replaced = replaceAlias(cond, aliasMap)
696+ replaced.references.subsetOf(aggregate.child.outputSet) && replaced.deterministic
700697 }
698+
701699 if (pushDown.nonEmpty) {
702700 val pushDownPredicate = pushDown.reduce(And )
703701 val replaced = replaceAlias(pushDownPredicate, aliasMap)
704- val withPushdown = aggregate.copy(child = Filter (replaced, grandChild))
705- stayUp.reduceOption(And ).map(Filter (_, withPushdown)).getOrElse(withPushdown)
702+ val newAggregate = aggregate.copy(child = Filter (replaced, aggregate.child))
703+ // If there is no more filter to stay up, just eliminate the filter.
704+ // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)".
705+ if (stayUp.isEmpty) newAggregate else Filter (stayUp.reduce(And ), newAggregate)
706706 } else {
707707 filter
708708 }
@@ -714,7 +714,7 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel
714714 * evaluated using only the attributes of the left or right side of a join. Other
715715 * [[Filter ]] conditions are moved into the `condition` of the [[Join ]].
716716 *
717- * And also Pushes down the join filter, where the `condition` can be evaluated using only the
717+ * And also pushes down the join filter, where the `condition` can be evaluated using only the
718718 * attributes of the left or right side of sub query when applicable.
719719 *
720720 * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details
@@ -821,7 +821,7 @@ object SimplifyCasts extends Rule[LogicalPlan] {
821821/**
822822 * Removes nodes that are not necessary.
823823 */
824- object RemoveDispensable extends Rule [LogicalPlan ] {
824+ object RemoveDispensableExpressions extends Rule [LogicalPlan ] {
825825 def apply (plan : LogicalPlan ): LogicalPlan = plan transformAllExpressions {
826826 case UnaryPositive (child) => child
827827 case PromotePrecision (child) => child
0 commit comments