From 93947ab052af802aa861aa7d0af7bbaa9efd6f2c Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Fri, 29 May 2020 20:22:24 -0700 Subject: [PATCH 01/14] initial checkin --- .../joins/BroadcastHashJoinExec.scala | 17 +++- .../adaptive/AdaptiveQueryExecSuite.scala | 13 ++- .../execution/joins/BroadcastJoinSuite.scala | 92 ++++++++++++++++++- 3 files changed, 115 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 08128d8f69dab..bb460a56593ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{BooleanType, LongType} @@ -59,6 +59,21 @@ case class BroadcastHashJoinExec( } } + override def outputPartitioning: Partitioning = { + def buildKeys: Seq[Expression] = buildSide match { + case BuildLeft => leftKeys + case BuildRight => rightKeys + } + + streamedPlan.outputPartitioning match { + case h: HashPartitioning => + PartitioningCollection(Seq(h, HashPartitioning(buildKeys, h.numPartitions))) + case c: PartitioningCollection if c.partitionings.forall(_.isInstanceOf[HashPartitioning]) => + PartitioningCollection(c.partitionings :+ HashPartitioning(buildKeys, c.numPartitions)) + case other => other + } + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index a80fc410f5033..e4566937d6993 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -539,17 +539,20 @@ class AdaptiveQueryExecSuite } test("Avoid plan change if cost is greater") { + val testData2 = spark.table("testData2") + val newTestData2 = testData2.withColumn("c", testData2("a")) + newTestData2.createTempView("newTestData2") val origPlan = sql("SELECT * FROM testData " + - "join testData2 t2 ON key = t2.a " + - "join testData2 t3 on t2.a = t3.a where t2.b = 1").queryExecution.executedPlan + "join newTestData2 t2 ON key = t2.a " + + "join testData2 t3 on t2.c = t3.a where t2.b = 1").queryExecution.executedPlan withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM testData " + - "join testData2 t2 ON key = t2.a " + - "join testData2 t3 on t2.a = t3.a where t2.b = 1") + "join newTestData2 t2 ON key = t2.a " + + "join testData2 t3 on t2.c = t3.a where t2.b = 1") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 2) val smj2 = findTopLevelSortMergeJoin(adaptivePlan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 1be9308c06d8c..3029326d71500 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -23,10 +23,11 @@ import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} import org.apache.spark.sql.catalyst.plans.logical.BROADCAST +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection} import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AdaptiveTestUtils, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec -import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -414,6 +415,95 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils AdaptiveTestUtils.assertExceptionMessage(e, s"Could not execute broadcast in $timeout secs.") } } + + test("broadcast join where streamed side's output partitioning is PartitioningCollection") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { + val t1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") + val t2 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i2", "j2") + val t3 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i3", "j3") + val t4 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i4", "j4") + + // join1 is a sort merge join (shuffle on the both sides). + val join1 = t1.join(t2, t1("i1") === t2("i2")) + val plan1 = join1.queryExecution.executedPlan + assert(plan1.collect { case s: SortMergeJoinExec => s }.size == 1) + assert(plan1.collect { case e: ShuffleExchangeExec => e }.size == 2) + + // join2 is a broadcast join where t3 is broadcasted. Note that output partitioning on the + // streamed side (join1) is PartitioningCollection (sort merge join) + val join2 = join1.join(t3, join1("i1") === t3("i3")) + val plan2 = join2.queryExecution.executedPlan + assert(plan2.collect { case s: SortMergeJoinExec => s }.size == 1) + assert(plan2.collect { case e: ShuffleExchangeExec => e }.size == 2) + val broadcastJoins = plan2.collect { case b: BroadcastHashJoinExec => b } + assert(broadcastJoins.size == 1) + broadcastJoins(0).outputPartitioning match { + case p: PartitioningCollection + if p.partitionings.forall(_.isInstanceOf[HashPartitioning]) => + // two partitionings from sort merge join and one from build side. + assert(p.partitionings.size == 3) + case _ => fail() + } + + // Join on the column from the broadcasted side (i3) and make sure output partitioning + // is maintained by checking no shuffle exchange is introduced. Note that one extra + // ShuffleExchangeExec is from t4, not from join2. + val join3 = join2.join(t4, join2("i3") === t4("i4")) + val plan3 = join3.queryExecution.executedPlan + assert(plan3.collect { case s: SortMergeJoinExec => s }.size == 2) + assert(plan3.collect { case b: BroadcastHashJoinExec => b }.size == 1) + assert(plan3.collect { case e: ShuffleExchangeExec => e }.size == 3) + + // Validate the data with boradcast join off. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = join2.join(t4, join2("i3") === t4("i4")) + QueryTest.sameRows(join3.collect().toSeq, df.collect().toSeq) + } + } + } + + test("broadcast join where streamed side's output partitioning is HashPartitioning") { + withTable("t1", "t3") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { + val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") + val df2 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i2", "j2") + val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3") + df1.write.format("parquet").bucketBy(8, "i1").saveAsTable("t1") + df3.write.format("parquet").bucketBy(8, "i3").saveAsTable("t3") + val t1 = spark.table("t1") + val t3 = spark.table("t3") + + // join2 is a broadcast join where df2 is broadcasted. Note that output partitioning on the + // streamed side (t1) is HashPartitioning (bucketed files). + val join1 = t1.join(df2, t1("i1") === df2("i2")) + val plan1 = join1.queryExecution.executedPlan + assert(plan1.collect { case e: ShuffleExchangeExec => e }.isEmpty) + val broadcastJoins = plan1.collect { case b: BroadcastHashJoinExec => b } + assert(broadcastJoins.size == 1) + broadcastJoins(0).outputPartitioning match { + case p: PartitioningCollection + if p.partitionings.forall(_.isInstanceOf[HashPartitioning]) => + // one partitioning from streamed side and one from build side. + assert(p.partitionings.size == 2) + case _ => fail() + } + + // Join on the column from the broadcasted side (i2) and make sure output partitioning + // is maintained by checking no shuffle exchange is introduced. + val join2 = join1.join(t3, join1("i2") === t3("i3")) + val plan2 = join2.queryExecution.executedPlan + assert(plan2.collect { case s: SortMergeJoinExec => s }.size == 1) + assert(plan2.collect { case b: BroadcastHashJoinExec => b }.size == 1) + assert(plan2.collect { case e: ShuffleExchangeExec => e }.isEmpty) + + // Validate the data with boradcast join off. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = join1.join(t3, join1("i2") === t3("i3")) + QueryTest.sameRows(join2.collect().toSeq, df.collect().toSeq) + } + } + } + } } class BroadcastJoinSuite extends BroadcastJoinSuiteBase with DisableAdaptiveExecutionSuite From 985834bc612ec41c36360b6ca02d1e32a364fbd3 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Fri, 29 May 2020 21:03:49 -0700 Subject: [PATCH 02/14] update comment --- .../apache/spark/sql/execution/joins/BroadcastJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index a1ed7c9463d39..72177d1ba3061 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -474,7 +474,7 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils val t1 = spark.table("t1") val t3 = spark.table("t3") - // join2 is a broadcast join where df2 is broadcasted. Note that output partitioning on the + // join1 is a broadcast join where df2 is broadcasted. Note that output partitioning on the // streamed side (t1) is HashPartitioning (bucketed files). val join1 = t1.join(df2, t1("i1") === df2("i2")) val plan1 = join1.queryExecution.executedPlan From 683a70528e4cfdf79e4ef9596e728c768c86afa2 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Sat, 30 May 2020 22:12:17 -0700 Subject: [PATCH 03/14] fix BroadcastJoinSuite tests, address comment on join type. --- .../joins/BroadcastHashJoinExec.scala | 17 +++++++----- .../adaptive/AdaptiveQueryExecSuite.scala | 15 +++++------ .../execution/joins/BroadcastJoinSuite.scala | 26 +++++++++---------- 3 files changed, 30 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 7079f012f324c..90ef7547f78a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -66,12 +66,17 @@ case class BroadcastHashJoinExec( case BuildRight => rightKeys } - streamedPlan.outputPartitioning match { - case h: HashPartitioning => - PartitioningCollection(Seq(h, HashPartitioning(buildKeys, h.numPartitions))) - case c: PartitioningCollection if c.partitionings.forall(_.isInstanceOf[HashPartitioning]) => - PartitioningCollection(c.partitionings :+ HashPartitioning(buildKeys, c.numPartitions)) - case other => other + joinType match { + case _: InnerLike => + streamedPlan.outputPartitioning match { + case h: HashPartitioning => + PartitioningCollection(Seq(h, HashPartitioning(buildKeys, h.numPartitions))) + case c: PartitioningCollection + if c.partitionings.forall(_.isInstanceOf[HashPartitioning]) => + PartitioningCollection(c.partitionings :+ HashPartitioning(buildKeys, c.numPartitions)) + case other => other + } + case _ => streamedPlan.outputPartitioning } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 2a7e94b9f10e4..1449f33304946 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -540,24 +540,21 @@ class AdaptiveQueryExecSuite } test("Avoid plan change if cost is greater") { - val testData2 = spark.table("testData2") - val newTestData2 = testData2.withColumn("c", testData2("a")) - newTestData2.createTempView("newTestData2") val origPlan = sql("SELECT * FROM testData " + - "join newTestData2 t2 ON key = t2.a " + - "join testData2 t3 on t2.c = t3.a where t2.b = 1").queryExecution.executedPlan + "join testData2 t2 ON key = t2.a " + + "join testData2 t3 on t2.a = t3.a where t2.b = 1").queryExecution.executedPlan withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100") { + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM testData " + - "join newTestData2 t2 ON key = t2.a " + - "join testData2 t3 on t2.c = t3.a where t2.b = 1") + "join testData2 t2 ON key = t2.a " + + "join testData2 t3 on t2.a = t3.a where t2.b = 1") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 2) val smj2 = findTopLevelSortMergeJoin(adaptivePlan) - assert(smj2.size == 2, origPlan.toString) + assert(smj2.size == 1, origPlan.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 72177d1ba3061..27cf245a97308 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -427,16 +427,16 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils // join1 is a sort merge join (shuffle on the both sides). val join1 = t1.join(t2, t1("i1") === t2("i2")) val plan1 = join1.queryExecution.executedPlan - assert(plan1.collect { case s: SortMergeJoinExec => s }.size == 1) - assert(plan1.collect { case e: ShuffleExchangeExec => e }.size == 2) + assert(collect(plan1) { case s: SortMergeJoinExec => s }.size == 1) + assert(collect(plan1) { case e: ShuffleExchangeExec => e }.size == 2) // join2 is a broadcast join where t3 is broadcasted. Note that output partitioning on the // streamed side (join1) is PartitioningCollection (sort merge join) val join2 = join1.join(t3, join1("i1") === t3("i3")) val plan2 = join2.queryExecution.executedPlan - assert(plan2.collect { case s: SortMergeJoinExec => s }.size == 1) - assert(plan2.collect { case e: ShuffleExchangeExec => e }.size == 2) - val broadcastJoins = plan2.collect { case b: BroadcastHashJoinExec => b } + assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1) + assert(collect(plan2) { case e: ShuffleExchangeExec => e }.size == 2) + val broadcastJoins = collect(plan2) { case b: BroadcastHashJoinExec => b } assert(broadcastJoins.size == 1) broadcastJoins(0).outputPartitioning match { case p: PartitioningCollection @@ -451,9 +451,9 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils // ShuffleExchangeExec is from t4, not from join2. val join3 = join2.join(t4, join2("i3") === t4("i4")) val plan3 = join3.queryExecution.executedPlan - assert(plan3.collect { case s: SortMergeJoinExec => s }.size == 2) - assert(plan3.collect { case b: BroadcastHashJoinExec => b }.size == 1) - assert(plan3.collect { case e: ShuffleExchangeExec => e }.size == 3) + assert(collect(plan3) { case s: SortMergeJoinExec => s }.size == 2) + assert(collect(plan3) { case b: BroadcastHashJoinExec => b }.size == 1) + assert(collect(plan3) { case e: ShuffleExchangeExec => e }.size == 3) // Validate the data with boradcast join off. withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { @@ -478,8 +478,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils // streamed side (t1) is HashPartitioning (bucketed files). val join1 = t1.join(df2, t1("i1") === df2("i2")) val plan1 = join1.queryExecution.executedPlan - assert(plan1.collect { case e: ShuffleExchangeExec => e }.isEmpty) - val broadcastJoins = plan1.collect { case b: BroadcastHashJoinExec => b } + assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty) + val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b } assert(broadcastJoins.size == 1) broadcastJoins(0).outputPartitioning match { case p: PartitioningCollection @@ -493,9 +493,9 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils // is maintained by checking no shuffle exchange is introduced. val join2 = join1.join(t3, join1("i2") === t3("i3")) val plan2 = join2.queryExecution.executedPlan - assert(plan2.collect { case s: SortMergeJoinExec => s }.size == 1) - assert(plan2.collect { case b: BroadcastHashJoinExec => b }.size == 1) - assert(plan2.collect { case e: ShuffleExchangeExec => e }.isEmpty) + assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1) + assert(collect(plan2) { case b: BroadcastHashJoinExec => b }.size == 1) + assert(collect(plan2) { case e: ShuffleExchangeExec => e }.isEmpty) // Validate the data with boradcast join off. withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { From 488e051e1a7c21c57b646d9f68df8c48e4717126 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Thu, 25 Jun 2020 20:44:52 -0700 Subject: [PATCH 04/14] Add more checks --- .../joins/BroadcastHashJoinExec.scala | 82 +++++++++++++++++-- .../execution/joins/BroadcastJoinSuite.scala | 14 ++-- 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 90ef7547f78a6..5646efab59b40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.joins +import scala.collection.mutable + import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -61,25 +63,91 @@ case class BroadcastHashJoinExec( } override def outputPartitioning: Partitioning = { - def buildKeys: Seq[Expression] = buildSide match { - case BuildLeft => leftKeys - case BuildRight => rightKeys + val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) } joinType match { case _: InnerLike => streamedPlan.outputPartitioning match { case h: HashPartitioning => - PartitioningCollection(Seq(h, HashPartitioning(buildKeys, h.numPartitions))) - case c: PartitioningCollection - if c.partitionings.forall(_.isInstanceOf[HashPartitioning]) => - PartitioningCollection(c.partitionings :+ HashPartitioning(buildKeys, c.numPartitions)) + getBuildSidePartitioning(h, streamedKeys, buildKeys) match { + case Some(p) => PartitioningCollection(Seq(h, p)) + case None => h + } + case c: PartitioningCollection => + c.partitionings.foreach { + case h: HashPartitioning => + getBuildSidePartitioning(h, streamedKeys, buildKeys) match { + case Some(p) => return PartitioningCollection(c.partitionings :+ p) + case None => () + } + case _ => () + } + c case other => other } case _ => streamedPlan.outputPartitioning } } + /** + * Returns a partitioning for the build side if the following conditions are met: + * - The streamed side's output partitioning expressions consist of all the keys + * from the streamed side, we can add a partitioning for the build side. + * - There is a one-to-one mapping from streamed keys to build keys. + * + * The build side partitioning will have expressions in the same order as the expressions + * in the streamed side partitioning. For example, for the following setup: + * - streamed partitioning expressions: Seq(s1, s2) + * - streamed keys: Seq(c1, c2) + * - build keys: Seq(b1, b2) + * the expressions in the build side partitioning will be Seq(b1, b2), not Seq(b2, b1). + */ + private def getBuildSidePartitioning( + streamedPartitioning: HashPartitioning, + streamedKeys: Seq[Expression], + buildKeys: Seq[Expression]): Option[HashPartitioning] = { + if (!satisfiesPartitioning(streamedKeys, streamedPartitioning)) { + return None + } + + val streamedKeyToBuildKeyMap = mutable.Map.empty[Expression, Expression] + streamedKeys.zip(buildKeys).foreach { + case (streamedKey, buildKey) => + val inserted = streamedKeyToBuildKeyMap.getOrElseUpdate( + streamedKey.canonicalized, + buildKey) + + if (!inserted.semanticEquals(buildKey)) { + // One-to-many mapping from streamed keys to build keys found. + return None + } + } + + // Ensure the one-to-one mapping from streamed keys to build keys. + if (streamedKeyToBuildKeyMap.size != streamedKeyToBuildKeyMap.values.toSet.size) { + return None + } + + // The final expressions are built by mapping stream partitioning expressions -> + // streamed keys -> build keys. + val buildPartitioningExpressions = streamedPartitioning.expressions.map { e => + streamedKeyToBuildKeyMap(e.canonicalized) + } + + Some(HashPartitioning(buildPartitioningExpressions, streamedPartitioning.numPartitions)) + } + + // Returns true if `keys` consist of all the expressions in `partitioning`. + private def satisfiesPartitioning( + keys: Seq[Expression], + partitioning: HashPartitioning): Boolean = { + partitioning.expressions.length == keys.length && + partitioning.expressions.forall(e => keys.exists(_.semanticEquals(e))) + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 27cf245a97308..ea20adbe8ebc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -469,14 +469,14 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") val df2 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i2", "j2") val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3") - df1.write.format("parquet").bucketBy(8, "i1").saveAsTable("t1") - df3.write.format("parquet").bucketBy(8, "i3").saveAsTable("t3") + df1.write.format("parquet").bucketBy(8, "i1", "j1").saveAsTable("t1") + df3.write.format("parquet").bucketBy(8, "i3", "j3").saveAsTable("t3") val t1 = spark.table("t1") val t3 = spark.table("t3") // join1 is a broadcast join where df2 is broadcasted. Note that output partitioning on the // streamed side (t1) is HashPartitioning (bucketed files). - val join1 = t1.join(df2, t1("i1") === df2("i2")) + val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") === df2("j2")) val plan1 = join1.queryExecution.executedPlan assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty) val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b } @@ -489,17 +489,17 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils case _ => fail() } - // Join on the column from the broadcasted side (i2) and make sure output partitioning + // Join on the column from the broadcasted side (i2, j2) and make sure output partitioning // is maintained by checking no shuffle exchange is introduced. - val join2 = join1.join(t3, join1("i2") === t3("i3")) + val join2 = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) val plan2 = join2.queryExecution.executedPlan assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1) assert(collect(plan2) { case b: BroadcastHashJoinExec => b }.size == 1) assert(collect(plan2) { case e: ShuffleExchangeExec => e }.isEmpty) - // Validate the data with boradcast join off. + // Validate the data with broadcast join off. withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df = join1.join(t3, join1("i2") === t3("i3")) + val df = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) QueryTest.sameRows(join2.collect().toSeq, df.collect().toSeq) } } From cac382976a33d7fb8b43790b2a554ab2ed0daf63 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Mon, 29 Jun 2020 22:12:52 -0700 Subject: [PATCH 05/14] Revert back to previous impl. --- .../joins/BroadcastHashJoinExec.scala | 80 ++----------------- 1 file changed, 7 insertions(+), 73 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 5646efab59b40..685321bf8e62a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -63,91 +63,25 @@ case class BroadcastHashJoinExec( } override def outputPartitioning: Partitioning = { - val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) + def buildKeys: Seq[Expression] = buildSide match { + case BuildLeft => leftKeys + case BuildRight => rightKeys } joinType match { case _: InnerLike => streamedPlan.outputPartitioning match { case h: HashPartitioning => - getBuildSidePartitioning(h, streamedKeys, buildKeys) match { - case Some(p) => PartitioningCollection(Seq(h, p)) - case None => h - } - case c: PartitioningCollection => - c.partitionings.foreach { - case h: HashPartitioning => - getBuildSidePartitioning(h, streamedKeys, buildKeys) match { - case Some(p) => return PartitioningCollection(c.partitionings :+ p) - case None => () - } - case _ => () - } - c + PartitioningCollection(Seq(h, HashPartitioning(buildKeys, h.numPartitions))) + case c: PartitioningCollection + if c.partitionings.forall(_.isInstanceOf[HashPartitioning]) => + PartitioningCollection(c.partitionings :+ HashPartitioning(buildKeys, c.numPartitions)) case other => other } case _ => streamedPlan.outputPartitioning } } - /** - * Returns a partitioning for the build side if the following conditions are met: - * - The streamed side's output partitioning expressions consist of all the keys - * from the streamed side, we can add a partitioning for the build side. - * - There is a one-to-one mapping from streamed keys to build keys. - * - * The build side partitioning will have expressions in the same order as the expressions - * in the streamed side partitioning. For example, for the following setup: - * - streamed partitioning expressions: Seq(s1, s2) - * - streamed keys: Seq(c1, c2) - * - build keys: Seq(b1, b2) - * the expressions in the build side partitioning will be Seq(b1, b2), not Seq(b2, b1). - */ - private def getBuildSidePartitioning( - streamedPartitioning: HashPartitioning, - streamedKeys: Seq[Expression], - buildKeys: Seq[Expression]): Option[HashPartitioning] = { - if (!satisfiesPartitioning(streamedKeys, streamedPartitioning)) { - return None - } - - val streamedKeyToBuildKeyMap = mutable.Map.empty[Expression, Expression] - streamedKeys.zip(buildKeys).foreach { - case (streamedKey, buildKey) => - val inserted = streamedKeyToBuildKeyMap.getOrElseUpdate( - streamedKey.canonicalized, - buildKey) - - if (!inserted.semanticEquals(buildKey)) { - // One-to-many mapping from streamed keys to build keys found. - return None - } - } - - // Ensure the one-to-one mapping from streamed keys to build keys. - if (streamedKeyToBuildKeyMap.size != streamedKeyToBuildKeyMap.values.toSet.size) { - return None - } - - // The final expressions are built by mapping stream partitioning expressions -> - // streamed keys -> build keys. - val buildPartitioningExpressions = streamedPartitioning.expressions.map { e => - streamedKeyToBuildKeyMap(e.canonicalized) - } - - Some(HashPartitioning(buildPartitioningExpressions, streamedPartitioning.numPartitions)) - } - - // Returns true if `keys` consist of all the expressions in `partitioning`. - private def satisfiesPartitioning( - keys: Seq[Expression], - partitioning: HashPartitioning): Boolean = { - partitioning.expressions.length == keys.length && - partitioning.expressions.forall(e => keys.exists(_.semanticEquals(e))) - } - protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") From febc4025564b9f36ba397a3921b3a216ce6221bd Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Mon, 29 Jun 2020 22:13:48 -0700 Subject: [PATCH 06/14] remove import --- .../spark/sql/execution/joins/BroadcastHashJoinExec.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 685321bf8e62a..90ef7547f78a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.joins -import scala.collection.mutable - import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD From 1ea931b5c585144512d354d0a1cbebdee95c66ed Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Fri, 3 Jul 2020 19:29:10 -0700 Subject: [PATCH 07/14] Address PR comments --- .../joins/BroadcastHashJoinExec.scala | 71 +++++-- .../spark/sql/execution/joins/HashJoin.scala | 21 +- .../joins/ShuffledHashJoinExec.scala | 3 +- .../execution/joins/BroadcastJoinSuite.scala | 188 +++++++++++++----- 4 files changed, 217 insertions(+), 66 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 90ef7547f78a6..c39730772b490 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.joins +import scala.collection.mutable + import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -51,7 +53,7 @@ case class BroadcastHashJoinExec( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def requiredChildDistribution: Seq[Distribution] = { - val mode = HashedRelationBroadcastMode(buildKeys) + val mode = HashedRelationBroadcastMode(buildBoundKeys) buildSide match { case BuildLeft => BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil @@ -61,25 +63,66 @@ case class BroadcastHashJoinExec( } override def outputPartitioning: Partitioning = { - def buildKeys: Seq[Expression] = buildSide match { - case BuildLeft => leftKeys - case BuildRight => rightKeys - } - joinType match { case _: InnerLike => streamedPlan.outputPartitioning match { - case h: HashPartitioning => - PartitioningCollection(Seq(h, HashPartitioning(buildKeys, h.numPartitions))) - case c: PartitioningCollection - if c.partitionings.forall(_.isInstanceOf[HashPartitioning]) => - PartitioningCollection(c.partitionings :+ HashPartitioning(buildKeys, c.numPartitions)) + case h: HashPartitioning => PartitioningCollection(expandOutputPartitioning(h)) + case c: PartitioningCollection => + def expand(partitioning: PartitioningCollection): Partitioning = { + PartitioningCollection(partitioning.partitionings.flatMap { + case h: HashPartitioning => expandOutputPartitioning(h) + case c: PartitioningCollection => Seq(expand(c)) + case other => Seq(other) + }) + } + expand(c) case other => other } case _ => streamedPlan.outputPartitioning } } + // An one-to-many mapping from a streamed key to build keys. + private lazy val streamedKeyToBuildKeyMapping = { + val mapping = mutable.Map.empty[Expression, Seq[Expression]] + streamedKeys.zip(buildKeys).foreach { + case (streamedKey, buildKey) => + val key = streamedKey.canonicalized + mapping.get(key) match { + case Some(v) => mapping.put(key, v :+ buildKey) + case None => mapping.put(key, Seq(buildKey)) + } + } + mapping.toMap + } + + // Expands the given partitioning by substituting streamed keys with build keys. + // For example, if the expressions for the given partitioning are Seq("a", "b", "c") + // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), + // the expanded partitioning will have the following expressions: + // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). + private def expandOutputPartitioning(partitioning: HashPartitioning): Seq[HashPartitioning] = { + def generateExprCombinations( + current: Seq[Expression], + accumulated: Seq[Expression], + all: mutable.ListBuffer[Seq[Expression]]): Unit = { + if (current.isEmpty) { + all += accumulated + } else { + generateExprCombinations(current.tail, accumulated :+ current.head, all) + val mapped = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) + if (mapped.isDefined) { + mapped.get.foreach(m => + generateExprCombinations(current.tail, accumulated :+ m, all)) + } + } + } + + val all = mutable.ListBuffer[Seq[Expression]]() + generateExprCombinations(partitioning.expressions, Nil, all) + all.map(HashPartitioning(_, partitioning.numPartitions)) + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") @@ -155,13 +198,13 @@ case class BroadcastHashJoinExec( ctx: CodegenContext, input: Seq[ExprCode]): (ExprCode, String) = { ctx.currentVars = input - if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { + if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) { // generate the join key as Long - val ev = streamedKeys.head.genCode(ctx) + val ev = streamedBoundKeys.head.genCode(ctx) (ev, ev.isNull) } else { // generate the join key as UnsafeRow - val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) + val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys) (ev, s"${ev.value}.anyNull()") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index c7c3e1672f034..4626a41f19ea8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -62,21 +62,30 @@ trait HashJoin extends BaseJoinExec { protected lazy val (buildKeys, streamedKeys) = { require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), "Join keys from two sides should have same types") - val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output) - val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output) buildSide match { - case BuildLeft => (lkeys, rkeys) - case BuildRight => (rkeys, lkeys) + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) } } + private lazy val (buildOutput, streamedOutput) = { + buildSide match { + case BuildLeft => (left.output, right.output) + case BuildRight => (right.output, left.output) + } + } + + protected lazy val buildBoundKeys = + bindReferences(HashJoin.rewriteKeyExpr(buildKeys), buildOutput) + protected lazy val streamedBoundKeys = + bindReferences(HashJoin.rewriteKeyExpr(streamedKeys), streamedOutput) protected def buildSideKeyGenerator(): Projection = - UnsafeProjection.create(buildKeys) + UnsafeProjection.create(buildBoundKeys) protected def streamSideKeyGenerator(): UnsafeProjection = - UnsafeProjection.create(streamedKeys) + UnsafeProjection.create(streamedBoundKeys) @transient private[this] lazy val boundCondition = if (condition.isDefined) { Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 2b7cd65e7d96f..1120850fdddaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -55,7 +55,8 @@ case class ShuffledHashJoinExec( val buildTime = longMetric("buildTime") val start = System.nanoTime() val context = TaskContext.get() - val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) + val relation = HashedRelation( + iter, buildBoundKeys, taskMemoryManager = context.taskMemoryManager()) buildTime += NANOSECONDS.toMillis(System.nanoTime() - start) buildDataSize += relation.estimatedSize // This relation is usually used until the end of task. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index ea20adbe8ebc9..9fbb83043d59d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -21,11 +21,12 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BitwiseAnd, BitwiseOr, Cast, Expression, Literal, ShiftLeft} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.BROADCAST import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection} -import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{DummySparkPlan, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} @@ -417,6 +418,57 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils } } + test("broadcast join where streamed side's output partitioning is HashPartitioning") { + withTable("t1", "t3") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { + val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") + val df2 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i2", "j2") + val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3") + df1.write.format("parquet").bucketBy(8, "i1", "j1").saveAsTable("t1") + df3.write.format("parquet").bucketBy(8, "i3", "j3").saveAsTable("t3") + val t1 = spark.table("t1") + val t3 = spark.table("t3") + + // join1 is a broadcast join where df2 is broadcasted. Note that output partitioning on the + // streamed side (t1) is HashPartitioning (bucketed files). + val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") === df2("j2")) + val plan1 = join1.queryExecution.executedPlan + assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty) + val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b } + assert(broadcastJoins.size == 1) + broadcastJoins(0).outputPartitioning match { + case p: PartitioningCollection => + assert(p.partitionings.size == 4) + // Verify all the combinations of output partitioning. + Seq(Seq(t1("i1"), t1("j1")), + Seq(t1("i1"), df2("j2")), + Seq(df2("i2"), t1("j1")), + Seq(df2("i2"), df2("j2"))).foreach { expected => + val expectedExpressions = expected.map(_.expr) + assert(p.partitionings.exists { + case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) + }) + } + case _ => fail() + } + + // Join on the column from the broadcasted side (i2, j2) and make sure output partitioning + // is maintained by checking no shuffle exchange is introduced. + val join2 = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) + val plan2 = join2.queryExecution.executedPlan + assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1) + assert(collect(plan2) { case b: BroadcastHashJoinExec => b }.size == 1) + assert(collect(plan2) { case e: ShuffleExchangeExec => e }.isEmpty) + + // Validate the data with broadcast join off. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) + QueryTest.sameRows(join2.collect().toSeq, df.collect().toSeq) + } + } + } + } + test("broadcast join where streamed side's output partitioning is PartitioningCollection") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { val t1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") @@ -439,10 +491,15 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils val broadcastJoins = collect(plan2) { case b: BroadcastHashJoinExec => b } assert(broadcastJoins.size == 1) broadcastJoins(0).outputPartitioning match { - case p: PartitioningCollection - if p.partitionings.forall(_.isInstanceOf[HashPartitioning]) => - // two partitionings from sort merge join and one from build side. + case p: PartitioningCollection => assert(p.partitionings.size == 3) + // Verify all the combinations of output partitioning. + Seq(Seq(t1("i1")), Seq(t2("i2")), Seq(t3("i3"))).foreach { expected => + val expectedExpressions = expected.map(_.expr) + assert(p.partitionings.exists { + case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) + }) + } case _ => fail() } @@ -463,47 +520,88 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils } } - test("broadcast join where streamed side's output partitioning is HashPartitioning") { - withTable("t1", "t3") { - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { - val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") - val df2 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i2", "j2") - val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3") - df1.write.format("parquet").bucketBy(8, "i1", "j1").saveAsTable("t1") - df3.write.format("parquet").bucketBy(8, "i3", "j3").saveAsTable("t3") - val t1 = spark.table("t1") - val t3 = spark.table("t3") - - // join1 is a broadcast join where df2 is broadcasted. Note that output partitioning on the - // streamed side (t1) is HashPartitioning (bucketed files). - val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") === df2("j2")) - val plan1 = join1.queryExecution.executedPlan - assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty) - val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b } - assert(broadcastJoins.size == 1) - broadcastJoins(0).outputPartitioning match { - case p: PartitioningCollection - if p.partitionings.forall(_.isInstanceOf[HashPartitioning]) => - // one partitioning from streamed side and one from build side. - assert(p.partitionings.size == 2) - case _ => fail() - } - - // Join on the column from the broadcasted side (i2, j2) and make sure output partitioning - // is maintained by checking no shuffle exchange is introduced. - val join2 = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) - val plan2 = join2.queryExecution.executedPlan - assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1) - assert(collect(plan2) { case b: BroadcastHashJoinExec => b }.size == 1) - assert(collect(plan2) { case e: ShuffleExchangeExec => e }.isEmpty) + test("BroadcastHashJoinExec output partitioning scenarios for inner join") { + val l1 = AttributeReference("l1", LongType)() + val l2 = AttributeReference("l2", LongType)() + val l3 = AttributeReference("l3", LongType)() + val r1 = AttributeReference("r1", LongType)() + val r2 = AttributeReference("r2", LongType)() + val r3 = AttributeReference("r3", LongType)() + + // Streamed side has a HashPartitioning. + var bhj = BroadcastHashJoinExec( + leftKeys = Seq(l2, l3), + rightKeys = Seq(r1, r2), + Inner, + BuildRight, + None, + left = DummySparkPlan(outputPartitioning = HashPartitioning(Seq(l1, l2, l3), 1)), + right = DummySparkPlan()) + var expected = PartitioningCollection(Seq( + HashPartitioning(Seq(l1, l2, l3), 1), + HashPartitioning(Seq(l1, l2, r2), 1), + HashPartitioning(Seq(l1, r1, l3), 1), + HashPartitioning(Seq(l1, r1, r2), 1))) + assert(bhj.outputPartitioning === expected) + + // Streamed side has a PartitioningCollection. + bhj = BroadcastHashJoinExec( + leftKeys = Seq(l1, l2, l3), + rightKeys = Seq(r1, r2, r3), + Inner, + BuildRight, + None, + left = DummySparkPlan(outputPartitioning = PartitioningCollection(Seq( + HashPartitioning(Seq(l1, l2), 1), HashPartitioning(Seq(l3), 1)))), + right = DummySparkPlan()) + expected = PartitioningCollection(Seq( + HashPartitioning(Seq(l1, l2), 1), + HashPartitioning(Seq(l1, r2), 1), + HashPartitioning(Seq(r1, l2), 1), + HashPartitioning(Seq(r1, r2), 1), + HashPartitioning(Seq(l3), 1), + HashPartitioning(Seq(r3), 1))) + assert(bhj.outputPartitioning === expected) + + // Streamed side has a nested PartitioningCollection. + bhj = BroadcastHashJoinExec( + leftKeys = Seq(l1, l2, l3), + rightKeys = Seq(r1, r2, r3), + Inner, + BuildRight, + None, + left = DummySparkPlan(outputPartitioning = PartitioningCollection(Seq( + PartitioningCollection(Seq(HashPartitioning(Seq(l1), 1), HashPartitioning(Seq(l2), 1))), + HashPartitioning(Seq(l3), 1)))), + right = DummySparkPlan()) + expected = PartitioningCollection(Seq( + PartitioningCollection(Seq( + HashPartitioning(Seq(l1), 1), + HashPartitioning(Seq(r1), 1), + HashPartitioning(Seq(l2), 1), + HashPartitioning(Seq(r2), 1))), + HashPartitioning(Seq(l3), 1), + HashPartitioning(Seq(r3), 1))) + assert(bhj.outputPartitioning === expected) + + // One-to-mapping case ("l1" = "r1" AND "l1" = "r2") + bhj = BroadcastHashJoinExec( + leftKeys = Seq(l1, l1), + rightKeys = Seq(r1, r2), + Inner, + BuildRight, + None, + left = DummySparkPlan(outputPartitioning = HashPartitioning(Seq(l1, l2), 1)), + right = DummySparkPlan()) + expected = PartitioningCollection(Seq( + HashPartitioning(Seq(l1, l2), 1), + HashPartitioning(Seq(r1, l2), 1), + HashPartitioning(Seq(r2, l2), 1))) + assert(bhj.outputPartitioning === expected) + } - // Validate the data with broadcast join off. - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) - QueryTest.sameRows(join2.collect().toSeq, df.collect().toSeq) - } - } - } + private def expressionsEqual(l: Seq[Expression], r: Seq[Expression]): Boolean = { + l.length == r.length && l.zip(r).forall { case (e1, e2) => e1.semanticEquals(e2) } } } From c5f48032907dc5b550d410984254015e4e3ae235 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Fri, 3 Jul 2020 19:36:07 -0700 Subject: [PATCH 08/14] formatting --- .../spark/sql/execution/joins/BroadcastHashJoinExec.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index c39730772b490..fb0c04a530435 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -112,8 +112,7 @@ case class BroadcastHashJoinExec( generateExprCombinations(current.tail, accumulated :+ current.head, all) val mapped = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) if (mapped.isDefined) { - mapped.get.foreach(m => - generateExprCombinations(current.tail, accumulated :+ m, all)) + mapped.get.foreach(m => generateExprCombinations(current.tail, accumulated :+ m, all)) } } } From 126ee53705fc72f2be8c93a086ae8cf814549184 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Tue, 7 Jul 2020 20:20:56 -0700 Subject: [PATCH 09/14] address PR comments --- .../joins/BroadcastHashJoinExec.scala | 49 ++++++++++--------- .../execution/joins/BroadcastJoinSuite.scala | 46 ++++++++--------- 2 files changed, 46 insertions(+), 49 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index fb0c04a530435..3b69d193d6bdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -62,20 +62,12 @@ case class BroadcastHashJoinExec( } } - override def outputPartitioning: Partitioning = { + override lazy val outputPartitioning: Partitioning = { joinType match { case _: InnerLike => streamedPlan.outputPartitioning match { - case h: HashPartitioning => PartitioningCollection(expandOutputPartitioning(h)) - case c: PartitioningCollection => - def expand(partitioning: PartitioningCollection): Partitioning = { - PartitioningCollection(partitioning.partitionings.flatMap { - case h: HashPartitioning => expandOutputPartitioning(h) - case c: PartitioningCollection => Seq(expand(c)) - case other => Seq(other) - }) - } - expand(c) + case h: HashPartitioning => expandOutputPartitioning(h) + case c: PartitioningCollection => expandOutputPartitioning(c) case other => other } case _ => streamedPlan.outputPartitioning @@ -96,30 +88,39 @@ case class BroadcastHashJoinExec( mapping.toMap } - // Expands the given partitioning by substituting streamed keys with build keys. + // Expands the given partitioning collection recursively. + private def expandOutputPartitioning( + partitioning: PartitioningCollection): PartitioningCollection = { + PartitioningCollection(partitioning.partitionings.flatMap { + case h: HashPartitioning => expandOutputPartitioning(h).partitionings + case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) + case other => Seq(other) + }) + } + + // Expands the given hash partitioning by substituting streamed keys with build keys. // For example, if the expressions for the given partitioning are Seq("a", "b", "c") // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), // the expanded partitioning will have the following expressions: // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). - private def expandOutputPartitioning(partitioning: HashPartitioning): Seq[HashPartitioning] = { + // The expanded expressions are returned as PartitioningCollection. + private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { def generateExprCombinations( current: Seq[Expression], - accumulated: Seq[Expression], - all: mutable.ListBuffer[Seq[Expression]]): Unit = { + accumulated: Seq[Expression]): Seq[Seq[Expression]] = { if (current.isEmpty) { - all += accumulated + Seq(accumulated) } else { - generateExprCombinations(current.tail, accumulated :+ current.head, all) - val mapped = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) - if (mapped.isDefined) { - mapped.get.foreach(m => generateExprCombinations(current.tail, accumulated :+ m, all)) - } + val buildKeys = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) + generateExprCombinations(current.tail, accumulated :+ current.head) ++ + buildKeys.map { _.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b)) + }.getOrElse(Nil) } } - val all = mutable.ListBuffer[Seq[Expression]]() - generateExprCombinations(partitioning.expressions, Nil, all) - all.map(HashPartitioning(_, partitioning.numPartitions)) + PartitioningCollection( + generateExprCombinations(partitioning.expressions, Nil).map( + HashPartitioning(_, partitioning.numPartitions))) } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 9fbb83043d59d..43a173104d892 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -436,20 +436,18 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty) val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b } assert(broadcastJoins.size == 1) - broadcastJoins(0).outputPartitioning match { - case p: PartitioningCollection => - assert(p.partitionings.size == 4) - // Verify all the combinations of output partitioning. - Seq(Seq(t1("i1"), t1("j1")), - Seq(t1("i1"), df2("j2")), - Seq(df2("i2"), t1("j1")), - Seq(df2("i2"), df2("j2"))).foreach { expected => - val expectedExpressions = expected.map(_.expr) - assert(p.partitionings.exists { - case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) - }) - } - case _ => fail() + assert(broadcastJoins(0).outputPartitioning.isInstanceOf[PartitioningCollection]) + val p = broadcastJoins(0).outputPartitioning.asInstanceOf[PartitioningCollection] + assert(p.partitionings.size == 4) + // Verify all the combinations of output partitioning. + Seq(Seq(t1("i1"), t1("j1")), + Seq(t1("i1"), df2("j2")), + Seq(df2("i2"), t1("j1")), + Seq(df2("i2"), df2("j2"))).foreach { expected => + val expectedExpressions = expected.map(_.expr) + assert(p.partitionings.exists { + case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) + }) } // Join on the column from the broadcasted side (i2, j2) and make sure output partitioning @@ -490,17 +488,15 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils assert(collect(plan2) { case e: ShuffleExchangeExec => e }.size == 2) val broadcastJoins = collect(plan2) { case b: BroadcastHashJoinExec => b } assert(broadcastJoins.size == 1) - broadcastJoins(0).outputPartitioning match { - case p: PartitioningCollection => - assert(p.partitionings.size == 3) - // Verify all the combinations of output partitioning. - Seq(Seq(t1("i1")), Seq(t2("i2")), Seq(t3("i3"))).foreach { expected => - val expectedExpressions = expected.map(_.expr) - assert(p.partitionings.exists { - case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) - }) - } - case _ => fail() + assert(broadcastJoins(0).outputPartitioning.isInstanceOf[PartitioningCollection]) + val p = broadcastJoins(0).outputPartitioning.asInstanceOf[PartitioningCollection] + assert(p.partitionings.size == 3) + // Verify all the combinations of output partitioning. + Seq(Seq(t1("i1")), Seq(t2("i2")), Seq(t3("i3"))).foreach { expected => + val expectedExpressions = expected.map(_.expr) + assert(p.partitionings.exists { + case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) + }) } // Join on the column from the broadcasted side (i3) and make sure output partitioning From afa5acac186c7340d086e199883dfa5238f82200 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Wed, 15 Jul 2020 21:38:06 -0700 Subject: [PATCH 10/14] Address PR comments --- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++++ .../joins/BroadcastHashJoinExec.scala | 27 ++++++++++++++---- .../execution/joins/BroadcastJoinSuite.scala | 28 +++++++++++++++++++ 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9be0497e46603..e37e47ea7ed40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2665,6 +2665,15 @@ object SQLConf { .checkValue(_ > 0, "The difference must be positive.") .createWithDefault(4) + val BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT = + buildConf("spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit") + .doc("The maximum number of partitionings that a HashPartitioning can be expanded to. " + + "This configuration is applicable only for inner joins.") + .version("3.1.0") + .intConf + .checkValue(_ > 0, "The value must be positive.") + .createWithDefault(8) + /** * Holds information about keys that have been deprecated. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 3b69d193d6bdc..f57bb0128c667 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, LongType} /** @@ -64,7 +65,7 @@ case class BroadcastHashJoinExec( override lazy val outputPartitioning: Partitioning = { joinType match { - case _: InnerLike => + case Inner => streamedPlan.outputPartitioning match { case h: HashPartitioning => expandOutputPartitioning(h) case c: PartitioningCollection => expandOutputPartitioning(c) @@ -105,22 +106,36 @@ case class BroadcastHashJoinExec( // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). // The expanded expressions are returned as PartitioningCollection. private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { + val maxNumCombinations = sqlContext.conf.getConf( + SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT) + var currentNumCombinations = 0 + def generateExprCombinations( current: Seq[Expression], accumulated: Seq[Expression]): Seq[Seq[Expression]] = { - if (current.isEmpty) { + if (currentNumCombinations > maxNumCombinations) { + Nil + } else if (current.isEmpty) { + currentNumCombinations += 1 Seq(accumulated) } else { val buildKeys = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) generateExprCombinations(current.tail, accumulated :+ current.head) ++ - buildKeys.map { _.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b)) - }.getOrElse(Nil) + buildKeys.map { bKeys => + bKeys.flatMap { bKey => + if (currentNumCombinations < maxNumCombinations) { + generateExprCombinations(current.tail, accumulated :+ bKey) + } else { + Nil + } + } + }.getOrElse(Nil) } } PartitioningCollection( - generateExprCombinations(partitioning.expressions, Nil).map( - HashPartitioning(_, partitioning.numPartitions))) + generateExprCombinations(partitioning.expressions, Nil) + .map(HashPartitioning(_, partitioning.numPartitions))) } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 43a173104d892..6c85032183d33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -596,6 +596,34 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils assert(bhj.outputPartitioning === expected) } + test("BroadcastHashJoinExec output partitioning size should be limited with a config") { + val l1 = AttributeReference("l1", LongType)() + val l2 = AttributeReference("l2", LongType)() + val r1 = AttributeReference("r1", LongType)() + val r2 = AttributeReference("r2", LongType)() + + val expected = Seq( + HashPartitioning(Seq(l1, l2), 1), + HashPartitioning(Seq(l1, r2), 1), + HashPartitioning(Seq(r1, l2), 1), + HashPartitioning(Seq(r1, r2), 1)) + + Seq(1, 2, 3, 4).foreach { limit => + withSQLConf( + SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT.key -> s"$limit") { + val bhj = BroadcastHashJoinExec( + leftKeys = Seq(l1, l2), + rightKeys = Seq(r1, r2), + Inner, + BuildRight, + None, + left = DummySparkPlan(outputPartitioning = HashPartitioning(Seq(l1, l2), 1)), + right = DummySparkPlan()) + assert(bhj.outputPartitioning === PartitioningCollection(expected.take(limit))) + } + } + } + private def expressionsEqual(l: Seq[Expression], r: Seq[Expression]): Boolean = { l.length == r.length && l.zip(r).forall { case (e1, e2) => e1.semanticEquals(e2) } } From 51187dc08d0712fbeb54c2c8537ab77aa22f6afd Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Thu, 16 Jul 2020 14:38:10 -0700 Subject: [PATCH 11/14] Address PR comments --- .../apache/spark/sql/internal/SQLConf.scala | 8 ++++++-- .../joins/BroadcastHashJoinExec.scala | 20 ++++++------------- .../spark/sql/execution/joins/HashJoin.scala | 6 +++--- .../adaptive/AdaptiveQueryExecSuite.scala | 5 +++-- .../execution/joins/BroadcastJoinSuite.scala | 6 +++--- 5 files changed, 21 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e37e47ea7ed40..5175b2b5f7366 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2668,10 +2668,11 @@ object SQLConf { val BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT = buildConf("spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit") .doc("The maximum number of partitionings that a HashPartitioning can be expanded to. " + - "This configuration is applicable only for inner joins.") + "This configuration is applicable only for inner joins and can be set to '0' to disable " + + "this feature.") .version("3.1.0") .intConf - .checkValue(_ > 0, "The value must be positive.") + .checkValue(_ >= 0, "The value must be non-negative.") .createWithDefault(8) /** @@ -2984,6 +2985,9 @@ class SQLConf extends Serializable with Logging { LegacyBehaviorPolicy.withName(getConf(SQLConf.LEGACY_TIME_PARSER_POLICY)) } + def broadcastHashJoinOutputPartitioningExpandLimit: Int = + getConf(BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index f57bb0128c667..857e760e840ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -65,7 +65,7 @@ case class BroadcastHashJoinExec( override lazy val outputPartitioning: Partitioning = { joinType match { - case Inner => + case Inner if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => streamedPlan.outputPartitioning match { case h: HashPartitioning => expandOutputPartitioning(h) case c: PartitioningCollection => expandOutputPartitioning(c) @@ -106,30 +106,22 @@ case class BroadcastHashJoinExec( // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). // The expanded expressions are returned as PartitioningCollection. private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { - val maxNumCombinations = sqlContext.conf.getConf( - SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT) + val maxNumCombinations = sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit var currentNumCombinations = 0 def generateExprCombinations( current: Seq[Expression], accumulated: Seq[Expression]): Seq[Seq[Expression]] = { - if (currentNumCombinations > maxNumCombinations) { + if (currentNumCombinations >= maxNumCombinations) { Nil } else if (current.isEmpty) { currentNumCombinations += 1 Seq(accumulated) } else { - val buildKeys = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) + val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) generateExprCombinations(current.tail, accumulated :+ current.head) ++ - buildKeys.map { bKeys => - bKeys.flatMap { bKey => - if (currentNumCombinations < maxNumCombinations) { - generateExprCombinations(current.tail, accumulated :+ bKey) - } else { - Nil - } - } - }.getOrElse(Nil) + buildKeysOpt.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))) + .getOrElse(Nil) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 4626a41f19ea8..7c3c53b0fa54c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -68,17 +68,17 @@ trait HashJoin extends BaseJoinExec { } } - private lazy val (buildOutput, streamedOutput) = { + @transient private lazy val (buildOutput, streamedOutput) = { buildSide match { case BuildLeft => (left.output, right.output) case BuildRight => (right.output, left.output) } } - protected lazy val buildBoundKeys = + @transient protected lazy val buildBoundKeys = bindReferences(HashJoin.rewriteKeyExpr(buildKeys), buildOutput) - protected lazy val streamedBoundKeys = + @transient protected lazy val streamedBoundKeys = bindReferences(HashJoin.rewriteKeyExpr(streamedKeys), streamedOutput) protected def buildSideKeyGenerator(): Projection = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 16f2834361ffd..511e0cf0b3817 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -557,7 +557,8 @@ class AdaptiveQueryExecSuite withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", + SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT.key -> "0") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM testData " + "join testData2 t2 ON key = t2.a " + @@ -565,7 +566,7 @@ class AdaptiveQueryExecSuite val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 2) val smj2 = findTopLevelSortMergeJoin(adaptivePlan) - assert(smj2.size == 1, origPlan.toString) + assert(smj2.size == 2, origPlan.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 6c85032183d33..7ff945f5cbfb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -461,7 +461,7 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils // Validate the data with broadcast join off. withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3")) - QueryTest.sameRows(join2.collect().toSeq, df.collect().toSeq) + checkAnswer(join2, df) } } } @@ -508,10 +508,10 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils assert(collect(plan3) { case b: BroadcastHashJoinExec => b }.size == 1) assert(collect(plan3) { case e: ShuffleExchangeExec => e }.size == 3) - // Validate the data with boradcast join off. + // Validate the data with broadcast join off. withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df = join2.join(t4, join2("i3") === t4("i4")) - QueryTest.sameRows(join3.collect().toSeq, df.collect().toSeq) + checkAnswer(join3, df) } } } From 80df4dc7318af093d76272bb8c4c17395f7ec398 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Fri, 17 Jul 2020 10:34:44 -0700 Subject: [PATCH 12/14] Address PR comment --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5175b2b5f7366..afa19a81ef403 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2667,6 +2667,7 @@ object SQLConf { val BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT = buildConf("spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit") + .internal() .doc("The maximum number of partitionings that a HashPartitioning can be expanded to. " + "This configuration is applicable only for inner joins and can be set to '0' to disable " + "this feature.") From ba19acbfde7d80825e424476e71a6f1daa36c266 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Fri, 17 Jul 2020 10:49:05 -0700 Subject: [PATCH 13/14] clean up header + change it to support inner-like joins. --- .../spark/sql/execution/joins/BroadcastHashJoinExec.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 857e760e840ba..71faad9829a42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, LongType} /** @@ -65,7 +64,7 @@ case class BroadcastHashJoinExec( override lazy val outputPartitioning: Partitioning = { joinType match { - case Inner if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => + case _: InnerLike if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => streamedPlan.outputPartitioning match { case h: HashPartitioning => expandOutputPartitioning(h) case c: PartitioningCollection => expandOutputPartitioning(c) From 9caeecddaa07ef825b73835a3666502df468f881 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Fri, 17 Jul 2020 13:59:21 -0700 Subject: [PATCH 14/14] Address PR comments --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index afa19a81ef403..c1aa3932c6f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2669,8 +2669,8 @@ object SQLConf { buildConf("spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit") .internal() .doc("The maximum number of partitionings that a HashPartitioning can be expanded to. " + - "This configuration is applicable only for inner joins and can be set to '0' to disable " + - "this feature.") + "This configuration is applicable only for BroadcastHashJoin inner joins and can be " + + "set to '0' to disable this feature.") .version("3.1.0") .intConf .checkValue(_ >= 0, "The value must be non-negative.")