From ae5f6ee5d4ed6b195865c539b8c93aeeedd53363 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 8 Feb 2019 16:22:48 -0800 Subject: [PATCH 01/12] [SPARK-19712] Pushing down Left Semi and Left Anti joins --- .../sql/catalyst/expressions/subquery.scala | 11 + .../sql/catalyst/optimizer/Optimizer.scala | 187 ++++++++++++++- .../spark/sql/catalyst/plans/joinTypes.scala | 7 + .../LeftSemiAntiJoinPushDownSuite.scala | 219 ++++++++++++++++++ .../execution/WholeStageCodegenSuite.scala | 2 +- .../execution/metric/SQLMetricsSuite.scala | 4 +- 6 files changed, 426 insertions(+), 4 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index fc1caed84e27..97424597ccd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -267,6 +267,17 @@ object ScalarSubquery { case _ => false }.isDefined } + + def hasScalarSubquery(e: Expression): Boolean = { + e.find { + case s: ScalarSubquery => true + case _ => false + }.isDefined + } + + def hasScalarSubquery(e: Seq[Expression]): Boolean = { + e.find(hasScalarSubquery(_)).isDefined + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 38a051c15476..5bab2525a827 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -95,6 +95,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) EliminateOuterJoin, PushPredicateThroughJoin, PushDownPredicate, + PushDownLeftSemiAntiJoin, LimitPushDown, ColumnPruning, InferFiltersFromConstraints, @@ -1131,7 +1132,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } } - private def canPushThrough(p: UnaryNode): Boolean = p match { + def canPushThrough(p: UnaryNode): Boolean = p match { // Note that some operators (e.g. project, aggregate, union) are being handled separately // (earlier in this rule). case _: AppendColumns => true @@ -1188,6 +1189,190 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } } +object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Similar to the above Filter over Project + // LeftSemi/LeftAnti over Project + case join @ Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if pList.forall(_.deterministic) && !ScalarSubquery.hasScalarSubquery(pList) && + canPushThroughCondition(Seq(gChild), joinCond, rightOp) => + if (joinCond.isEmpty) { + // No join condition, just push down the Join below Project + Project(pList, Join(gChild, rightOp, joinType, joinCond, hint)) + } else { + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). + val aliasMap = AttributeMap(pList.collect { + case a: Alias => (a.toAttribute, a.child) + }) + val newJoinCond = if (aliasMap.nonEmpty) { + Option(replaceAlias(joinCond.get, aliasMap)) + } else { + joinCond + } + Project(pList, Join(gChild, rightOp, joinType, newJoinCond, hint)) + } + + // Similar to the above Filter over Aggregate + // LeftSemi/LeftAnti over Aggregate + case join @ Join(aggregate: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if aggregate.aggregateExpressions.forall(_.deterministic) + && aggregate.groupingExpressions.nonEmpty => + if (joinCond.isEmpty) { + // No join condition, just push down Join below Aggregate + aggregate.copy(child = Join(aggregate.child, rightOp, joinType, joinCond, hint)) + } else { + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) + }) + + // For each join condition, expand the alias and + // check if the condition can be evaluated using + // attributes produced by the aggregate operator's child operator. + + val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => + val replaced = replaceAlias(cond, aliasMap) + cond.references.nonEmpty && + replaced.references.subsetOf(aggregate.child.outputSet ++ rightOp.outputSet) + } + + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val replaced = replaceAlias(pushDownPredicate, aliasMap) + val newAggregate = aggregate.copy(child = + Join(aggregate.child, rightOp, joinType, Option(replaced), hint)) + // If there is no more filter to stay up, just return the Aggregate over Join. + // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)". + if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) + } else { + // The join condition is not a subset of the Aggregate's GROUP BY columns, + // no push down. + join + } + } + + // Similar to the above Filter over Window + // LeftSemi/LeftAnti over Window + case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => + if (joinCond.isEmpty) { + // No join condition, just push down Join below Window + w.copy(child = Join(w.child, rightOp, joinType, joinCond, hint)) + } else { + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ + rightOp.outputSet + + val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => + cond.references.subsetOf(partitionAttrs) + } + + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val newPlan = + w.copy(child = Join(w.child, rightOp, joinType, Option(pushDownPredicate), hint)) + if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan) + } else { + // The join condition is not a subset of the Window's PARTITION BY clause, + // no push down. + join + } + } + + // Similar to the above Filter over Union + // LeftSemi/LeftAnti over Union + case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if canPushThroughCondition(union.children, joinCond, rightOp) => + if (joinCond.isEmpty) { + // Push down the Join below Union + val newGrandChildren = union.children.map { grandchild => + Join(grandchild, rightOp, joinType, joinCond, hint) + } + union.withNewChildren(newGrandChildren) + } else { + val pushDown = splitConjunctivePredicates(joinCond.get) + + if (pushDown.nonEmpty) { + val pushDownCond = pushDown.reduceLeft(And) + val output = union.output + val newGrandChildren = union.children.map { grandchild => + val newCond = pushDownCond transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) + } + assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet)) + Join(grandchild, rightOp, joinType, Option(newCond), hint) + } + union.withNewChildren(newGrandChildren) + } else { + // Nothing to push down + join + } + } + + // Similar to the above Filter over UnaryNode + // LeftSemi/LeftAnti over UnaryNode + case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if PushDownPredicate.canPushThrough(u) => + pushDownJoin(join, u.child) { joinCond => + u.withNewChildren(Seq(Join(u.child, rightOp, joinType, Option(joinCond), hint))) + } + } + + /** + * Check if we can safely push a join through a project or union by making sure that predicate + * subqueries in the condition do not contain the same attributes as the plan they are moved + * into. This can happen when the plan and predicate subquery have the same source. + */ + private def canPushThroughCondition(plans: Seq[LogicalPlan], condition: Option[Expression], + rightOp: LogicalPlan): Boolean = { + val attributes = AttributeSet(plans.flatMap (_.output)) + if (condition.isDefined) { + val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes) + matched.isEmpty + } else true + } + + + private def pushDownJoin( + join: Join, + grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + // Only push down the join when join condition deterministic and all the referenced attributes + // come from childen of left and right legs of join. + val (candidates, containingNonDeterministic) = if (join.condition.isDefined) { + splitConjunctivePredicates(join.condition.get).partition(_.deterministic) + } else { + (Nil, Nil) + } + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(grandchild.outputSet ++ join.right.outputSet) + } + + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val newChild = insertFilter(pushDown.reduceLeft(And)) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + join + } + } + +} + /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index c77849035a97..86cdc261329c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -114,3 +114,10 @@ object LeftExistence { case _ => None } } + +object LeftSemiOrAnti { + def unapply(joinType: JoinType): Option[JoinType] = joinType match { + case LeftSemi | LeftAnti => Some(joinType) + case _ => None + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala new file mode 100644 index 000000000000..d57d1aeca96b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.unsafe.types.CalendarInterval + +class LeftSemiPushdownSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Filter Pushdown", FixedPoint(10), + CombineFilters, + PushDownPredicate, + PushDownLeftSemiAntiJoin, + BooleanSimplification, + PushPredicateThroughJoin, + CollapseProject) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val testRelation1 = LocalRelation('d.int) + + test("Project: LeftSemiAnti join pushdown") { + val originalQuery = testRelation + .select(star()) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .select('a, 'b, 'c) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = testRelation + .select(Rand('a), 'b, 'c) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(Rand('a), 'b, 'c) + .join(y, joinType = LeftSemi, condition = Some('b === 'd)) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Aggregate: LeftSemiAnti join pushdown") { + val originalQuery = testRelation + .groupBy('b)('b, sum('c)) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .groupBy('b)('b, sum('c)) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") { + val originalQuery = testRelation + .groupBy('b)('b, Rand(10).as('c)) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .groupBy('b)('b, Rand(10).as('c)) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Aggregate: LeftSemiAnti join partial pushdown") { + val originalQuery = testRelation + // .select('b.as('alias1)) + .groupBy('b)('b, sum('c).as('sum)) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 10)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .groupBy('b)('b, sum('c).as('sum)) + .where('sum === 10) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("LeftSemiAnti join over aggregate - no pushdown") { + val originalQuery = testRelation + // .select('b.as('alias1)) + .groupBy('b)('b, sum('c).as('sum)) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + // .select('b.as('alias1)) + .groupBy('b)('b, sum('c).as('sum)) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd)) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("LeftSemiAnti join over Window") { + val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + + val originalQuery = testRelation + .select('a, 'b, 'c, winExpr.as('window)) + .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) + .select('a, 'b, 'c) + .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) + .select('a, 'b, 'c, 'window).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Window: LeftSemiAnti partial pushdown") { + // Attributes from join condition which does not refer to the window partition spec + // are kept up in the plan as a Filter operator above Window. + val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + + val originalQuery = testRelation + .select('a, 'b, 'c, winExpr.as('window)) + .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd && 'b > 5)) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) + .select('a, 'b, 'c) + .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) + .where('b > 5) + .select('a, 'b, 'c, 'window).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Union: LeftSemiAnti join pushdown") { + val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int) + + val originalQuery = Union(Seq(testRelation, testRelation2)) + .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Union(Seq( + testRelation.join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)), + testRelation2.join(testRelation1, joinType = LeftSemi, condition = Some('x === 'd)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Union: LeftSemiAnti join no pushdown in self join scenario") { + val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int) + + val originalQuery = Union(Seq(testRelation, testRelation2)) + .join(testRelation2, joinType = LeftSemi, condition = Some('a === 'x)) + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, originalQuery.analyze) + } + + test("Unary: LeftSemiAnti join pushdown") { + val originalQuery = testRelation + .select(star()) + .repartition(1) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .select('a, 'b, 'c) + .repartition(1) + .analyze + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 3c9a0908147a..ceb5fdd5fce1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -169,7 +169,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val plan = df.queryExecution.executedPlan assert(!plan.find(p => - p.isInstanceOf[WholeStageCodegenExec] && + p.isInstanceOf[WholeStageCodegenExec] && p.isInstanceOf[SortMergeJoinExec] && p.asInstanceOf[WholeStageCodegenExec].child.children(0) .isInstanceOf[SortMergeJoinExec]).isDefined) assert(df.collect() === Array(Row(1), Row(2))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 98a8ad5eeb2b..b77048aba779 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -345,10 +345,10 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") // Assume the execution plan is - // ... -> BroadcastHashJoin(nodeId = 0) + // ... -> BroadcastHashJoin(nodeId = 1) val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") testSparkPlanMetrics(df, 2, Map( - 0L -> (("BroadcastHashJoin", Map( + 1L -> (("BroadcastHashJoin", Map( "number of output rows" -> 2L)))) ) } From 488eda8d46b2649d199c6cc7487bc0a5fe9eaa90 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 15 Feb 2019 19:39:59 -0800 Subject: [PATCH 02/12] Code review --- .../sql/catalyst/expressions/subquery.scala | 6 +- .../sql/catalyst/optimizer/Optimizer.scala | 217 ++---------------- .../optimizer/PushDownLeftSemiAntiJoin.scala | 197 ++++++++++++++++ 3 files changed, 218 insertions(+), 202 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 97424597ccd7..e647df983db6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -270,14 +270,10 @@ object ScalarSubquery { def hasScalarSubquery(e: Expression): Boolean = { e.find { - case s: ScalarSubquery => true + case _: ScalarSubquery => true case _ => false }.isDefined } - - def hasScalarSubquery(e: Seq[Expression]): Boolean = { - e.find(hasScalarSubquery(_)).isDefined - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5bab2525a827..ae45f4afaadd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1017,24 +1017,13 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // This also applies to Aggregate. case Filter(condition, project @ Project(fields, grandChild)) if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) => - - // Create a map of Aliases to their values from the child projection. - // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). - val aliasMap = AttributeMap(fields.collect { - case a: Alias => (a.toAttribute, a.child) - }) - + val aliasMap = getAliasMap(project) project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) case filter @ Filter(condition, aggregate: Aggregate) if aggregate.aggregateExpressions.forall(_.deterministic) && aggregate.groupingExpressions.nonEmpty => - // Find all the aliased expressions in the aggregate list that don't include any actual - // AggregateExpression, and create a map from the alias to the expression - val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { - case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => - (a.toAttribute, a.child) - }) + val aliasMap = getAliasMap(aggregate) // For each filter, expand the alias and check if the filter can be evaluated using // attributes produced by the aggregate operator's child operator. @@ -1132,6 +1121,24 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } } + + def getAliasMap(plan: LogicalPlan): AttributeMap[Expression] = { + val aliasMap = plan match { + case p: Project => + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). + p.projectList.collect { case a: Alias => (a.toAttribute, a.child) } + case a: Aggregate => + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + a.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) + } + } + AttributeMap(aliasMap) + } + def canPushThrough(p: UnaryNode): Boolean = p match { // Note that some operators (e.g. project, aggregate, union) are being handled separately // (earlier in this rule). @@ -1189,190 +1196,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } } -object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Similar to the above Filter over Project - // LeftSemi/LeftAnti over Project - case join @ Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if pList.forall(_.deterministic) && !ScalarSubquery.hasScalarSubquery(pList) && - canPushThroughCondition(Seq(gChild), joinCond, rightOp) => - if (joinCond.isEmpty) { - // No join condition, just push down the Join below Project - Project(pList, Join(gChild, rightOp, joinType, joinCond, hint)) - } else { - // Create a map of Aliases to their values from the child projection. - // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). - val aliasMap = AttributeMap(pList.collect { - case a: Alias => (a.toAttribute, a.child) - }) - val newJoinCond = if (aliasMap.nonEmpty) { - Option(replaceAlias(joinCond.get, aliasMap)) - } else { - joinCond - } - Project(pList, Join(gChild, rightOp, joinType, newJoinCond, hint)) - } - - // Similar to the above Filter over Aggregate - // LeftSemi/LeftAnti over Aggregate - case join @ Join(aggregate: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if aggregate.aggregateExpressions.forall(_.deterministic) - && aggregate.groupingExpressions.nonEmpty => - if (joinCond.isEmpty) { - // No join condition, just push down Join below Aggregate - aggregate.copy(child = Join(aggregate.child, rightOp, joinType, joinCond, hint)) - } else { - // Find all the aliased expressions in the aggregate list that don't include any actual - // AggregateExpression, and create a map from the alias to the expression - val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { - case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => - (a.toAttribute, a.child) - }) - - // For each join condition, expand the alias and - // check if the condition can be evaluated using - // attributes produced by the aggregate operator's child operator. - - val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => - val replaced = replaceAlias(cond, aliasMap) - cond.references.nonEmpty && - replaced.references.subsetOf(aggregate.child.outputSet ++ rightOp.outputSet) - } - - // Check if the remaining predicates do not contain columns from subquery - val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) - - if (pushDown.nonEmpty && rightOpColumns.isEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val replaced = replaceAlias(pushDownPredicate, aliasMap) - val newAggregate = aggregate.copy(child = - Join(aggregate.child, rightOp, joinType, Option(replaced), hint)) - // If there is no more filter to stay up, just return the Aggregate over Join. - // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)". - if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) - } else { - // The join condition is not a subset of the Aggregate's GROUP BY columns, - // no push down. - join - } - } - - // Similar to the above Filter over Window - // LeftSemi/LeftAnti over Window - case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => - if (joinCond.isEmpty) { - // No join condition, just push down Join below Window - w.copy(child = Join(w.child, rightOp, joinType, joinCond, hint)) - } else { - val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ - rightOp.outputSet - - val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => - cond.references.subsetOf(partitionAttrs) - } - - // Check if the remaining predicates do not contain columns from subquery - val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) - - if (pushDown.nonEmpty && rightOpColumns.isEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val newPlan = - w.copy(child = Join(w.child, rightOp, joinType, Option(pushDownPredicate), hint)) - if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan) - } else { - // The join condition is not a subset of the Window's PARTITION BY clause, - // no push down. - join - } - } - - // Similar to the above Filter over Union - // LeftSemi/LeftAnti over Union - case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if canPushThroughCondition(union.children, joinCond, rightOp) => - if (joinCond.isEmpty) { - // Push down the Join below Union - val newGrandChildren = union.children.map { grandchild => - Join(grandchild, rightOp, joinType, joinCond, hint) - } - union.withNewChildren(newGrandChildren) - } else { - val pushDown = splitConjunctivePredicates(joinCond.get) - - if (pushDown.nonEmpty) { - val pushDownCond = pushDown.reduceLeft(And) - val output = union.output - val newGrandChildren = union.children.map { grandchild => - val newCond = pushDownCond transform { - case e if output.exists(_.semanticEquals(e)) => - grandchild.output(output.indexWhere(_.semanticEquals(e))) - } - assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet)) - Join(grandchild, rightOp, joinType, Option(newCond), hint) - } - union.withNewChildren(newGrandChildren) - } else { - // Nothing to push down - join - } - } - - // Similar to the above Filter over UnaryNode - // LeftSemi/LeftAnti over UnaryNode - case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if PushDownPredicate.canPushThrough(u) => - pushDownJoin(join, u.child) { joinCond => - u.withNewChildren(Seq(Join(u.child, rightOp, joinType, Option(joinCond), hint))) - } - } - - /** - * Check if we can safely push a join through a project or union by making sure that predicate - * subqueries in the condition do not contain the same attributes as the plan they are moved - * into. This can happen when the plan and predicate subquery have the same source. - */ - private def canPushThroughCondition(plans: Seq[LogicalPlan], condition: Option[Expression], - rightOp: LogicalPlan): Boolean = { - val attributes = AttributeSet(plans.flatMap (_.output)) - if (condition.isDefined) { - val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes) - matched.isEmpty - } else true - } - - - private def pushDownJoin( - join: Join, - grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { - // Only push down the join when join condition deterministic and all the referenced attributes - // come from childen of left and right legs of join. - val (candidates, containingNonDeterministic) = if (join.condition.isDefined) { - splitConjunctivePredicates(join.condition.get).partition(_.deterministic) - } else { - (Nil, Nil) - } - - val (pushDown, rest) = candidates.partition { cond => - cond.references.subsetOf(grandchild.outputSet ++ join.right.outputSet) - } - - val stayUp = rest ++ containingNonDeterministic - - if (pushDown.nonEmpty) { - val newChild = insertFilter(pushDown.reduceLeft(And)) - if (stayUp.nonEmpty) { - Filter(stayUp.reduceLeft(And), newChild) - } else { - newChild - } - } else { - join - } - } - -} - /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala new file mode 100644 index 000000000000..2ba88f38ded8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} +import org.apache.spark.sql.catalyst.plans.LeftSemiOrAnti +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * Pushes Left semi and Left Anti joins below the following operators. + * 1) Project + * 2) Window + * 3) Union + * 4) Aggregate + * 5) Other permissible unary operators. please see [[PushDownPredicate.canPushThrough]]. + */ + +object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Similar to the above Filter over Project + // LeftSemi/LeftAnti over Project + case Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if pList.forall(_.deterministic) && + !pList.find(ScalarSubquery.hasScalarSubquery(_)).isDefined && + canPushThroughCondition(Seq(gChild), joinCond, rightOp) => + if (joinCond.isEmpty) { + // No join condition, just push down the Join below Project + p.copy(child = Join(gChild, rightOp, joinType, joinCond, hint)) + } else { + val aliasMap = PushDownPredicate.getAliasMap(p) + val newJoinCond = if (aliasMap.nonEmpty) { + Option(replaceAlias(joinCond.get, aliasMap)) + } else { + joinCond + } + p.copy(child = Join(gChild, rightOp, joinType, newJoinCond, hint)) + } + + // Similar to the above Filter over Aggregate + // LeftSemi/LeftAnti over Aggregate + case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if agg.aggregateExpressions.forall(_.deterministic) + && agg.groupingExpressions.nonEmpty => + if (joinCond.isEmpty) { + // No join condition, just push down Join below Aggregate + agg.copy(child = Join(agg.child, rightOp, joinType, joinCond, hint)) + } else { + val aliasMap = PushDownPredicate.getAliasMap(agg) + + // For each join condition, expand the alias and + // check if the condition can be evaluated using + // attributes produced by the aggregate operator's child operator. + val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => + val replaced = replaceAlias(cond, aliasMap) + cond.references.nonEmpty && + replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet) + } + + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val replaced = replaceAlias(pushDownPredicate, aliasMap) + val newAgg = agg.copy(child = Join(agg.child, rightOp, joinType, Option(replaced), hint)) + // If there is no more filter to stay up, just return the Aggregate over Join. + // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)". + if (stayUp.isEmpty) newAgg else Filter(stayUp.reduce(And), newAgg) + } else { + // The join condition is not a subset of the Aggregate's GROUP BY columns, + // no push down. + join + } + } + + // Similar to the above Filter over Window + // LeftSemi/LeftAnti over Window + case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => + if (joinCond.isEmpty) { + // No join condition, just push down Join below Window + w.copy(child = Join(w.child, rightOp, joinType, joinCond, hint)) + } else { + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ + rightOp.outputSet + + val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => + cond.references.subsetOf(partitionAttrs) + } + + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val predicate = pushDown.reduce(And) + val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(predicate), hint)) + if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan) + } else { + // The join condition is not a subset of the Window's PARTITION BY clause, + // no push down. + join + } + } + + // Similar to the above Filter over Union + // LeftSemi/LeftAnti over Union + case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if canPushThroughCondition(union.children, joinCond, rightOp) => + if (joinCond.isEmpty) { + // Push down the Join below Union + val newGrandChildren = union.children.map { Join(_, rightOp, joinType, joinCond, hint) } + union.withNewChildren(newGrandChildren) + } else { + val pushDown = splitConjunctivePredicates(joinCond.get) + + if (pushDown.nonEmpty) { + val pushDownCond = pushDown.reduceLeft(And) + val output = union.output + val newGrandChildren = union.children.map { grandchild => + val newCond = pushDownCond transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) + } + assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet)) + Join(grandchild, rightOp, joinType, Option(newCond), hint) + } + union.withNewChildren(newGrandChildren) + } else { + // Nothing to push down + join + } + } + + // Similar to the above Filter over UnaryNode + // LeftSemi/LeftAnti over UnaryNode + case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if PushDownPredicate.canPushThrough(u) => + pushDownJoin(join, u.child) { joinCond => + u.withNewChildren(Seq(Join(u.child, rightOp, joinType, Option(joinCond), hint))) + } + } + + /** + * Check if we can safely push a join through a project or union by making sure that predicate + * subqueries in the condition do not contain the same attributes as the plan they are moved + * into. This can happen when the plan and predicate subquery have the same source. + */ + private def canPushThroughCondition(plans: Seq[LogicalPlan], condition: Option[Expression], + rightOp: LogicalPlan): Boolean = { + val attributes = AttributeSet(plans.flatMap (_.output)) + if (condition.isDefined) { + val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes) + matched.isEmpty + } else true + } + + + private def pushDownJoin( + join: Join, + grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + val (pushDown, stayUp) = if (join.condition.isDefined) { + splitConjunctivePredicates(join.condition.get) + .partition {_.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)} + } else { + (Nil, Nil) + } + + if (pushDown.nonEmpty) { + val newChild = insertFilter(pushDown.reduceLeft(And)) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + join + } + } +} + From ea76e29a520f36dc6917f713b33b9b11d91f30e6 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 18 Feb 2019 13:21:03 -0800 Subject: [PATCH 03/12] Code review --- .../sql/catalyst/optimizer/Optimizer.scala | 24 +++++++++---------- .../optimizer/PushDownLeftSemiAntiJoin.scala | 21 +++++----------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ae45f4afaadd..bf74c8a51d27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1121,20 +1121,18 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } } + def getAliasMap(plan: Project): AttributeMap[Expression] = { + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). + AttributeMap(plan.projectList.collect { case a: Alias => (a.toAttribute, a.child) }) + } - def getAliasMap(plan: LogicalPlan): AttributeMap[Expression] = { - val aliasMap = plan match { - case p: Project => - // Create a map of Aliases to their values from the child projection. - // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). - p.projectList.collect { case a: Alias => (a.toAttribute, a.child) } - case a: Aggregate => - // Find all the aliased expressions in the aggregate list that don't include any actual - // AggregateExpression, and create a map from the alias to the expression - a.aggregateExpressions.collect { - case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => - (a.toAttribute, a.child) - } + def getAliasMap(plan: Aggregate): AttributeMap[Expression] = { + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + val aliasMap = plan.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) } AttributeMap(aliasMap) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index 2ba88f38ded8..bef297a6691a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} -import org.apache.spark.sql.catalyst.plans.LeftSemiOrAnti +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -31,14 +30,12 @@ import org.apache.spark.sql.catalyst.rules.Rule * 4) Aggregate * 5) Other permissible unary operators. please see [[PushDownPredicate.canPushThrough]]. */ - object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Similar to the above Filter over Project // LeftSemi/LeftAnti over Project case Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint) if pList.forall(_.deterministic) && - !pList.find(ScalarSubquery.hasScalarSubquery(_)).isDefined && + !pList.exists(ScalarSubquery.hasScalarSubquery)&& canPushThroughCondition(Seq(gChild), joinCond, rightOp) => if (joinCond.isEmpty) { // No join condition, just push down the Join below Project @@ -53,20 +50,17 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { p.copy(child = Join(gChild, rightOp, joinType, newJoinCond, hint)) } - // Similar to the above Filter over Aggregate // LeftSemi/LeftAnti over Aggregate case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if agg.aggregateExpressions.forall(_.deterministic) - && agg.groupingExpressions.nonEmpty => + if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty => if (joinCond.isEmpty) { // No join condition, just push down Join below Aggregate agg.copy(child = Join(agg.child, rightOp, joinType, joinCond, hint)) } else { val aliasMap = PushDownPredicate.getAliasMap(agg) - // For each join condition, expand the alias and - // check if the condition can be evaluated using - // attributes produced by the aggregate operator's child operator. + // For each join condition, expand the alias and check if the condition can be evaluated + // using attributes produced by the aggregate operator's child operator. val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => val replaced = replaceAlias(cond, aliasMap) cond.references.nonEmpty && @@ -90,7 +84,6 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { } } - // Similar to the above Filter over Window // LeftSemi/LeftAnti over Window case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => @@ -119,7 +112,6 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { } } - // Similar to the above Filter over Union // LeftSemi/LeftAnti over Union case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) if canPushThroughCondition(union.children, joinCond, rightOp) => @@ -148,10 +140,9 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { } } - // Similar to the above Filter over UnaryNode // LeftSemi/LeftAnti over UnaryNode case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if PushDownPredicate.canPushThrough(u) => + if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) => pushDownJoin(join, u.child) { joinCond => u.withNewChildren(Seq(Join(u.child, rightOp, joinType, Option(joinCond), hint))) } From 43e9eef7e2aa9777957c40864e05f0e5701723df Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 19 Feb 2019 14:38:16 -0800 Subject: [PATCH 04/12] Code review --- .../optimizer/PushDownLeftSemiAntiJoin.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index bef297a6691a..c7444ee11dff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule /** - * Pushes Left semi and Left Anti joins below the following operators. + * This rule is a variant of [[PushDownPredicate]] which can handle + * pushing down Left semi and Left Anti joins below the following operators. * 1) Project * 2) Window * 3) Union @@ -149,17 +150,19 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { } /** - * Check if we can safely push a join through a project or union by making sure that predicate - * subqueries in the condition do not contain the same attributes as the plan they are moved + * Check if we can safely push a join through a project or union by making sure that attributes + * referred in join condition do not contain the same attributes as the plan they are moved * into. This can happen when the plan and predicate subquery have the same source. */ private def canPushThroughCondition(plans: Seq[LogicalPlan], condition: Option[Expression], rightOp: LogicalPlan): Boolean = { - val attributes = AttributeSet(plans.flatMap (_.output)) + val attributes = AttributeSet(plans.flatMap(_.output)) if (condition.isDefined) { val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes) matched.isEmpty - } else true + } else { + true + } } From 744355c6d7eb62f6eff9a78ace4da950966aaa9b Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 20 Feb 2019 13:38:15 -0800 Subject: [PATCH 05/12] Code review --- .../optimizer/PushDownLeftSemiAntiJoin.scala | 26 ++++++++++++++----- .../LeftSemiAntiJoinPushDownSuite.scala | 20 +++++++++++--- .../execution/WholeStageCodegenSuite.scala | 2 +- 3 files changed, 36 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index c7444ee11dff..d6d64dfb4641 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -68,7 +68,11 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet) } - // Check if the remaining predicates do not contain columns from subquery + // Check if the remaining predicates do not contain columns from the right + // hand side of the join. Since the remaining predicates will be kept + // as a filter over aggregate, this check is necessary after the left semi + // or left anti join is moved below aggregate. The reason is, for this kind + // of join, we only output from the left leg of the join. val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) if (pushDown.nonEmpty && rightOpColumns.isEmpty) { @@ -99,7 +103,11 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { cond.references.subsetOf(partitionAttrs) } - // Check if the remaining predicates do not contain columns from subquery + // Check if the remaining predicates do not contain columns from the right + // hand side of the join. Since the remaining predicates will be kept + // as a filter over window, this check is necessary after the left semi + // or left anti join is moved below window. The reason is, for this kind + // of join, we only output from the left leg of the join. val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) if (pushDown.nonEmpty && rightOpColumns.isEmpty) { @@ -152,10 +160,14 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Check if we can safely push a join through a project or union by making sure that attributes * referred in join condition do not contain the same attributes as the plan they are moved - * into. This can happen when the plan and predicate subquery have the same source. + * into. This can happen when both sides of join refers to the same source (self join). This + * function makes sure that the join condition refers to attributes that are not ambiguous (i.e + * present in both the legs of the join) or else the resultant plan will be invalid. */ - private def canPushThroughCondition(plans: Seq[LogicalPlan], condition: Option[Expression], - rightOp: LogicalPlan): Boolean = { + private def canPushThroughCondition( + plans: Seq[LogicalPlan], + condition: Option[Expression], + rightOp: LogicalPlan): Boolean = { val attributes = AttributeSet(plans.flatMap(_.output)) if (condition.isDefined) { val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes) @@ -167,8 +179,8 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { private def pushDownJoin( - join: Join, - grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + join: Join, + grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { val (pushDown, stayUp) = if (join.condition.isDefined) { splitConjunctivePredicates(join.condition.get) .partition {_.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index d57d1aeca96b..bf315488ab78 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -60,9 +60,6 @@ class LeftSemiPushdownSuite extends PlanTest { } test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) - val originalQuery = testRelation .select(Rand('a), 'b, 'c) .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) @@ -70,7 +67,22 @@ class LeftSemiPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Rand('a), 'b, 'c) - .join(y, joinType = LeftSemi, condition = Some('b === 'd)) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Project: LeftSemiAnti join no pushdown because scalar subq proj exprs") { + val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum"))) + val originalQuery = testRelation + .select(subq.as("sum")) + .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(subq.as("sum")) + .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd)) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index ceb5fdd5fce1..3c9a0908147a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -169,7 +169,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val plan = df.queryExecution.executedPlan assert(!plan.find(p => - p.isInstanceOf[WholeStageCodegenExec] && p.isInstanceOf[SortMergeJoinExec] && + p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.children(0) .isInstanceOf[SortMergeJoinExec]).isDefined) assert(df.collect() === Array(Row(1), Row(2))) From 0fa7950038d082e8e5ab47df3bc94750c22cd687 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 20 Feb 2019 13:51:01 -0800 Subject: [PATCH 06/12] minor --- .../sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index bf315488ab78..2bce5ed58bc9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -134,13 +134,11 @@ class LeftSemiPushdownSuite extends PlanTest { test("LeftSemiAnti join over aggregate - no pushdown") { val originalQuery = testRelation - // .select('b.as('alias1)) .groupBy('b)('b, sum('c).as('sum)) .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd)) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - // .select('b.as('alias1)) .groupBy('b)('b, sum('c).as('sum)) .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd)) .analyze From 4a11ad4978df337d48e772bcc3129e37d0aaf52c Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 20 Feb 2019 15:31:02 -0800 Subject: [PATCH 07/12] Don't pushdown in presence of scalar subqueries inside aggregate expressions --- .../optimizer/PushDownLeftSemiAntiJoin.scala | 6 +++--- .../optimizer/LeftSemiAntiJoinPushDownSuite.scala | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index d6d64dfb4641..4b06401df1c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -35,8 +35,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // LeftSemi/LeftAnti over Project case Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if pList.forall(_.deterministic) && - !pList.exists(ScalarSubquery.hasScalarSubquery)&& + if pList.forall(_.deterministic) && !pList.exists(ScalarSubquery.hasScalarSubquery) && canPushThroughCondition(Seq(gChild), joinCond, rightOp) => if (joinCond.isEmpty) { // No join condition, just push down the Join below Project @@ -53,7 +52,8 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { // LeftSemi/LeftAnti over Aggregate case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty => + if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty && + !agg.aggregateExpressions.exists(ScalarSubquery.hasScalarSubquery) => if (joinCond.isEmpty) { // No join condition, just push down Join below Aggregate agg.copy(child = Join(agg.child, rightOp, joinType, joinCond, hint)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index 2bce5ed58bc9..a8a5f0b43a63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -146,6 +146,21 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Aggregate: LeftSemiAnti join no pushdown because scalar subq aggr exprs") { + val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum"))) + val originalQuery = testRelation + .groupBy('a) ('a, subq.as("sum")) + .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd && 'a === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .groupBy('a) ('a, subq.as("sum")) + .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd && 'a === 'd)) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("LeftSemiAnti join over Window") { val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) From fa41b442cd5f6e39cdc3890e88ed97753810e9a6 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 28 Feb 2019 17:29:01 -0800 Subject: [PATCH 08/12] Code review --- .../optimizer/PushDownLeftSemiAntiJoin.scala | 5 ++-- .../LeftSemiAntiJoinPushDownSuite.scala | 28 +++++++++++++++---- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index 4b06401df1c9..79a5f6dcc9d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -35,7 +35,8 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // LeftSemi/LeftAnti over Project case Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if pList.forall(_.deterministic) && !pList.exists(ScalarSubquery.hasScalarSubquery) && + if pList.forall(_.deterministic) && + !pList.exists(ScalarSubquery.hasCorrelatedScalarSubquery) && canPushThroughCondition(Seq(gChild), joinCond, rightOp) => if (joinCond.isEmpty) { // No join condition, just push down the Join below Project @@ -53,7 +54,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { // LeftSemi/LeftAnti over Aggregate case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty && - !agg.aggregateExpressions.exists(ScalarSubquery.hasScalarSubquery) => + !agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) => if (joinCond.isEmpty) { // No join condition, just push down Join below Aggregate agg.copy(child = Join(agg.child, rightOp, joinType, joinCond, hint)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index a8a5f0b43a63..f0acc9113589 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -38,7 +38,6 @@ class LeftSemiPushdownSuite extends PlanTest { PushDownPredicate, PushDownLeftSemiAntiJoin, BooleanSimplification, - PushPredicateThroughJoin, CollapseProject) :: Nil } @@ -73,15 +72,32 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Project: LeftSemiAnti join no pushdown because scalar subq proj exprs") { - val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum"))) + test("Project: LeftSemiAnti join non correlated scalar subq") { + val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze) val originalQuery = testRelation .select(subq.as("sum")) .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd)) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some(subq === 'd)) .select(subq.as("sum")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") { + val testRelation2 = LocalRelation('e.int, 'f.int) + val subqPlan = testRelation2.groupBy('e)(sum('f).as("sum")).where('e === 'a) + val subqExpr = ScalarSubquery(subqPlan) + val originalQuery = testRelation + .select(subqExpr.as("sum")) + .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(subqExpr.as("sum")) .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd)) .analyze @@ -146,16 +162,16 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Aggregate: LeftSemiAnti join no pushdown because scalar subq aggr exprs") { - val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum"))) + test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") { + val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze) val originalQuery = testRelation .groupBy('a) ('a, subq.as("sum")) .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd && 'a === 'd)) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = Some(subq === 'd && 'a === 'd)) .groupBy('a) ('a, subq.as("sum")) - .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd && 'a === 'd)) .analyze comparePlans(optimized, correctAnswer) From 76e7203391ed7dd9544cbcc43822597a09195aba Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 1 Mar 2019 11:07:08 -0800 Subject: [PATCH 09/12] review --- .../sql/catalyst/expressions/subquery.scala | 7 --- .../optimizer/PushDownLeftSemiAntiJoin.scala | 48 ++++++++----------- .../LeftSemiAntiJoinPushDownSuite.scala | 15 ++++++ 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index e647df983db6..fc1caed84e27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -267,13 +267,6 @@ object ScalarSubquery { case _ => false }.isDefined } - - def hasScalarSubquery(e: Expression): Boolean = { - e.find { - case _: ScalarSubquery => true - case _ => false - }.isDefined - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index 79a5f6dcc9d3..7b846bc18fbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -130,24 +130,16 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { val newGrandChildren = union.children.map { Join(_, rightOp, joinType, joinCond, hint) } union.withNewChildren(newGrandChildren) } else { - val pushDown = splitConjunctivePredicates(joinCond.get) - - if (pushDown.nonEmpty) { - val pushDownCond = pushDown.reduceLeft(And) - val output = union.output - val newGrandChildren = union.children.map { grandchild => - val newCond = pushDownCond transform { - case e if output.exists(_.semanticEquals(e)) => - grandchild.output(output.indexWhere(_.semanticEquals(e))) - } - assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet)) - Join(grandchild, rightOp, joinType, Option(newCond), hint) + val output = union.output + val newGrandChildren = union.children.map { grandchild => + val newCond = joinCond.get transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) } - union.withNewChildren(newGrandChildren) - } else { - // Nothing to push down - join + assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet)) + Join(grandchild, rightOp, joinType, Option(newCond), hint) } + union.withNewChildren(newGrandChildren) } // LeftSemi/LeftAnti over UnaryNode @@ -182,22 +174,22 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { private def pushDownJoin( join: Join, grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { - val (pushDown, stayUp) = if (join.condition.isDefined) { - splitConjunctivePredicates(join.condition.get) - .partition {_.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)} + if (join.condition.isEmpty) { + insertFilter(null) } else { - (Nil, Nil) - } + val (pushDown, stayUp) = splitConjunctivePredicates(join.condition.get) + .partition {_.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)} - if (pushDown.nonEmpty) { - val newChild = insertFilter(pushDown.reduceLeft(And)) - if (stayUp.nonEmpty) { - Filter(stayUp.reduceLeft(And), newChild) + if (pushDown.nonEmpty) { + val newChild = insertFilter(pushDown.reduceLeft(And)) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } } else { - newChild + join } - } else { - join } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index f0acc9113589..2a379425ed0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -257,4 +257,19 @@ class LeftSemiPushdownSuite extends PlanTest { .analyze comparePlans(optimized, correctAnswer) } + + test("Unary: LeftSemiAnti join pushdown - empty join condition") { + val originalQuery = testRelation + .select(star()) + .repartition(1) + .join(testRelation1, joinType = LeftSemi, condition = None) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .join(testRelation1, joinType = LeftSemi, condition = None) + .select('a, 'b, 'c) + .repartition(1) + .analyze + comparePlans(optimized, correctAnswer) + } } From 79579c82cfe0d04defb30a1b76095d3547dc70e4 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 2 Mar 2019 01:24:39 -0800 Subject: [PATCH 10/12] fix1 --- .../optimizer/PushDownLeftSemiAntiJoin.scala | 11 ++-- .../LeftSemiAntiJoinPushDownSuite.scala | 52 ++++++++++++++----- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index 7b846bc18fbb..bc868df3dbb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -146,7 +146,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) => pushDownJoin(join, u.child) { joinCond => - u.withNewChildren(Seq(Join(u.child, rightOp, joinType, Option(joinCond), hint))) + u.withNewChildren(Seq(Join(u.child, rightOp, joinType, joinCond, hint))) } } @@ -173,15 +173,16 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { private def pushDownJoin( join: Join, - grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + grandchild: LogicalPlan)(insertJoin: Option[Expression] => LogicalPlan): LogicalPlan = { if (join.condition.isEmpty) { - insertFilter(null) + insertJoin(None) } else { val (pushDown, stayUp) = splitConjunctivePredicates(join.condition.get) .partition {_.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)} - if (pushDown.nonEmpty) { - val newChild = insertFilter(pushDown.reduceLeft(And)) + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(join.right.outputSet) + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val newChild = insertJoin(Option(pushDown.reduceLeft(And))) if (stayUp.nonEmpty) { Filter(stayUp.reduceLeft(And), newChild) } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index 2a379425ed0f..b35f4e5baff9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -134,7 +134,6 @@ class LeftSemiPushdownSuite extends PlanTest { test("Aggregate: LeftSemiAnti join partial pushdown") { val originalQuery = testRelation - // .select('b.as('alias1)) .groupBy('b)('b, sum('c).as('sum)) .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 10)) @@ -154,10 +153,7 @@ class LeftSemiPushdownSuite extends PlanTest { .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd)) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = testRelation - .groupBy('b)('b, sum('c).as('sum)) - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd)) - .analyze + val correctAnswer = originalQuery.analyze comparePlans(optimized, correctAnswer) } @@ -217,19 +213,19 @@ class LeftSemiPushdownSuite extends PlanTest { } test("Union: LeftSemiAnti join pushdown") { - val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int) + val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int) - val originalQuery = Union(Seq(testRelation, testRelation2)) - .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) + val originalQuery = Union(Seq(testRelation, testRelation2)) + .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) - val optimized = Optimize.execute(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Union(Seq( - testRelation.join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)), - testRelation2.join(testRelation1, joinType = LeftSemi, condition = Some('x === 'd)))) - .analyze + val correctAnswer = Union(Seq( + testRelation.join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)), + testRelation2.join(testRelation1, joinType = LeftSemi, condition = Some('x === 'd)))) + .analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, correctAnswer) } test("Union: LeftSemiAnti join no pushdown in self join scenario") { @@ -272,4 +268,32 @@ class LeftSemiPushdownSuite extends PlanTest { .analyze comparePlans(optimized, correctAnswer) } + + test("Unary: LeftSemiAnti join pushdown - partial pushdown") { + val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + val originalQuery = testRelationWithArrayType + .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'b === 'out_col)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelationWithArrayType + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) + .where('b === 'out_col) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Unary: LeftSemiAnti join pushdown - no pushdown") { + val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + val originalQuery = testRelationWithArrayType + .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) + .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'd === 'out_col)) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } } From 5ea6a4a0717613c17c52039b64d871ec24a268e0 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sun, 3 Mar 2019 19:36:17 -0800 Subject: [PATCH 11/12] Code review --- .../LeftSemiAntiJoinPushDownSuite.scala | 26 +++---------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index b35f4e5baff9..b957ceb69775 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -64,12 +64,7 @@ class LeftSemiPushdownSuite extends PlanTest { .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = testRelation - .select(Rand('a), 'b, 'c) - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) - .analyze - - comparePlans(optimized, correctAnswer) + comparePlans(optimized, originalQuery.analyze) } test("Project: LeftSemiAnti join non correlated scalar subq") { @@ -96,12 +91,7 @@ class LeftSemiPushdownSuite extends PlanTest { .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd)) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = testRelation - .select(subqExpr.as("sum")) - .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd)) - .analyze - - comparePlans(optimized, correctAnswer) + comparePlans(optimized, originalQuery.analyze) } test("Aggregate: LeftSemiAnti join pushdown") { @@ -124,12 +114,7 @@ class LeftSemiPushdownSuite extends PlanTest { .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = testRelation - .groupBy('b)('b, Rand(10).as('c)) - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) - .analyze - - comparePlans(optimized, correctAnswer) + comparePlans(optimized, originalQuery.analyze) } test("Aggregate: LeftSemiAnti join partial pushdown") { @@ -235,7 +220,6 @@ class LeftSemiPushdownSuite extends PlanTest { .join(testRelation2, joinType = LeftSemi, condition = Some('a === 'x)) val optimized = Optimize.execute(originalQuery.analyze) - comparePlans(optimized, originalQuery.analyze) } @@ -292,8 +276,6 @@ class LeftSemiPushdownSuite extends PlanTest { .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'd === 'out_col)) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = originalQuery.analyze - - comparePlans(optimized, correctAnswer) + comparePlans(optimized, originalQuery.analyze) } } From 68e726878f0b6254a44aa3e0e81e2e0d82a2be86 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sun, 3 Mar 2019 19:45:39 -0800 Subject: [PATCH 12/12] code reivew --- .../catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index b957ceb69775..1a0231ed2d99 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -138,9 +138,7 @@ class LeftSemiPushdownSuite extends PlanTest { .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd)) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = originalQuery.analyze - - comparePlans(optimized, correctAnswer) + comparePlans(optimized, originalQuery.analyze) } test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") {