diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 8d9ba54f6568d..fa3a55aa5ad94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -52,7 +52,41 @@ trait HashJoin extends BaseJoinExec { } } - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + override def outputPartitioning: Partitioning = buildSide match { + case BuildLeft => + joinType match { + case _: InnerLike | RightOuter => right.outputPartitioning + case x => + throw new IllegalArgumentException( + s"HashJoin should not take $x as the JoinType with building left side") + } + case BuildRight => + joinType match { + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => + left.outputPartitioning + case x => + throw new IllegalArgumentException( + s"HashJoin should not take $x as the JoinType with building right side") + } + } + + override def outputOrdering: Seq[SortOrder] = buildSide match { + case BuildLeft => + joinType match { + case _: InnerLike | RightOuter => right.outputOrdering + case x => + throw new IllegalArgumentException( + s"HashJoin should not take $x as the JoinType with building left side") + } + case BuildRight => + joinType match { + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => + left.outputOrdering + case x => + throw new IllegalArgumentException( + s"HashJoin should not take $x as the JoinType with building right side") + } + } protected lazy val (buildPlan, streamedPlan) = buildSide match { case BuildLeft => (left, right) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index b4f626270cfc9..c42d4c6f74a93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1104,4 +1104,47 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan }) } } + + test("SPARK-32383: Preserve hash join (BHJ and SHJ) stream side ordering") { + val df1 = spark.range(100).select($"id".as("k1")) + val df2 = spark.range(100).select($"id".as("k2")) + val df3 = spark.range(3).select($"id".as("k3")) + val df4 = spark.range(100).select($"id".as("k4")) + + // Test broadcast hash join + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50") { + Seq("inner", "left_outer").foreach(joinType => { + val plan = df1.join(df2, $"k1" === $"k2", joinType) + .join(df3, $"k1" === $"k3", joinType) + .join(df4, $"k1" === $"k4", joinType) + .queryExecution + .executedPlan + assert(plan.collect { case _: SortMergeJoinExec => true }.size === 2) + assert(plan.collect { case _: BroadcastHashJoinExec => true }.size === 1) + // No extra sort before last sort merge join + assert(plan.collect { case _: SortExec => true }.size === 3) + }) + } + + // Test shuffled hash join + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50", + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { + val df3 = spark.range(10).select($"id".as("k3")) + + Seq("inner", "left_outer").foreach(joinType => { + val plan = df1.join(df2, $"k1" === $"k2", joinType) + .join(df3, $"k1" === $"k3", joinType) + .join(df4, $"k1" === $"k4", joinType) + .queryExecution + .executedPlan + assert(plan.collect { case _: SortMergeJoinExec => true }.size === 2) + assert(plan.collect { case _: ShuffledHashJoinExec => true }.size === 1) + // No extra sort before last sort merge join + assert(plan.collect { case _: SortExec => true }.size === 3) + }) + } + } }