Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ case class AdaptiveSparkPlanExec(
// TODO add more optimization rules
override protected def batches: Seq[Batch] = Seq(
Batch("Demote BroadcastHashJoin", Once, DemoteBroadcastHashJoin(conf)),
Batch("Eliminate Null Aware Anti Join", Once, EliminateNullAwareAntiJoin)
Batch("Eliminate Join to Empty Relation", Once, EliminateJoinToEmptyRelation)
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin
import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.joins.{EmptyHashedRelation, HashedRelation, HashedRelationWithAllNullKeys}

/**
* This optimization rule detects and converts a Join to an empty [[LocalRelation]]:
* 1. Join is single column NULL-aware anti join (NAAJ), and broadcasted [[HashedRelation]]
* is [[HashedRelationWithAllNullKeys]].
*
* 2. Join is inner or left semi join, and broadcasted [[HashedRelation]]
* is [[EmptyHashedRelation]].
* This applies to all Joins (sort merge join, shuffled hash join, and broadcast hash join),
* because sort merge join and shuffled hash join will be changed to broadcast hash join with AQE
* at the first place.
*/
object EliminateJoinToEmptyRelation extends Rule[LogicalPlan] {

private def canEliminate(plan: LogicalPlan, relation: HashedRelation): Boolean = plan match {
case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined
&& stage.broadcast.relationFuture.get().value == relation => true
case _ => false
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
case j @ ExtractSingleColumnNullAwareAntiJoin(_, _)
if canEliminate(j.right, HashedRelationWithAllNullKeys) =>
LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)

case j @ Join(_, _, Inner, _, _) if canEliminate(j.left, EmptyHashedRelation) ||
canEliminate(j.right, EmptyHashedRelation) =>
LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)

case j @ Join(_, _, LeftSemi, _, _) if canEliminate(j.right, EmptyHashedRelation) =>
LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
val joinRow = new JoinedRow
val joinKeys = streamSideKeyGenerator()

if (hashedRelation.keyIsUnique) {
if (hashedRelation == EmptyHashedRelation) {
Iterator.empty
} else if (hashedRelation.keyIsUnique) {
streamIter.flatMap { srow =>
joinRow.withLeft(srow)
val matched = hashedRelation.getValue(joinKeys(srow))
Expand Down Expand Up @@ -230,7 +232,9 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
val joinKeys = streamSideKeyGenerator()
val joinedRow = new JoinedRow

if (hashedRelation.keyIsUnique) {
if (hashedRelation == EmptyHashedRelation) {
Iterator.empty
} else if (hashedRelation.keyIsUnique) {
streamIter.filter { current =>
val key = joinKeys(current)
lazy val matched = hashedRelation.getValue(key)
Expand Down Expand Up @@ -432,7 +436,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
* Generates the code for Inner join.
*/
protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx)
val HashedRelationInfo(relationTerm, keyIsUnique, isEmptyHashedRelation) = prepareRelation(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
val numOutput = metricTerm(ctx, "numOutputRows")
Expand All @@ -442,7 +446,11 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
case BuildRight => input ++ buildVars
}

if (keyIsUnique) {
if (isEmptyHashedRelation) {
"""
|// If HashedRelation is empty, hash inner join simply returns nothing.
""".stripMargin
} else if (keyIsUnique) {
s"""
|// generate join key for stream side
|${keyEv.code}
Expand Down Expand Up @@ -559,12 +567,16 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
* Generates the code for left semi join.
*/
protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx)
val HashedRelationInfo(relationTerm, keyIsUnique, isEmptyHashedRelation) = prepareRelation(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val (matched, checkCondition, _) = getJoinCondition(ctx, input)
val numOutput = metricTerm(ctx, "numOutputRows")

if (keyIsUnique) {
if (isEmptyHashedRelation) {
"""
|// If HashedRelation is empty, hash semi join simply returns nothing.
""".stripMargin
} else if (keyIsUnique) {
s"""
|// generate join key for stream side
|${keyEv.code}
Expand Down Expand Up @@ -612,10 +624,10 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
val numOutput = metricTerm(ctx, "numOutputRows")
if (isEmptyHashedRelation) {
return s"""
|// If the right side is empty, Anti Join simply returns the left side.
|// If HashedRelation is empty, hash anti join simply returns the stream side.
|$numOutput.add(1);
|${consume(ctx, input)}
|""".stripMargin
""".stripMargin
}

val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
Expand Down
54 changes: 53 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec, SparkPlan}
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins._
Expand Down Expand Up @@ -1254,4 +1254,56 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
}
}
}

test("SPARK-32649: Optimize BHJ/SHJ inner/semi join with empty hashed relation") {
val inputDFs = Seq(
// Test empty build side for inner join
(spark.range(30).selectExpr("id as k1"),
spark.range(10).selectExpr("id as k2").filter("k2 < -1"),
"inner"),
// Test empty build side for semi join
(spark.range(30).selectExpr("id as k1"),
spark.range(10).selectExpr("id as k2").filter("k2 < -1"),
"semi")
)
inputDFs.foreach { case (df1, df2, joinType) =>
// Test broadcast hash join
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200") {
val bhjCodegenDF = df1.join(df2, $"k1" === $"k2", joinType)
assert(bhjCodegenDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : BroadcastHashJoinExec) => true
case WholeStageCodegenExec(ProjectExec(_, _ : BroadcastHashJoinExec)) => true
}.size === 1)
checkAnswer(bhjCodegenDF, Seq.empty)

withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val bhjNonCodegenDF = df1.join(df2, $"k1" === $"k2", joinType)
assert(bhjNonCodegenDF.queryExecution.executedPlan.collect {
case _: BroadcastHashJoinExec => true }.size === 1)
checkAnswer(bhjNonCodegenDF, Seq.empty)
}
}

// Test shuffled hash join
withSQLConf(SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
// Set broadcast join threshold and number of shuffle partitions,
// as shuffled hash join depends on these two configs.
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50",
SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
val shjCodegenDF = df1.join(df2, $"k1" === $"k2", joinType)
assert(shjCodegenDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
case WholeStageCodegenExec(ProjectExec(_, _ : ShuffledHashJoinExec)) => true
}.size === 1)
checkAnswer(shjCodegenDF, Seq.empty)

withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val shjNonCodegenDF = df1.join(df2, $"k1" === $"k2", joinType)
assert(shjNonCodegenDF.queryExecution.executedPlan.collect {
case _: ShuffledHashJoinExec => true }.size === 1)
checkAnswer(shjNonCodegenDF, Seq.empty)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ class AdaptiveQueryExecSuite
val df1 = spark.range(10).withColumn("a", 'id)
val df2 = spark.range(10).withColumn("b", 'id)
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val testDf = df1.where('a > 10).join(df2.where('b > 10), "id").groupBy('a).count()
val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer")
.groupBy('a).count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined)
Expand All @@ -238,7 +239,8 @@ class AdaptiveQueryExecSuite
}

withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
val testDf = df1.where('a > 10).join(df2.where('b > 10), "id").groupBy('a).count()
val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer")
.groupBy('a).count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - just fyi. the change in this unit test is needed as assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) no long true, because this is an inner join and the build side is empty. So with the change in this PR, the join operator is optimized into an empty relation operator (failure stack trace of unit test without change is here).

Changed from inner join to left outer join, to help unit test pass. And I don't think changing from inner join to left outer join here can comprise any functionality of original unit test. Let me know if it's not the case. thanks.

Expand Down Expand Up @@ -1181,4 +1183,26 @@ class AdaptiveQueryExecSuite
checkNumLocalShuffleReaders(adaptivePlan)
}
}

test("SPARK-32649: Eliminate inner and semi join to empty relation") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
Seq(
// inner join (small table at right side)
"SELECT * FROM testData t1 join testData3 t2 ON t1.key = t2.a WHERE t2.b = 1",
// inner join (small table at left side)
"SELECT * FROM testData3 t1 join testData t2 ON t1.a = t2.key WHERE t1.b = 1",
// left semi join
"SELECT * FROM testData t1 left semi join testData3 t2 ON t1.key = t2.a AND t2.b = 1"
).foreach(query => {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
val smj = findTopLevelSortMergeJoin(plan)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SortMergeJoin? I think this targets BHJ and SHJ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya - similar to other test cases in this file - the input data stats is super large, and by default it uses SMJ. Per my comment above, SMJ/BHJ/SHJ will all turn into empty LocalRelation (where SMJ/SHJ first turned into BHJ).

assert(smj.size == 1)
val join = findTopLevelBaseJoin(adaptivePlan)
assert(join.isEmpty)
checkNumLocalShuffleReaders(adaptivePlan)
})
}
}
}