diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 482c3a3091f86..fa1a57a8ae3a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -396,7 +396,8 @@ case class BroadcastNestedLoopJoinExec( } override def supportCodegen: Boolean = (joinType, buildSide) match { - case (_: InnerLike, _) | (LeftSemi | LeftAnti, BuildRight) => true + case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) | + (LeftSemi | LeftAnti, BuildRight) => true case _ => false } @@ -413,6 +414,7 @@ case class BroadcastNestedLoopJoinExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { (joinType, buildSide) match { case (_: InnerLike, _) => codegenInner(ctx, input) + case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => codegenOuter(ctx, input) case (LeftSemi, BuildRight) => codegenLeftExistence(ctx, input, exists = true) case (LeftAnti, BuildRight) => codegenLeftExistence(ctx, input, exists = false) case _ => @@ -458,6 +460,49 @@ case class BroadcastNestedLoopJoinExec( """.stripMargin } + private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (buildRowArray, buildRowArrayTerm) = prepareBroadcast(ctx) + val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast) + val buildVars = genBuildSideVars(ctx, buildRow, broadcast) + + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + val arrayIndex = ctx.freshName("arrayIndex") + val shouldOutputRow = ctx.freshName("shouldOutputRow") + val foundMatch = ctx.freshName("foundMatch") + val numOutput = metricTerm(ctx, "numOutputRows") + + if (buildRowArray.isEmpty) { + s""" + |UnsafeRow $buildRow = null; + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin + } else { + s""" + |boolean $foundMatch = false; + |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { + | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; + | boolean $shouldOutputRow = false; + | $checkCondition { + | $shouldOutputRow = true; + | $foundMatch = true; + | } + | if ($arrayIndex == $buildRowArrayTerm.length - 1 && !$foundMatch) { + | $buildRow = null; + | $shouldOutputRow = true; + | } + | if ($shouldOutputRow) { + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + |} + """.stripMargin + } + } + private def codegenLeftExistence( ctx: CodegenContext, input: Seq[ExprCode], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 8246bca1893a9..b66308c4f880f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -211,6 +211,53 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } + test("Left/Right outer BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") { + val df1 = spark.range(4).select($"id".as("k1")) + val df2 = spark.range(3).select($"id".as("k2")) + val df3 = spark.range(2).select($"id".as("k3")) + val df4 = spark.range(0).select($"id".as("k4")) + + Seq(true, false).foreach { codegenEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) { + // test left outer join + val leftOuterJoinDF = df1.join(df2, $"k1" > $"k2", "left_outer") + var hasJoinInCodegen = leftOuterJoinDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true + }.size === 1 + assert(hasJoinInCodegen == codegenEnabled) + checkAnswer(leftOuterJoinDF, + Seq(Row(0, null), Row(1, 0), Row(2, 0), Row(2, 1), Row(3, 0), Row(3, 1), Row(3, 2))) + + // test right outer join + val rightOuterJoinDF = df1.join(df2, $"k1" < $"k2", "right_outer") + hasJoinInCodegen = rightOuterJoinDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true + }.size === 1 + assert(hasJoinInCodegen == codegenEnabled) + checkAnswer(rightOuterJoinDF, Seq(Row(null, 0), Row(0, 1), Row(0, 2), Row(1, 2))) + + // test a combination of left outer and right outer joins + val twoJoinsDF = df1.join(df2, $"k1" > $"k2" + 1, "right_outer") + .join(df3, $"k1" <= $"k3", "left_outer") + hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(BroadcastNestedLoopJoinExec( + _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true + }.size === 1 + assert(hasJoinInCodegen == codegenEnabled) + checkAnswer(twoJoinsDF, + Seq(Row(2, 0, null), Row(3, 0, null), Row(3, 1, null), Row(null, 2, null))) + + // test build side is empty + val buildSideIsEmptyDF = df3.join(df4, $"k3" > $"k4", "left_outer") + hasJoinInCodegen = buildSideIsEmptyDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_: BroadcastNestedLoopJoinExec) => true + }.size === 1 + assert(hasJoinInCodegen == codegenEnabled) + checkAnswer(buildSideIsEmptyDF, Seq(Row(0, null), Row(1, null))) + } + } + } + test("Left semi/anti BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") { val df1 = spark.range(4).select($"id".as("k1")) val df2 = spark.range(3).select($"id".as("k2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 150d40d0301fc..810eeea5b9a60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -149,7 +149,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { } } - test(s"$testName using BroadcastNestedLoopJoin build left") { + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin build left") { _ => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition)), @@ -158,7 +158,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { } } - test(s"$testName using BroadcastNestedLoopJoin build right") { + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin build right") { _ => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index dd99368e3a87b..50f980643d2d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -452,11 +452,11 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a" val rightQuery = "SELECT * FROM testData2 RIGHT JOIN testDataForJoin ON " + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a" - Seq((leftQuery, false), (rightQuery, false), (leftQuery, true), (rightQuery, true)) - .foreach { case (query, enableWholeStage) => + Seq((leftQuery, 0L, false), (rightQuery, 0L, false), (leftQuery, 1L, true), + (rightQuery, 1L, true)).foreach { case (query, nodeId, enableWholeStage) => val df = spark.sql(query) testSparkPlanMetrics(df, 2, Map( - 0L -> (("BroadcastNestedLoopJoin", Map( + nodeId -> (("BroadcastNestedLoopJoin", Map( "number of output rows" -> 12L)))), enableWholeStage )