@@ -34,67 +34,75 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
3434 rightRows : DataFrame ,
3535 condition : Expression ,
3636 expectedAnswer : Seq [Product ]): Unit = {
37- withSQLConf( SQLConf . SHUFFLE_PARTITIONS .key -> " 1 " ) {
38- val join = Join (leftRows.logicalPlan, rightRows.logicalPlan, Inner , Some (condition))
39- ExtractEquiJoinKeys .unapply(join).foreach {
40- case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
41-
42- def makeBroadcastHashJoin ( left : SparkPlan , right : SparkPlan , side : BuildSide ) = {
43- val broadcastHashJoin =
44- execution.joins. BroadcastHashJoin (leftKeys, rightKeys, side, left, right )
45- boundCondition.map( Filter (_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
46- }
47-
48- def makeShuffledHashJoin ( left : SparkPlan , right : SparkPlan , side : BuildSide ) = {
49- val shuffledHashJoin =
50- execution.joins. ShuffledHashJoin (leftKeys, rightKeys, side, left, right)
51- val filteredJoin =
52- boundCondition.map( Filter (_, shuffledHashJoin)).getOrElse(shuffledHashJoin )
53- EnsureRequirements (sqlContext).apply(filteredJoin)
54- }
55-
56- def makeSortMergeJoin ( left : SparkPlan , right : SparkPlan ) = {
57- val sortMergeJoin =
58- execution.joins. SortMergeJoin (leftKeys, rightKeys, left, right )
59- val filteredJoin = boundCondition.map( Filter (_, sortMergeJoin)).getOrElse(sortMergeJoin )
60- EnsureRequirements (sqlContext).apply(filteredJoin)
61- }
62-
63- test( s " $testName using BroadcastHashJoin (build=left) " ) {
37+ val join = Join (leftRows.logicalPlan, rightRows.logicalPlan, Inner , Some (condition))
38+ ExtractEquiJoinKeys .unapply(join).foreach {
39+ case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
40+
41+ def makeBroadcastHashJoin ( left : SparkPlan , right : SparkPlan , side : BuildSide ) = {
42+ val broadcastHashJoin =
43+ execution.joins. BroadcastHashJoin (leftKeys, rightKeys, side, left, right)
44+ boundCondition.map( Filter (_, broadcastHashJoin)).getOrElse(broadcastHashJoin )
45+ }
46+
47+ def makeShuffledHashJoin ( left : SparkPlan , right : SparkPlan , side : BuildSide ) = {
48+ val shuffledHashJoin =
49+ execution.joins. ShuffledHashJoin (leftKeys, rightKeys, side, left, right)
50+ val filteredJoin =
51+ boundCondition.map( Filter (_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
52+ EnsureRequirements (sqlContext).apply(filteredJoin )
53+ }
54+
55+ def makeSortMergeJoin ( left : SparkPlan , right : SparkPlan ) = {
56+ val sortMergeJoin =
57+ execution.joins. SortMergeJoin (leftKeys, rightKeys, left, right)
58+ val filteredJoin = boundCondition.map( Filter (_, sortMergeJoin)).getOrElse(sortMergeJoin )
59+ EnsureRequirements (sqlContext).apply(filteredJoin )
60+ }
61+
62+ test( s " $testName using BroadcastHashJoin (build=left) " ) {
63+ withSQLConf( SQLConf . SHUFFLE_PARTITIONS .key -> " 1 " ) {
6464 checkAnswer2(leftRows, rightRows, (left : SparkPlan , right : SparkPlan ) =>
6565 makeBroadcastHashJoin(left, right, joins.BuildLeft ),
6666 expectedAnswer.map(Row .fromTuple),
6767 sortAnswers = true )
6868 }
69+ }
6970
70- test(s " $testName using BroadcastHashJoin (build=right) " ) {
71+ test(s " $testName using BroadcastHashJoin (build=right) " ) {
72+ withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 1" ) {
7173 checkAnswer2(leftRows, rightRows, (left : SparkPlan , right : SparkPlan ) =>
7274 makeBroadcastHashJoin(left, right, joins.BuildRight ),
7375 expectedAnswer.map(Row .fromTuple),
7476 sortAnswers = true )
7577 }
78+ }
7679
77- test(s " $testName using ShuffledHashJoin (build=left) " ) {
80+ test(s " $testName using ShuffledHashJoin (build=left) " ) {
81+ withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 1" ) {
7882 checkAnswer2(leftRows, rightRows, (left : SparkPlan , right : SparkPlan ) =>
7983 makeShuffledHashJoin(left, right, joins.BuildLeft ),
8084 expectedAnswer.map(Row .fromTuple),
8185 sortAnswers = true )
8286 }
87+ }
8388
84- test(s " $testName using ShuffledHashJoin (build=right) " ) {
89+ test(s " $testName using ShuffledHashJoin (build=right) " ) {
90+ withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 1" ) {
8591 checkAnswer2(leftRows, rightRows, (left : SparkPlan , right : SparkPlan ) =>
8692 makeShuffledHashJoin(left, right, joins.BuildRight ),
8793 expectedAnswer.map(Row .fromTuple),
8894 sortAnswers = true )
8995 }
96+ }
9097
91- test(s " $testName using SortMergeJoin " ) {
98+ test(s " $testName using SortMergeJoin " ) {
99+ withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 1" ) {
92100 checkAnswer2(leftRows, rightRows, (left : SparkPlan , right : SparkPlan ) =>
93101 makeSortMergeJoin(left, right),
94102 expectedAnswer.map(Row .fromTuple),
95103 sortAnswers = true )
96104 }
97- }
105+ }
98106 }
99107 }
100108
0 commit comments