diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2664fd638062d..c6f436eb75b3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -146,7 +146,9 @@ abstract class Optimizer(catalogManager: CatalogManager) operatorOptimizationRuleSet: _*) :: Batch("Push extra predicate through join", fixedPoint, PushExtraPredicateThroughJoin, - PushDownPredicates) :: Nil + PushDownPredicates) :: + Batch("Pull out complex join condition", Once, + PullOutComplexJoinCondition) :: Nil } val batches = ( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinCondition.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinCondition.scala new file mode 100644 index 0000000000000..eb13e53f8b6d9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinCondition.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, PredicateHelper} +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.JOIN + +/** + * This rule ensures that [[Join]] keys doesn't contain complex expressions in the + * optimization phase. + * + * Complex expressions are pulled out to a [[Project]] node under [[Join]] and are + * referenced in join condition. + * + * {{{ + * SELECT * FROM t1 JOIN t2 ON t1.a + 10 = t2.x ==> + * Project [a#0, b#1, x#2, y#3] + * +- Join Inner, ((spark_catalog.default.t1.a + 10)#8 = x#2) + * :- Project [a#0, b#1, (a#0 + 10) AS (spark_catalog.default.t1.a + 10)#8] + * : +- Filter isnotnull((a#0 + 10)) + * : +- Relation default.t1[a#0,b#1] parquet + * +- Filter isnotnull(x#2) + * +- Relation default.t2[x#2,y#3] parquet + * }}} + * + * This rule should be executed after ReplaceNullWithFalseInPredicate. + */ +object PullOutComplexJoinCondition extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(_.containsPattern(JOIN)) { + case j @ Join(left, right, _, Some(condition), _) if !j.isStreaming => + val complexExps = splitConjunctivePredicates(condition).flatMap { + case p: Expression => p.children.filter(e => !e.foldable && e.children.nonEmpty) + case _ => Nil + } + + val leftComplexExpMap = complexExps.filter(canEvaluate(_, left)) + .map(e => e.canonicalized -> Alias(e, e.sql.take(20))()).toMap + val rightComplexExpMap = complexExps.filter(canEvaluate(_, right)) + .map(e => e.canonicalized -> Alias(e, e.sql.take(20))()).toMap + val allComplexExpMap = leftComplexExpMap ++ rightComplexExpMap + + if (allComplexExpMap.nonEmpty) { + val newCondition = condition.transformDown { + case e: Expression if e.children.nonEmpty && allComplexExpMap.contains(e.canonicalized) => + allComplexExpMap.get(e.canonicalized).map(_.toAttribute).getOrElse(e) + } + val newLeft = Project(left.output ++ leftComplexExpMap.values, left) + val newRight = Project(right.output ++ rightComplexExpMap.values, right) + Project(j.output, j.copy(left = newLeft, right = newRight, condition = Some(newCondition))) + } else { + j + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinConditionSuite.scala new file mode 100644 index 0000000000000..09e7b89ed0624 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinConditionSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, Substring, Upper} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class PullOutComplexJoinConditionSuite extends PlanTest { + + private object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Pull out complex join condition", Once, + PullOutComplexJoinCondition, + CollapseProject) :: Nil + } + + private val testRelation = LocalRelation($"a".string, $"b".int, $"c".int) + private val testRelation1 = LocalRelation($"d".string, $"e".int) + private val x = testRelation.subquery("x") + private val y = testRelation1.subquery("y") + + test("Pull out join keys evaluation(String expressions)") { + Seq(Upper("y.d".attr), Substring("y.d".attr, 1, 5)).foreach { udf => + val originalQuery = x.join(y, condition = Option($"a" === udf)).select($"a", $"e") + val correctAnswer = x.select($"a", $"b", $"c") + .join(y.select($"d", $"e", Alias(udf, udf.sql)()), + condition = Option($"a" === s"`${udf.sql}`".attr)).select($"a", $"e") + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("Pull out join condition contains other predicates") { + val udf = Substring("y.d".attr, 1, 5) + val originalQuery = + x.join(y, condition = Option($"a" === udf && $"b" > $"e")).select($"a", $"e") + val correctAnswer = x.select($"a", $"b", $"c") + .join(y.select($"d", $"e", Alias(udf, udf.sql)()), + condition = Option($"a" === s"`${udf.sql}`".attr && $"b" > $"e")).select($"a", $"e") + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Pull out EqualNullSafe join condition") { + val udf = "x.b".attr + 1 + + val originalQuery = x.join(y, condition = Option(udf <=> $"e")).select($"a", $"e") + val correctAnswer = x.select($"a", $"b", $"c", Alias(udf, udf.sql)()) + .join(y.select($"d", $"e"), condition = Option(s"`${udf.sql}`".attr <=> s"e".attr)) + .select($"a", $"e") + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Pull out non-equality join conditions") { + val udf = "x.b".attr + 1 + val originalQuery = x.join(y, condition = Option(udf > $"e")).select($"a", $"e") + + val correctAnswer = x.select($"a", $"b", $"c", Alias(udf, udf.sql)()) + .join(y.select($"d", $"e"), condition = Option(s"`${udf.sql}`".attr > s"e".attr)) + .select($"a", $"e") + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Negative case: all children are Attributes") { + val originalQuery = x.join(y, condition = Option($"a" === $"d")) + + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + + test("Negative case: contains Literal") { + val originalQuery = x.join(y, condition = Option($"a" === "string")) + + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 6dd34d41cf6c1..c41651cd8f4d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -27,16 +27,16 @@ import org.mockito.Mockito._ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Cast, Coalesce, GenericRow, IsNotNull, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.Filter -import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.BatchEvalPythonExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlanHelper { import testImplicits._ @@ -1057,7 +1057,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val pythonEvals = collect(joinNode.get) { case p: BatchEvalPythonExec => p } - assert(pythonEvals.size == 2) + assert(pythonEvals.size == 4) checkAnswer(df, Row(1, 2, 1, 2) :: Nil) } @@ -1455,4 +1455,29 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan checkAnswer(result1, result2) } } + + test("SPARK-36290: Pull out join condition can infer more filter conditions") { + import org.apache.spark.sql.catalyst.dsl.expressions.DslString + + withTable("t1", "t2") { + spark.sql("CREATE TABLE t1(a int, b int) using parquet") + spark.sql("CREATE TABLE t2(a string, b string, c string) using parquet") + + spark.sql("SELECT t1.* FROM t1 RIGHT JOIN t2 ON coalesce(t1.a, t1.b) = t2.a") + .queryExecution.optimizedPlan.find(_.isInstanceOf[Filter]) match { + case Some(Filter(condition, _)) => + condition === IsNotNull(Coalesce(Seq("a".attr, "b".attr))) + case _ => + fail("It should contains Filter") + } + + spark.sql("SELECT t1.* FROM t1 LEFT JOIN t2 ON t1.a = t2.a") + .queryExecution.optimizedPlan.find(_.isInstanceOf[Filter]) match { + case Some(Filter(condition, _)) => + condition === IsNotNull(Cast("a".attr, IntegerType)) + case _ => + fail("It should contains Filter") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index ac710c3229647..2fd000c531a3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -154,9 +154,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val schema = new StructType().add("k", IntegerType).add("v", StringType) val smallDF = spark.createDataFrame(rdd, schema) val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id")) - assert(df.queryExecution.executedPlan.exists(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec])) + val broadcastHashJoin = df.queryExecution.executedPlan.find { + case WholeStageCodegenExec(ProjectExec(_, _: BroadcastHashJoinExec)) => true + } + assert(broadcastHashJoin.isDefined) assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) } @@ -201,8 +202,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession // test one join with non-unique key from build side val joinNonUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2" % 3, "full_outer") assert(joinNonUniqueDF.queryExecution.executedPlan.collect { - case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true - case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true + case WholeStageCodegenExec(ProjectExec(_, _: ShuffledHashJoinExec)) + if hint == "SHUFFLE_HASH" => true + case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) + if hint == "SHUFFLE_MERGE" => true }.size === 1) checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1), Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null))) @@ -211,8 +214,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val joinWithNonEquiDF = df1.join(df2.hint(hint), $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer") assert(joinWithNonEquiDF.queryExecution.executedPlan.collect { - case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true - case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true + case WholeStageCodegenExec(ProjectExec(_, _: ShuffledHashJoinExec)) + if hint == "SHUFFLE_HASH" => true + case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) + if hint == "SHUFFLE_MERGE" => true }.size === 1) checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1), Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4), @@ -378,7 +383,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession .join(df3, $"k1" <= $"k3", "left_outer") hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect { case WholeStageCodegenExec(BroadcastNestedLoopJoinExec( - _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true + ProjectExec(_, _: BroadcastNestedLoopJoinExec), _, _, _, _)) => true }.size === 1 assert(hasJoinInCodegen == codegenEnabled) checkAnswer(twoJoinsDF,