diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 63602eaa8ccd8..deabf1ebf4674 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -20,11 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.FullOuter -import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.RightOuter -import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ @@ -43,6 +39,7 @@ object DefaultOptimizer extends Optimizer { // Operator push down SetOperationPushDown, SamplePushDown, + TransitiveClosure, PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, @@ -674,6 +671,85 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp } } +/** + * Tries to generate a transitive filter for an EqualTo join condition and a filter that on one side + * has the left/right side of the join condition and on the other a foldable expression. Also works + * for filters which are part of the join condition itself + */ +object TransitiveClosure extends Rule[LogicalPlan] with PredicateHelper { + + def split(conditions: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = conditions.partition { cond => + cond.references.subsetOf(left.outputSet) ^ cond.references.subsetOf(right.outputSet) + } + + def toTransPredOption(cond: Expression, from: Expression, to: Expression): Option[Expression] = cond match { + case Not(child) => + toTransPredOption(child, from, to) map Not + case Or(left, right) => for { + l <- toTransPredOption(left, from, to) + r <- toTransPredOption(right, from, to) + } yield Or(l, r) + case e @ In(value, exprs) if value == from && exprs.forall(_.foldable) => + Some(e.copy(value = to)) + case e @ BinaryComparison(l, r) if l.foldable ^ r.foldable && l == from || r == from => + val res = + if (l.foldable) e.makeCopy(Array(to, l)) + else e.makeCopy(Array(to, r)) + Some(res) + case _ => None + } + + def toTransPredByJoinType(filterCondition: Expression, joinCondition: EqualTo, joinType: JoinType) = { + val fromTo = joinType match { + case Inner => + (joinCondition.left, joinCondition.right) :: + (joinCondition.right, joinCondition.left) :: Nil + case LeftOuter | LeftSemi => + (joinCondition.left, joinCondition.right) :: Nil + case RightOuter => (joinCondition.right, joinCondition.left) :: Nil + case FullOuter => Nil + } + fromTo.flatMap(ft => toTransPredOption(filterCondition, ft._1, ft._2)) + } + + def findTransitivePredicates(filterConditions: Seq[Expression], joinConditions: Seq[Expression], joinType: JoinType) = { + val transitivePredCandidate = for { + joinCond <- joinConditions.collect{case e: EqualTo => e} + filterCond <- filterConditions + newPred <- toTransPredByJoinType(filterCond, joinCond, joinType) + } yield newPred + + transitivePredCandidate.toSet -- filterConditions + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(filterCondition, j @ Join(left, right, joinType, joinCondition)) => + val (filterConditions, commonConditions) = split(splitConjunctivePredicates(filterCondition), left, right) + + val joinConditions = commonConditions ++ joinCondition.map(splitConjunctivePredicates).getOrElse(Nil) + + val transitiveCond = findTransitivePredicates(filterConditions, joinConditions, joinType) + + if (transitiveCond.nonEmpty) Filter((transitiveCond + filterCondition).reduce(And), j) + else f + + case j @ Join(left, right, joinType, Some(joinCondition)) => + val (leftRightCondition, commonCondition) = split(splitConjunctivePredicates(joinCondition), left, right) + + // We can process only Inner join in this case because otherwise the original + // predicate does not get pushed down and keeps triggering generation of the same + // transitive predicate on each Iteration + joinType match { + case Inner => + val transitiveCond = findTransitivePredicates(leftRightCondition, commonCondition, joinType) + + if (transitiveCond.nonEmpty) j.copy(condition = Option((transitiveCond + joinCondition).reduce(And))) + else j + case _ => j + } + } +} + /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransitiveClosureSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransitiveClosureSuite.scala new file mode 100644 index 0000000000000..0baf7b24ed007 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransitiveClosureSuite.scala @@ -0,0 +1,179 @@ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class TransitiveClosureSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Filter Pushdown", FixedPoint(4), + SamplePushDown, + CombineFilters, + PushPredicateThroughProject, + BooleanSimplification, + TransitiveClosure, + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + ColumnPruning, + ProjectCollapsing) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val testRelation1 = LocalRelation('d.int) + + test("Binary comparison in where clause together with join condition") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y) + .where("x.a".attr === "y.d".attr && "x.a".attr >= 1) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a >= 1) + val right = testRelation1.where('d >= 1) + val correctAnswer = + left.join(right, condition = Some("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Not and mixed join condition with where clause filter") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y, condition = Option("x.a".attr === "y.d".attr)) + .where("x.a".attr !== 1) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a !== 1) + val right = testRelation1.where('d !== 1) + val correctAnswer = + left.join(right, condition = Some("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("In and all conditions are join conditions") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y, condition = Option("x.a".attr === "y.d".attr && ("x.a".attr in (1, 2)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a in (1, 2)) + val right = testRelation1.where('d in (1, 2)) + val correctAnswer = + left.join(right, condition = Some("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Or condition") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y, condition = Option("x.a".attr === "y.d".attr)) + .where("x.a".attr === 1 || "x.a".attr === 2) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a === 1 || 'a === 2) + val right = testRelation1.where('d === 1 || 'd === 2) + val correctAnswer = + left.join(right, condition = Some("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Mixed join and where clause conditions") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y, condition = Option("x.a".attr === "y.d".attr && "y.d".attr === 2)) + .where("x.a".attr === 1) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a === 1 && 'a === 2) + val right = testRelation1.where('d === 1 && 'd === 2) + val correctAnswer = + left.join(right, condition = Some("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Left join works one way") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) + .where("x.a".attr === 1 && "y.d".attr === 2) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a === 1) + val right = testRelation1 + val correctAnswer = + left.join(right, LeftOuter, Some("a".attr === "d".attr)) + .where("d".attr === 1 && "d".attr === 2) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Right join works one way") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y, RightOuter, Option("x.a".attr === "y.d".attr)) + .where("y.d".attr === 2 && "x.a".attr === 1) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where('d === 2) + val correctAnswer = + left.join(right, RightOuter, Some("a".attr === "d".attr)) + .where("a".attr === 2 && "a".attr === 1) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Full join does not do anything") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr && "y.d".attr === 2)) + .where("x.a".attr === 1) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1 + val correctAnswer = + left.join(right, FullOuter, Some("a".attr === "d".attr && "d".attr === 2)) + .where("a".attr === 1).analyze + + comparePlans(optimized, correctAnswer) + } + +}