diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f4dba67f13b54..098a5a8ee7154 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -197,6 +197,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { * - Aggregate * - Generate * - Project <- Join + * - Project <- Filter <- Join * - LeftSemiJoin */ object ColumnPruning extends Rule[LogicalPlan] { @@ -246,6 +247,16 @@ object ColumnPruning extends Rule[LogicalPlan] { Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) + // Eliminate unneeded attributes from either side of a Join. + case Project(projectList, Filter(predicates, Join(left, right, joinType, condition))) => + val allReferences: AttributeSet = + AttributeSet( + projectList.flatMap(_.references.iterator)) ++ + predicates.references ++ + condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) + Project(projectList, Filter(predicates, Join( + prunedChild(left, allReferences), prunedChild(right, allReferences), joinType, condition))) + // Eliminate unneeded attributes from right side of a LeftSemiJoin. case Join(left, right, LeftSemi, condition) => // Collect the list of all references required to evaluate the condition.