diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index a4ed3d5683185..d4018f8ce3a95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -80,7 +80,7 @@ case class AdaptiveSparkPlanExec( // TODO add more optimization rules override protected def batches: Seq[Batch] = Seq( Batch("Demote BroadcastHashJoin", Once, DemoteBroadcastHashJoin(conf)), - Batch("Eliminate Null Aware Anti Join", Once, EliminateNullAwareAntiJoin) + Batch("Eliminate Join to Empty Relation", Once, EliminateJoinToEmptyRelation) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala new file mode 100644 index 0000000000000..cfdd20ec7565d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin +import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi} +import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.joins.{EmptyHashedRelation, HashedRelation, HashedRelationWithAllNullKeys} + +/** + * This optimization rule detects and converts a Join to an empty [[LocalRelation]]: + * 1. Join is single column NULL-aware anti join (NAAJ), and broadcasted [[HashedRelation]] + * is [[HashedRelationWithAllNullKeys]]. + * + * 2. Join is inner or left semi join, and broadcasted [[HashedRelation]] + * is [[EmptyHashedRelation]]. + * This applies to all Joins (sort merge join, shuffled hash join, and broadcast hash join), + * because sort merge join and shuffled hash join will be changed to broadcast hash join with AQE + * at the first place. + */ +object EliminateJoinToEmptyRelation extends Rule[LogicalPlan] { + + private def canEliminate(plan: LogicalPlan, relation: HashedRelation): Boolean = plan match { + case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined + && stage.broadcast.relationFuture.get().value == relation => true + case _ => false + } + + def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown { + case j @ ExtractSingleColumnNullAwareAntiJoin(_, _) + if canEliminate(j.right, HashedRelationWithAllNullKeys) => + LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) + + case j @ Join(_, _, Inner, _, _) if canEliminate(j.left, EmptyHashedRelation) || + canEliminate(j.right, EmptyHashedRelation) => + LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) + + case j @ Join(_, _, LeftSemi, _, _) if canEliminate(j.right, EmptyHashedRelation) => + LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateNullAwareAntiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateNullAwareAntiJoin.scala deleted file mode 100644 index afccde09040a4..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateNullAwareAntiJoin.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.adaptive - -import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys - -/** - * This optimization rule detects and convert a NAAJ to an Empty LocalRelation - * when buildSide is HashedRelationWithAllNullKeys. - */ -object EliminateNullAwareAntiJoin extends Rule[LogicalPlan] { - - private def canEliminate(plan: LogicalPlan): Boolean = plan match { - case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined - && stage.broadcast.relationFuture.get().value == HashedRelationWithAllNullKeys => true - case _ => false - } - - def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown { - case j @ ExtractSingleColumnNullAwareAntiJoin(_, _) if canEliminate(j.right) => - LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 1a7554c905c6c..085cc29289ddd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -155,7 +155,9 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { val joinRow = new JoinedRow val joinKeys = streamSideKeyGenerator() - if (hashedRelation.keyIsUnique) { + if (hashedRelation == EmptyHashedRelation) { + Iterator.empty + } else if (hashedRelation.keyIsUnique) { streamIter.flatMap { srow => joinRow.withLeft(srow) val matched = hashedRelation.getValue(joinKeys(srow)) @@ -230,7 +232,9 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { val joinKeys = streamSideKeyGenerator() val joinedRow = new JoinedRow - if (hashedRelation.keyIsUnique) { + if (hashedRelation == EmptyHashedRelation) { + Iterator.empty + } else if (hashedRelation.keyIsUnique) { streamIter.filter { current => val key = joinKeys(current) lazy val matched = hashedRelation.getValue(key) @@ -432,7 +436,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { * Generates the code for Inner join. */ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx) + val HashedRelationInfo(relationTerm, keyIsUnique, isEmptyHashedRelation) = prepareRelation(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input) val numOutput = metricTerm(ctx, "numOutputRows") @@ -442,7 +446,11 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { case BuildRight => input ++ buildVars } - if (keyIsUnique) { + if (isEmptyHashedRelation) { + """ + |// If HashedRelation is empty, hash inner join simply returns nothing. + """.stripMargin + } else if (keyIsUnique) { s""" |// generate join key for stream side |${keyEv.code} @@ -559,12 +567,16 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { * Generates the code for left semi join. */ protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx) + val HashedRelationInfo(relationTerm, keyIsUnique, isEmptyHashedRelation) = prepareRelation(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val (matched, checkCondition, _) = getJoinCondition(ctx, input) val numOutput = metricTerm(ctx, "numOutputRows") - if (keyIsUnique) { + if (isEmptyHashedRelation) { + """ + |// If HashedRelation is empty, hash semi join simply returns nothing. + """.stripMargin + } else if (keyIsUnique) { s""" |// generate join key for stream side |${keyEv.code} @@ -612,10 +624,10 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { val numOutput = metricTerm(ctx, "numOutputRows") if (isEmptyHashedRelation) { return s""" - |// If the right side is empty, Anti Join simply returns the left side. + |// If HashedRelation is empty, hash anti join simply returns the stream side. |$numOutput.add(1); |${consume(ctx, input)} - |""".stripMargin + """.stripMargin } val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e7629a21f787a..93cd84713296b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.Filter -import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec, SparkPlan} +import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins._ @@ -1254,4 +1254,56 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan } } } + + test("SPARK-32649: Optimize BHJ/SHJ inner/semi join with empty hashed relation") { + val inputDFs = Seq( + // Test empty build side for inner join + (spark.range(30).selectExpr("id as k1"), + spark.range(10).selectExpr("id as k2").filter("k2 < -1"), + "inner"), + // Test empty build side for semi join + (spark.range(30).selectExpr("id as k1"), + spark.range(10).selectExpr("id as k2").filter("k2 < -1"), + "semi") + ) + inputDFs.foreach { case (df1, df2, joinType) => + // Test broadcast hash join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200") { + val bhjCodegenDF = df1.join(df2, $"k1" === $"k2", joinType) + assert(bhjCodegenDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : BroadcastHashJoinExec) => true + case WholeStageCodegenExec(ProjectExec(_, _ : BroadcastHashJoinExec)) => true + }.size === 1) + checkAnswer(bhjCodegenDF, Seq.empty) + + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val bhjNonCodegenDF = df1.join(df2, $"k1" === $"k2", joinType) + assert(bhjNonCodegenDF.queryExecution.executedPlan.collect { + case _: BroadcastHashJoinExec => true }.size === 1) + checkAnswer(bhjNonCodegenDF, Seq.empty) + } + } + + // Test shuffled hash join + withSQLConf(SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + // Set broadcast join threshold and number of shuffle partitions, + // as shuffled hash join depends on these two configs. + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50", + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + val shjCodegenDF = df1.join(df2, $"k1" === $"k2", joinType) + assert(shjCodegenDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true + case WholeStageCodegenExec(ProjectExec(_, _ : ShuffledHashJoinExec)) => true + }.size === 1) + checkAnswer(shjCodegenDF, Seq.empty) + + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val shjNonCodegenDF = df1.join(df2, $"k1" === $"k2", joinType) + assert(shjNonCodegenDF.queryExecution.executedPlan.collect { + case _: ShuffledHashJoinExec => true }.size === 1) + checkAnswer(shjNonCodegenDF, Seq.empty) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 1dc239c0416f8..3bd079cf65433 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -226,7 +226,8 @@ class AdaptiveQueryExecSuite val df1 = spark.range(10).withColumn("a", 'id) val df2 = spark.range(10).withColumn("b", 'id) withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val testDf = df1.where('a > 10).join(df2.where('b > 10), "id").groupBy('a).count() + val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") + .groupBy('a).count() checkAnswer(testDf, Seq()) val plan = testDf.queryExecution.executedPlan assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) @@ -238,7 +239,8 @@ class AdaptiveQueryExecSuite } withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { - val testDf = df1.where('a > 10).join(df2.where('b > 10), "id").groupBy('a).count() + val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") + .groupBy('a).count() checkAnswer(testDf, Seq()) val plan = testDf.queryExecution.executedPlan assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) @@ -1181,4 +1183,26 @@ class AdaptiveQueryExecSuite checkNumLocalShuffleReaders(adaptivePlan) } } + + test("SPARK-32649: Eliminate inner and semi join to empty relation") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + Seq( + // inner join (small table at right side) + "SELECT * FROM testData t1 join testData3 t2 ON t1.key = t2.a WHERE t2.b = 1", + // inner join (small table at left side) + "SELECT * FROM testData3 t1 join testData t2 ON t1.a = t2.key WHERE t1.b = 1", + // left semi join + "SELECT * FROM testData t1 left semi join testData3 t2 ON t1.key = t2.a AND t2.b = 1" + ).foreach(query => { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val join = findTopLevelBaseJoin(adaptivePlan) + assert(join.isEmpty) + checkNumLocalShuffleReaders(adaptivePlan) + }) + } + } }