diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 952ba7d45c13e..3324393329d0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -96,13 +96,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) - // If the sort merge join option is set, we want to use sort merge join prior to hashjoin - // for now let's support inner join first, then add outer join - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys( + LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys( + RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + + // If the sort merge join option is set, we want to use sort merge join prior to hashjoin. + // And for outer join, we can not put conditions outside of the join + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => - val mergeJoin = - joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) - condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + joins.SortMergeJoin( + leftKeys, rightKeys, joinType, planLater(left), planLater(right), condition) :: Nil case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = @@ -115,16 +124,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil - case ExtractEquiJoinKeys( - LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil - - case ExtractEquiJoinKeys( - RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.ShuffledHashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index eee8ad800f98e..1dcfd04bb8665 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -55,7 +55,6 @@ case class ShuffledHashOuterJoin( protected override def doExecute(): RDD[InternalRow] = { val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => val hashed = HashedRelation(rightIter, buildKeyGenerator) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 4ae23c186cf7b..ec001108d2717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -23,9 +23,10 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.{BitSet, CompactBuffer} /** * :: DeveloperApi :: @@ -35,26 +36,63 @@ import org.apache.spark.util.collection.CompactBuffer case class SortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + joinType: JoinType, left: SparkPlan, - right: SparkPlan) extends BinaryNode { + right: SparkPlan, + condition: Option[Expression] = None) extends BinaryNode { override protected[sql] val trackNumOfRowsEnabled = true - override def output: Seq[Attribute] = left.output ++ right.output + val (streamedPlan, bufferedPlan, streamedKeys, bufferedKeys) = joinType match { + case RightOuter => (right, left, rightKeys, leftKeys) + case _ => (left, right, leftKeys, rightKeys) + } + + override def output: Seq[Attribute] = joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalStateException(s"SortMergeJoin should not take $x as the JoinType") + } - override def outputPartitioning: Partitioning = - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + override def outputPartitioning: Partitioning = joinType match { + case FullOuter => + // when doing Full Outer join, NULL rows from both sides are not so partitioned. + UnknownPartitioning(streamedPlan.outputPartitioning.numPartitions) + case Inner => + PartitioningCollection(Seq(streamedPlan.outputPartitioning, bufferedPlan.outputPartitioning)) + case LeftOuter | RightOuter => + streamedPlan.outputPartitioning + case x => + throw new IllegalStateException(s"SortMergeJoin should not take $x as the JoinType") + } override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + override def outputOrdering: Seq[SortOrder] = joinType match { + case FullOuter => Nil // when doing Full Outer join, NULL rows from both sides are not ordered. + case _ => requiredOrders(streamedKeys) + } override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) - @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + @transient protected lazy val streamedKeyGenerator = + newProjection(streamedKeys, streamedPlan.output) + @transient protected lazy val bufferedKeyGenerator = + newProjection(bufferedKeys, bufferedPlan.output) + + // checks if the joinedRow can meet condition requirements + @transient private[this] lazy val boundCondition = + condition.map(newPredicate(_, streamedPlan.output ++ bufferedPlan.output)).getOrElse( + (row: InternalRow) => true) private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. @@ -62,111 +100,250 @@ case class SortMergeJoin( } protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) + val streamResults = streamedPlan.execute().map(_.copy()) + val bufferResults = bufferedPlan.execute().map(_.copy()) - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + streamResults.zipPartitions(bufferResults) ( (streamedIter, bufferedIter) => { + // standard null rows + val streamedNullRow = InternalRow.fromSeq(Seq.fill(streamedPlan.output.length)(null)) + val bufferedNullRow = InternalRow.fromSeq(Seq.fill(bufferedPlan.output.length)(null)) new Iterator[InternalRow] { // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] var leftElement: InternalRow = _ - private[this] var rightElement: InternalRow = _ - private[this] var leftKey: InternalRow = _ - private[this] var rightKey: InternalRow = _ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ - private[this] var rightPosition: Int = -1 + private[this] var streamedElement: InternalRow = _ + private[this] var bufferedElement: InternalRow = _ + private[this] var streamedKey: InternalRow = _ + private[this] var bufferedKey: InternalRow = _ + private[this] var bufferedMatches: CompactBuffer[InternalRow] = _ + private[this] var bufferedPosition: Int = -1 private[this] var stop: Boolean = false private[this] var matchKey: InternalRow = _ + // when we do merge algorithm and find some not matched join key, there must be a side + // that do not have a corresponding match. So we need to mark which side it is. True means + // streamed side not have match, and False means the buffered side. Only set when needed. + private[this] var continueStreamed: Boolean = _ + private[this] var streamNullGenerated: Boolean = false + // Tracks if each element in bufferedMatches have a matched streamedElement. + private[this] var bitSet: BitSet = _ + // marks if the found result has been fetched. + private[this] var found: Boolean = false + private[this] var bufferNullGenerated: Boolean = false // initialize iterator initialize() - override final def hasNext: Boolean = nextMatchingPair() + override final def hasNext: Boolean = { + val matching = nextMatchingBlock() + if (matching && !isBufferEmpty(bufferedMatches)) { + // The buffer stores all rows that match key, but condition may not be matched. + // If none of rows in the buffer match condition, we'll fetch next matching block. + findNextInBuffer() || hasNext + } else { + matching + } + } + + /** + * Run down the current `bufferedMatches` to find rows that match conditions. + * If `joinType` is not `Inner`, we will use `bufferNullGenerated` to mark if + * we need to build a bufferedNullRow for result. + * If `joinType` is `FullOuter`, we will use `streamNullGenerated` to mark if + * a buffered element need to join with a streamedNullRow. + * The method can be called multiple times since `found` serves as a guardian. + */ + def findNextInBuffer(): Boolean = { + while (!found && streamedElement != null + && keyOrdering.compare(streamedKey, matchKey) == 0) { + while (bufferedPosition < bufferedMatches.size && !boundCondition( + joinRow(streamedElement, bufferedMatches(bufferedPosition)))) { + bufferedPosition += 1 + } + if (bufferedPosition == bufferedMatches.size) { + if (joinType == Inner || bufferNullGenerated) { + bufferNullGenerated = false + bufferedPosition = 0 + fetchStreamed() + } else { + found = true + } + } else { + // mark as true so we don't generate null row for streamed row. + bufferNullGenerated = true + bitSet.set(bufferedPosition) + found = true + } + } + if (!found) { + if (joinType == FullOuter && !streamNullGenerated) { + streamNullGenerated = true + } + if (streamNullGenerated) { + while (bufferedPosition < bufferedMatches.size && bitSet.get(bufferedPosition)) { + bufferedPosition += 1 + } + if (bufferedPosition < bufferedMatches.size) { + found = true + } + } + } + if (!found) { + stop = false + bufferedMatches = null + } + found + } override final def next(): InternalRow = { if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null + if (isBufferEmpty(bufferedMatches)) { + // we just found a row with no join match and we are here to produce a row + // with this row and a standard null row from the other side. + if (continueStreamed) { + val joinedRow = smartJoinRow(streamedElement, bufferedNullRow) + fetchStreamed() + joinedRow + } else { + val joinedRow = smartJoinRow(streamedNullRow, bufferedElement) + fetchBuffered() + joinedRow } + } else { + // we are using the buffered right rows and run down left iterator + val joinedRow = if (streamNullGenerated) { + val ret = smartJoinRow(streamedNullRow, bufferedMatches(bufferedPosition)) + bufferedPosition += 1 + ret + } else { + if (bufferedPosition == bufferedMatches.size && !bufferNullGenerated) { + val ret = smartJoinRow(streamedElement, bufferedNullRow) + bufferNullGenerated = true + ret + } else { + val ret = smartJoinRow(streamedElement, bufferedMatches(bufferedPosition)) + bufferedPosition += 1 + ret + } + } + found = false + joinedRow } - joinedRow } else { // no more result throw new NoSuchElementException } } - private def fetchLeft() = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) + private def smartJoinRow(streamedRow: InternalRow, bufferedRow: InternalRow): InternalRow = + joinType match { + case RightOuter => joinRow(bufferedRow, streamedRow) + case _ => joinRow(streamedRow, bufferedRow) + } + + private def fetchStreamed(): Unit = { + if (streamedIter.hasNext) { + streamedElement = streamedIter.next() + streamedKey = streamedKeyGenerator(streamedElement) } else { - leftElement = null + streamedElement = null } } - private def fetchRight() = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) + private def fetchBuffered(): Unit = { + if (bufferedIter.hasNext) { + bufferedElement = bufferedIter.next() + bufferedKey = bufferedKeyGenerator(bufferedElement) } else { - rightElement = null + bufferedElement = null } } private def initialize() = { - fetchLeft() - fetchRight() + fetchStreamed() + fetchBuffered() } /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. + * Searches the right iterator for the next rows that have matches in left side (only check + * key match), and stores them in a buffer. + * This search will jump out every time from the same position until `next()` is called. + * Unless we call `next()`, this function can be called multiple times, with the same + * return value and result as running it once, since we have set guardians in it. * * @return true if the search is successful, and false if the right iterator runs out of * tuples. */ - private def nextMatchingPair(): Boolean = { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - val comparing = keyOrdering.compare(leftKey, rightKey) + private def nextMatchingBlock(): Boolean = { + if (!stop && streamedElement != null) { + // step 1: run both side to get the first match pair + while (!stop && streamedElement != null && bufferedElement != null) { + val comparing = keyOrdering.compare(streamedKey, bufferedKey) // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull - if (comparing > 0 || rightKey.anyNull) { - fetchRight() - } else if (comparing < 0 || leftKey.anyNull) { - fetchLeft() + stop = comparing == 0 && !streamedKey.anyNull + if (comparing > 0 || bufferedKey.anyNull) { + if (joinType == FullOuter) { + // the join type is full outer and the buffered side has a row with no + // join match, so we have a result row with streamed null with buffered + // side as this row. Then we fetch next buffered element and go back. + continueStreamed = false + return true + } else { + fetchBuffered() + } + } else if (comparing < 0 || streamedKey.anyNull) { + if (joinType == Inner) { + fetchStreamed() + } else { + // the join type is not inner and the streamed side has a row with no + // join match, so we have a result row with this streamed row with buffered + // null row. Then we fetch next streamed element and go back. + continueStreamed = true + return true + } } } - rightMatches = new CompactBuffer[InternalRow]() + // step 2: run down the buffered side to put all matched rows in a buffer + bufferedMatches = new CompactBuffer[InternalRow]() if (stop) { stop = false // iterate the right side to buffer all rows that matches // as the records should be ordered, exit when we meet the first that not match - while (!stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - stop = keyOrdering.compare(leftKey, rightKey) != 0 + while (!stop) { + bufferedMatches += bufferedElement + fetchBuffered() + stop = + keyOrdering.compare(streamedKey, bufferedKey) != 0 || bufferedElement == null } - if (rightMatches.size > 0) { - rightPosition = 0 - matchKey = leftKey + bufferedPosition = 0 + streamNullGenerated = false + bitSet = new BitSet(bufferedMatches.size) + matchKey = streamedKey + } + } + // `stop` is false iff left or right has finished iteration in step 1. + // if we get into step 2, `stop` cannot be false. + if (!stop && isBufferEmpty(bufferedMatches)) { + if (streamedElement == null && bufferedElement != null) { + // streamedElement == null but bufferedElement != null + if (joinType == FullOuter) { + continueStreamed = false + return true + } + } else if (streamedElement != null && bufferedElement == null) { + // bufferedElement == null but streamedElement != null + if (joinType != Inner) { + continueStreamed = true + return true } } } - rightMatches != null && rightMatches.size > 0 + !isBufferEmpty(bufferedMatches) } + + private def isBufferEmpty(buffer: CompactBuffer[InternalRow]): Boolean = + buffer == null || buffer.isEmpty } - } + }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5bef1d8966031..16bf88c961b3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,13 +22,14 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.test.SQLTestUtils -class JoinSuite extends QueryTest with BeforeAndAfterEach { +class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Ensures tables are loaded. TestData + override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ import ctx.logicalPlanToSparkQuery @@ -66,7 +67,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { ctx.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -83,13 +83,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -97,27 +96,29 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("SortMergeJoin shouldn't work on unsortable columns") { - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } @@ -125,15 +126,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", @@ -141,8 +140,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } ctx.sql("UNCACHE TABLE testData") @@ -152,16 +149,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[BroadcastHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", @@ -169,8 +164,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } ctx.sql("UNCACHE TABLE testData") @@ -212,9 +205,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( x.join(y).where($"x.a" === $"y.a"), Row(1, 1, 1, 1) :: - Row(1, 1, 1, 2) :: - Row(1, 2, 1, 1) :: - Row(1, 2, 1, 2) :: Nil + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil ) } @@ -457,25 +450,24 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("broadcasted left semi join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - val tmp = ctx.conf.autoBroadcastJoinThreshold - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000") - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastLeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") - - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp) ctx.sql("UNCACHE TABLE testData") } @@ -488,6 +480,5 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) - } }