diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 73be7902b998e..7b82817da6e44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -157,6 +157,7 @@ abstract class Optimizer(catalogManager: CatalogManager) ReplaceExpressions, RewriteNonCorrelatedExists, PullOutGroupingExpressions, + PullOutJoinCondition, ComputeCurrentTime, ReplaceCurrentLike(catalogManager), SpecialDatetimeValues, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutJoinCondition.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutJoinCondition.scala new file mode 100644 index 0000000000000..2c0efb8bae3a4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutJoinCondition.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.{Alias, And, EqualTo, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.JOIN + +/** + * This rule ensures that [[Join]] keys doesn't contain complex expressions in the + * optimization phase. + * + * Complex expressions are pulled out to a [[Project]] node under [[Join]] and are + * referenced in join condition. + * + * {{{ + * SELECT * FROM t1 JOIN t2 ON t1.a + 10 = t2.x ==> + * Project [a#0, b#1, x#2, y#3] + * +- Join Inner, ((spark_catalog.default.t1.a + 10)#8 = x#2) + * :- Project [a#0, b#1, (a#0 + 10) AS (spark_catalog.default.t1.a + 10)#8] + * : +- Filter isnotnull((a#0 + 10)) + * : +- Relation default.t1[a#0,b#1] parquet + * +- Filter isnotnull(x#2) + * +- Relation default.t2[x#2,y#3] parquet + * }}} + */ +object PullOutJoinCondition extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(_.containsPattern(JOIN)) { + case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, otherPredicates, _, left, right, _) + if j.resolved => + val complexLeftJoinKeys = new ArrayBuffer[NamedExpression]() + val complexRightJoinKeys = new ArrayBuffer[NamedExpression]() + + val newLeftJoinKeys = leftKeys.map { expr => + if (!expr.foldable && expr.children.nonEmpty) { + val ne = Alias(expr, expr.sql)() + complexLeftJoinKeys += ne + ne.toAttribute + } else { + expr + } + } + + val newRightJoinKeys = rightKeys.map { expr => + if (!expr.foldable && expr.children.nonEmpty) { + val ne = Alias(expr, expr.sql)() + complexRightJoinKeys += ne + ne.toAttribute + } else { + expr + } + } + + if (complexLeftJoinKeys.nonEmpty || complexRightJoinKeys.nonEmpty) { + val newLeft = Project(left.output ++ complexLeftJoinKeys, left) + val newRight = Project(right.output ++ complexRightJoinKeys, right) + val newCond = (newLeftJoinKeys.zip(newRightJoinKeys) + .map { case (l, r) => EqualTo(l, r) } ++ otherPredicates) + .reduceLeftOption(And) + Project(j.output, j.copy(left = newLeft, right = newRight, condition = newCond)) + } else { + j + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutJoinConditionSuite.scala new file mode 100644 index 0000000000000..7f120362779c7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutJoinConditionSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, IsNull, Literal, Substring, Upper} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class PullOutJoinConditionSuite extends PlanTest { + + private object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Pull out join condition", Once, + PullOutJoinCondition, + CollapseProject) :: Nil + } + + private val testRelation = LocalRelation('a.string, 'b.int, 'c.int) + private val testRelation1 = LocalRelation('d.string, 'e.int) + private val x = testRelation.subquery('x) + private val y = testRelation1.subquery('y) + + test("Pull out join keys evaluation(String expressions)") { + Seq(Upper("y.d".attr), Substring("y.d".attr, 1, 5)).foreach { udf => + val originalQuery = x.join(y, condition = Option('a === udf)).select('a, 'e) + val correctAnswer = x.select('a, 'b, 'c) + .join(y.select('d, 'e, Alias(udf, udf.sql)()), + condition = Option('a === s"`${udf.sql}`".attr)).select('a, 'e) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("Pull out join condition contains other predicates") { + val udf = Substring("y.d".attr, 1, 5) + val originalQuery = x.join(y, condition = Option('a === udf && 'b > 'e)).select('a, 'e) + val correctAnswer = x.select('a, 'b, 'c) + .join(y.select('d, 'e, Alias(udf, udf.sql)()), + condition = Option('a === s"`${udf.sql}`".attr && 'b > 'e)).select('a, 'e) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Pull out EqualNullSafe join condition") { + val joinType = Inner + val udf = "x.b".attr + 1 + val coalesce1 = Coalesce(Seq(udf, Literal(0))) + val coalesce2 = Coalesce(Seq("y.e".attr, Literal(0))) + val isNull1 = IsNull(udf) + val isNull2 = IsNull("y.e".attr) + + val originalQuery = x.join(y, joinType, Option(udf <=> 'e)).select('a, 'e) + val correctAnswer = + x.select('a, 'b, 'c, Alias(coalesce1, coalesce1.sql)(), Alias(isNull1, isNull1.sql)()) + .join(y.select('d, 'e, Alias(coalesce2, coalesce2.sql)(), Alias(isNull2, isNull2.sql)()), + condition = Option(s"`${coalesce1.sql}`".attr === s"`${coalesce2.sql}`".attr && + s"`${isNull1.sql}`".attr === s"`${isNull2.sql}`".attr)).select('a, 'e) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Negative case: non-equality join keys") { + val originalQuery = x.join(y, condition = Option("x.b".attr + 1 > 'e)).select('a, 'e) + + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + + test("Negative case: all children are Attributes") { + val originalQuery = x.join(y, condition = Option('a === 'd)) + + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + + test("Negative case: contains Literal") { + val originalQuery = x.join(y, condition = Option('a === "string")) + + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index a803fa88ed313..2144976452fb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -280,9 +280,9 @@ class DataFrameJoinSuite extends QueryTest plan match { // SPARK-34178: we can't match the plan before the fix due to // the right side plan doesn't contains dataset id. - case Join( - LogicalPlanWithDatasetId(_, leftId), - LogicalPlanWithDatasetId(_, rightId), _, _, _) => + case Project(_, Join( + Project(_, LogicalPlanWithDatasetId(_, leftId)), + Project(_, LogicalPlanWithDatasetId(_, rightId)), _, _, _)) => assert(leftId === rightId) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 77493afe43145..10b49a6efc76d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1057,7 +1057,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val pythonEvals = collect(joinNode.get) { case p: BatchEvalPythonExec => p } - assert(pythonEvals.size == 2) + assert(pythonEvals.size == 4) checkAnswer(df, Row(1, 2, 1, 2) :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0b8e10c8916a9..b96602b60030a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,10 +28,10 @@ import org.apache.commons.io.FileUtils import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} -import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions.{Cast, Coalesce, GenericRow, IsNotNull} import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite} -import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, Project, RepartitionByExpression, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalLimit, Project, RepartitionByExpression, Sort} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.{CommandResultExec, UnionExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -4227,6 +4227,31 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(df3, df4) } } + + test("SPARK-36290: Pull out join condition can infer more filter conditions") { + import org.apache.spark.sql.catalyst.dsl.expressions.DslString + + withTable("t1", "t2") { + spark.sql("CREATE TABLE t1(a int, b int) using parquet") + spark.sql("CREATE TABLE t2(a string, b string, c string) using parquet") + + spark.sql("SELECT t1.* FROM t1 RIGHT JOIN t2 ON coalesce(t1.a, t1.b) = t2.a") + .queryExecution.optimizedPlan.find(_.isInstanceOf[Filter]) match { + case Some(Filter(condition, _)) => + condition === IsNotNull(Coalesce(Seq("a".attr, "b".attr))) + case _ => + fail("It should contains Filter") + } + + spark.sql("SELECT t1.* FROM t1 LEFT JOIN t2 ON t1.a = t2.a") + .queryExecution.optimizedPlan.find(_.isInstanceOf[Filter]) match { + case Some(Filter(condition, _)) => + condition === IsNotNull(Cast("a".attr, IntegerType)) + case _ => + fail("It should contains Filter") + } + } + } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 7da813cfdab6f..31cf8e23934ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -143,9 +143,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val schema = new StructType().add("k", IntegerType).add("v", StringType) val smallDF = spark.createDataFrame(rdd, schema) val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id")) - assert(df.queryExecution.executedPlan.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined) + val broadcastHashJoin = df.queryExecution.executedPlan.find { + case WholeStageCodegenExec(ProjectExec(_, _: BroadcastHashJoinExec)) => true + } + assert(broadcastHashJoin.isDefined) assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) } @@ -187,7 +188,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession // test one join with non-unique key from build side val joinNonUniqueDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2" % 3, "full_outer") assert(joinNonUniqueDF.queryExecution.executedPlan.collect { - case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true + case WholeStageCodegenExec(ProjectExec(_, _: ShuffledHashJoinExec)) => true }.size === 1) checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1), Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null))) @@ -196,7 +197,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val joinWithNonEquiDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer") assert(joinWithNonEquiDF.queryExecution.executedPlan.collect { - case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true + case WholeStageCodegenExec(ProjectExec(_, _: ShuffledHashJoinExec)) => true }.size === 1) checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1), Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4),