Skip to content

Commit e79909e

Browse files
committed
Fix parallelism in join operator unit tests.
1 parent 899dce2 commit e79909e

File tree

3 files changed

+89
-54
lines changed

3 files changed

+89
-54
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,67 +34,75 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
3434
rightRows: DataFrame,
3535
condition: Expression,
3636
expectedAnswer: Seq[Product]): Unit = {
37-
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
38-
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
39-
ExtractEquiJoinKeys.unapply(join).foreach {
40-
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
41-
42-
def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
43-
val broadcastHashJoin =
44-
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right)
45-
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
46-
}
47-
48-
def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
49-
val shuffledHashJoin =
50-
execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right)
51-
val filteredJoin =
52-
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
53-
EnsureRequirements(sqlContext).apply(filteredJoin)
54-
}
55-
56-
def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
57-
val sortMergeJoin =
58-
execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
59-
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
60-
EnsureRequirements(sqlContext).apply(filteredJoin)
61-
}
62-
63-
test(s"$testName using BroadcastHashJoin (build=left)") {
37+
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
38+
ExtractEquiJoinKeys.unapply(join).foreach {
39+
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
40+
41+
def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
42+
val broadcastHashJoin =
43+
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right)
44+
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
45+
}
46+
47+
def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
48+
val shuffledHashJoin =
49+
execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right)
50+
val filteredJoin =
51+
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
52+
EnsureRequirements(sqlContext).apply(filteredJoin)
53+
}
54+
55+
def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
56+
val sortMergeJoin =
57+
execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
58+
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
59+
EnsureRequirements(sqlContext).apply(filteredJoin)
60+
}
61+
62+
test(s"$testName using BroadcastHashJoin (build=left)") {
63+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
6464
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
6565
makeBroadcastHashJoin(left, right, joins.BuildLeft),
6666
expectedAnswer.map(Row.fromTuple),
6767
sortAnswers = true)
6868
}
69+
}
6970

70-
test(s"$testName using BroadcastHashJoin (build=right)") {
71+
test(s"$testName using BroadcastHashJoin (build=right)") {
72+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
7173
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
7274
makeBroadcastHashJoin(left, right, joins.BuildRight),
7375
expectedAnswer.map(Row.fromTuple),
7476
sortAnswers = true)
7577
}
78+
}
7679

77-
test(s"$testName using ShuffledHashJoin (build=left)") {
80+
test(s"$testName using ShuffledHashJoin (build=left)") {
81+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
7882
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
7983
makeShuffledHashJoin(left, right, joins.BuildLeft),
8084
expectedAnswer.map(Row.fromTuple),
8185
sortAnswers = true)
8286
}
87+
}
8388

84-
test(s"$testName using ShuffledHashJoin (build=right)") {
89+
test(s"$testName using ShuffledHashJoin (build=right)") {
90+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
8591
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
8692
makeShuffledHashJoin(left, right, joins.BuildRight),
8793
expectedAnswer.map(Row.fromTuple),
8894
sortAnswers = true)
8995
}
96+
}
9097

91-
test(s"$testName using SortMergeJoin") {
98+
test(s"$testName using SortMergeJoin") {
99+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
92100
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
93101
makeSortMergeJoin(left, right),
94102
expectedAnswer.map(Row.fromTuple),
95103
sortAnswers = true)
96104
}
97-
}
105+
}
98106
}
99107
}
100108

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,43 +35,52 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
3535
joinType: JoinType,
3636
condition: Expression,
3737
expectedAnswer: Seq[Product]): Unit = {
38-
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
39-
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
40-
ExtractEquiJoinKeys.unapply(join).foreach {
41-
case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
42-
test(s"$testName using ShuffledHashOuterJoin") {
38+
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
39+
ExtractEquiJoinKeys.unapply(join).foreach {
40+
case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
41+
test(s"$testName using ShuffledHashOuterJoin") {
42+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
4343
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
4444
EnsureRequirements(sqlContext).apply(
4545
ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
4646
expectedAnswer.map(Row.fromTuple),
4747
sortAnswers = true)
4848
}
49+
}
4950

50-
if (joinType != FullOuter) {
51-
test(s"$testName using BroadcastHashOuterJoin") {
51+
if (joinType != FullOuter) {
52+
test(s"$testName using BroadcastHashOuterJoin") {
53+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
5254
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
5355
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
5456
expectedAnswer.map(Row.fromTuple),
5557
sortAnswers = true)
5658
}
59+
}
5760

58-
test(s"$testName using SortMergeOuterJoin") {
61+
test(s"$testName using SortMergeOuterJoin") {
62+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
5963
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
60-
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
64+
EnsureRequirements(sqlContext).apply(
65+
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
6166
expectedAnswer.map(Row.fromTuple),
6267
sortAnswers = false)
6368
}
6469
}
65-
}
70+
}
71+
}
6672

67-
test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
73+
test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
74+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
6875
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
6976
joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)),
7077
expectedAnswer.map(Row.fromTuple),
7178
sortAnswers = true)
7279
}
80+
}
7381

74-
test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
82+
test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
83+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
7584
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
7685
joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)),
7786
expectedAnswer.map(Row.fromTuple),
@@ -85,14 +94,19 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
8594
Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches
8695
Row(2, 1.0),
8796
Row(3, 3.0),
97+
Row(5, 1.0),
98+
Row(6, 6.0),
8899
Row(null, null)
89100
)), new StructType().add("a", IntegerType).add("b", DoubleType))
90101

91102
val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
103+
Row(0, 0.0),
92104
Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
93105
Row(2, 3.0),
94106
Row(3, 2.0),
95107
Row(4, 1.0),
108+
Row(5, 3.0),
109+
Row(7, 7.0),
96110
Row(null, null)
97111
)), new StructType().add("c", IntegerType).add("d", DoubleType))
98112

@@ -117,7 +131,9 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
117131
(2, 1.0, 2, 3.0),
118132
(2, 1.0, 2, 3.0),
119133
(2, 1.0, 2, 3.0),
120-
(3, 3.0, null, null)
134+
(3, 3.0, null, null),
135+
(5, 1.0, 5, 3.0),
136+
(6, 6.0, null, null)
121137
)
122138
)
123139

@@ -129,12 +145,15 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
129145
condition,
130146
Seq(
131147
(null, null, null, null),
148+
(null, null, 0, 0.0),
132149
(2, 1.0, 2, 3.0),
133150
(2, 1.0, 2, 3.0),
134151
(2, 1.0, 2, 3.0),
135152
(2, 1.0, 2, 3.0),
136153
(null, null, 3, 2.0),
137-
(null, null, 4, 1.0)
154+
(null, null, 4, 1.0),
155+
(5, 1.0, 5, 3.0),
156+
(null, null, 7, 7.0)
138157
)
139158
)
140159

@@ -151,8 +170,12 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
151170
(2, 1.0, 2, 3.0),
152171
(2, 1.0, 2, 3.0),
153172
(3, 3.0, null, null),
173+
(5, 1.0, 5, 3.0),
174+
(6, 6.0, null, null),
175+
(null, null, 0, 0.0),
154176
(null, null, 3, 2.0),
155177
(null, null, 4, 1.0),
178+
(null, null, 7, 7.0),
156179
(null, null, null, null),
157180
(null, null, null, null)
158181
)

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,31 @@ class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
3434
rightRows: DataFrame,
3535
condition: Expression,
3636
expectedAnswer: Seq[Product]): Unit = {
37-
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
38-
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
39-
ExtractEquiJoinKeys.unapply(join).foreach {
40-
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
41-
test(s"$testName using LeftSemiJoinHash") {
37+
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
38+
ExtractEquiJoinKeys.unapply(join).foreach {
39+
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
40+
test(s"$testName using LeftSemiJoinHash") {
41+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
4242
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
4343
EnsureRequirements(left.sqlContext).apply(
4444
LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
4545
expectedAnswer.map(Row.fromTuple),
4646
sortAnswers = true)
4747
}
48+
}
4849

49-
test(s"$testName using BroadcastLeftSemiJoinHash") {
50+
test(s"$testName using BroadcastLeftSemiJoinHash") {
51+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
5052
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
5153
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
5254
expectedAnswer.map(Row.fromTuple),
5355
sortAnswers = true)
5456
}
55-
}
57+
}
58+
}
5659

57-
test(s"$testName using LeftSemiJoinBNL") {
60+
test(s"$testName using LeftSemiJoinBNL") {
61+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
5862
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
5963
LeftSemiJoinBNL(left, right, Some(condition)),
6064
expectedAnswer.map(Row.fromTuple),

0 commit comments

Comments
 (0)