From d95417ebb9f4a70e945615af93afb478cc1ac135 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 18 Jun 2015 21:47:30 -0700 Subject: [PATCH 01/10] use sort merge join for outer join --- .../spark/sql/execution/SparkStrategies.scala | 11 +- .../sql/execution/joins/SortMergeJoin.scala | 246 ++++++++++++++---- .../org/apache/spark/sql/JoinSuite.scala | 8 +- 3 files changed, 201 insertions(+), 64 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 952ba7d45c13e..d14501bfc68e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -96,13 +96,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) - // If the sort merge join option is set, we want to use sort merge join prior to hashjoin - // for now let's support inner join first, then add outer join - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + // If the sort merge join option is set, we want to use sort merge join prior to hashjoin. + // And for outer join, we can not put conditions outside of the join + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => - val mergeJoin = - joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) - condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + joins.SortMergeJoin( + leftKeys, rightKeys, joinType, planLater(left), planLater(right), condition) :: Nil case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 4ae23c186cf7b..889de189e4904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -23,6 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.collection.CompactBuffer @@ -35,26 +36,65 @@ import org.apache.spark.util.collection.CompactBuffer case class SortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + joinType: JoinType, left: SparkPlan, - right: SparkPlan) extends BinaryNode { + right: SparkPlan, + condition: Option[Expression] = None) extends BinaryNode { override protected[sql] val trackNumOfRowsEnabled = true - override def output: Seq[Attribute] = left.output ++ right.output + val (streamed, buffered, streamedKeys, bufferedKeys) = joinType match { + case RightOuter => (right, left, rightKeys, leftKeys) + case _ => (left, right, leftKeys, rightKeys) + } + + override def output: Seq[Attribute] = joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalStateException(s"SortMergeJoin should not take $x as the JoinType") + } - override def outputPartitioning: Partitioning = - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + override def outputPartitioning: Partitioning = joinType match { + case FullOuter => + // when doing Full Outer join, NULL rows from both sides are not so partitioned. + UnknownPartitioning(streamed.outputPartitioning.numPartitions) + case Inner => + PartitioningCollection(Seq(streamed.outputPartitioning, buffered.outputPartitioning)) + case LeftOuter | rightOuter => + streamed.outputPartitioning + case x => + throw new IllegalStateException(s"SortMergeJoin should not take $x as the JoinType") + } override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + override def outputOrdering: Seq[SortOrder] = joinType match { + case FullOuter => Nil // when doing Full Outer join, NULL rows from both sides are not ordered. + case _ => requiredOrders(streamedKeys) + } override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) - @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + @transient protected lazy val streamedKeyGenerator = newProjection(streamedKeys, streamed.output) + @transient protected lazy val bufferedKeyGenerator = newProjection(bufferedKeys, buffered.output) + + // standard null rows + @transient private[this] lazy val streamedNullRow = new GenericRow(streamed.output.length) + @transient private[this] lazy val bufferedNullRow = new GenericRow(buffered.output.length) + + // checks if the joinedRow can meet condition requirements + @transient private[this] lazy val boundCondition = + condition.map( + newPredicate(_, streamed.output ++ buffered.output)).getOrElse((row: InternalRow) => true) private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. @@ -62,23 +102,34 @@ case class SortMergeJoin( } protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) + val streamResults = streamed.execute().map(_.copy()) + val bufferResults = buffered.execute().map(_.copy()) - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + streamResults.zipPartitions(bufferResults) { (streamedIter, bufferedIter) => new Iterator[InternalRow] { // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) // Mutable per row objects. - private[this] val joinRow = new JoinedRow - private[this] var leftElement: InternalRow = _ - private[this] var rightElement: InternalRow = _ - private[this] var leftKey: InternalRow = _ - private[this] var rightKey: InternalRow = _ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ - private[this] var rightPosition: Int = -1 + private[this] val joinRow = new JoinedRow5 + private[this] var streamedElement: InternalRow = _ + private[this] var bufferedElement: InternalRow = _ + private[this] var streamedKey: InternalRow = _ + private[this] var bufferedKey: InternalRow = _ + private[this] var bufferedMatches: CompactBuffer[InternalRow] = _ + private[this] var bufferedPosition: Int = -1 private[this] var stop: Boolean = false private[this] var matchKey: InternalRow = _ + // when we do merge algorithm and find some not matched join key, there must be a side + // that do not have a corresponding match. So we need to mark which side it is. True means + // streamed side not have match, and False means the buffered side. Only set when needed. + private[this] var continueStreamed: Boolean = _ + // when we do full outer join and find all matched keys, we put a null stream row into + // this to tell next() that we need to combine null stream row with all rows that not match + // conditions. + private[this] var secondStreamedElement: InternalRow = _ + // Stores rows that match the join key but not match conditions. + // These rows will be useful when we are doing Full Outer Join. + private[this] var secondBufferedMatches: CompactBuffer[InternalRow] = _ // initialize iterator initialize() @@ -87,84 +138,165 @@ case class SortMergeJoin( override final def next(): InternalRow = { if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null + if (bufferedMatches == null || bufferedMatches.size == 0) { + // we just found a row with no join match and we are here to produce a row + // with this row with a standard null row from the other side. + if (continueStreamed) { + val joinedRow = smartJoinRow(streamedElement, bufferedNullRow.copy()) + fetchStreamed() + joinedRow + } else { + val joinedRow = smartJoinRow(streamedNullRow.copy(), bufferedElement) + fetchBuffered() + joinedRow + } + } else { + // we are using the buffered right rows and run down left iterator + val joinedRow = smartJoinRow(streamedElement, bufferedMatches(bufferedPosition)) + bufferedPosition += 1 + if (bufferedPosition >= bufferedMatches.size) { + bufferedPosition = 0 + if (joinType != FullOuter || secondStreamedElement == null) { + fetchStreamed() + if (streamedElement == null || keyOrdering.compare(streamedKey, matchKey) != 0) { + stop = false + bufferedMatches = null + } + } else { + // in FullOuter join and the first time we finish the match buffer, + // we still want to generate all rows with streamed null row and buffered + // rows that match the join key but not the conditions. + streamedElement = secondStreamedElement + bufferedMatches = secondBufferedMatches + secondStreamedElement = null + secondBufferedMatches = null + } } + joinedRow } - joinedRow } else { // no more result throw new NoSuchElementException } } - private def fetchLeft() = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) + private def smartJoinRow(streamedRow: InternalRow, bufferedRow: InternalRow): InternalRow = + joinType match { + case RightOuter => joinRow(bufferedRow, streamedRow) + case _ => joinRow(streamedRow, bufferedRow) + } + + private def fetchStreamed() = { + if (streamedIter.hasNext) { + streamedElement = streamedIter.next() + streamedKey = streamedKeyGenerator(streamedElement) } else { - leftElement = null + streamedElement = null } } - private def fetchRight() = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) + private def fetchBuffered() = { + if (bufferedIter.hasNext) { + bufferedElement = bufferedIter.next() + bufferedKey = bufferedKeyGenerator(bufferedElement) } else { - rightElement = null + bufferedElement = null } } private def initialize() = { - fetchLeft() - fetchRight() + fetchStreamed() + fetchBuffered() } /** * Searches the right iterator for the next rows that have matches in left side, and store * them in a buffer. + * When this is not a Inner join, we will also return true when we get a row with no match + * on the other side. This search will jump out every time from the same position until + * `next()` is called. * * @return true if the search is successful, and false if the right iterator runs out of * tuples. */ private def nextMatchingPair(): Boolean = { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - val comparing = keyOrdering.compare(leftKey, rightKey) + if (!stop && streamedElement != null) { + // step 1: run both side to get the first match pair + while (!stop && streamedElement != null && bufferedElement != null) { + val comparing = keyOrdering.compare(streamedKey, bufferedKey) // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull - if (comparing > 0 || rightKey.anyNull) { - fetchRight() - } else if (comparing < 0 || leftKey.anyNull) { - fetchLeft() + stop = comparing == 0 && !streamedKey.anyNull + if (comparing > 0 || bufferedKey.anyNull) { + if (joinType == FullOuter) { + // the join type is full outer and the buffered side has a row with no + // join match, so we have a result row with streamed null with buffered + // side as this row. Then we fetch next buffered element and go back. + continueStreamed = false + return true + } else { + fetchBuffered() + } + } else if (comparing < 0 || streamedKey.anyNull) { + if (joinType == Inner) { + fetchStreamed() + } else { + // the join type is not inner and the streamed side has a row with no + // join match, so we have a result row with this streamed row with buffered + // null row. Then we fetch next streamed element and go back. + continueStreamed = true + return true + } } } - rightMatches = new CompactBuffer[InternalRow]() + // step 2: run down the buffered side to put all matched rows in a buffer + bufferedMatches = new CompactBuffer[InternalRow]() + secondBufferedMatches = new CompactBuffer[InternalRow]() if (stop) { stop = false // iterate the right side to buffer all rows that matches // as the records should be ordered, exit when we meet the first that not match - while (!stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - stop = keyOrdering.compare(leftKey, rightKey) != 0 + while (!stop) { + if (boundCondition(joinRow(streamedElement, bufferedElement))) { + bufferedMatches += bufferedElement + } else if (joinType == FullOuter) { + bufferedMatches += bufferedNullRow.copy() + secondBufferedMatches += bufferedElement + } + fetchBuffered() + stop = + keyOrdering.compare(streamedKey, bufferedKey) != 0 || bufferedElement == null + } + if (bufferedMatches.size == 0 && joinType != Inner) { + bufferedMatches += bufferedNullRow.copy() + } + if (bufferedMatches.size > 0) { + bufferedPosition = 0 + matchKey = streamedKey + // secondBufferedMatches.size cannot be larger than bufferedMatches + if (secondBufferedMatches.size > 0) { + secondStreamedElement = streamedNullRow.copy() + } + } + } + } + // `stop` is false iff left or right has finished iteration in step 1. + // if we get into step 2, `stop` cannot be false. + if (!stop && (bufferedMatches == null || bufferedMatches.size == 0)) { + if (streamedElement == null && bufferedElement != null) { + // streamedElement == null but bufferedElement != null + if (joinType == FullOuter) { + continueStreamed = false + return true } - if (rightMatches.size > 0) { - rightPosition = 0 - matchKey = leftKey + } else if (streamedElement != null && bufferedElement == null) { + // bufferedElement == null but streamedElement != null + if (joinType != Inner) { + continueStreamed = true + return true } } } - rightMatches != null && rightMatches.size > 0 + bufferedMatches != null && bufferedMatches.size > 0 } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5bef1d8966031..471c29258f161 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -102,7 +102,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Seq( ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[SortMergeJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[SortMergeJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) From 71ff4e910b574e2e9ef0b839558abc32569eb193 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 18 Jun 2015 23:22:46 -0700 Subject: [PATCH 02/10] rebase --- .../sql/execution/joins/SortMergeJoin.scala | 51 ++++++++++--------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 889de189e4904..884cdb21b34c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -43,7 +43,7 @@ case class SortMergeJoin( override protected[sql] val trackNumOfRowsEnabled = true - val (streamed, buffered, streamedKeys, bufferedKeys) = joinType match { + val (streamedPlan, bufferedPlan, streamedKeys, bufferedKeys) = joinType match { case RightOuter => (right, left, rightKeys, leftKeys) case _ => (left, right, leftKeys, rightKeys) } @@ -64,11 +64,11 @@ case class SortMergeJoin( override def outputPartitioning: Partitioning = joinType match { case FullOuter => // when doing Full Outer join, NULL rows from both sides are not so partitioned. - UnknownPartitioning(streamed.outputPartitioning.numPartitions) + UnknownPartitioning(streamedPlan.outputPartitioning.numPartitions) case Inner => - PartitioningCollection(Seq(streamed.outputPartitioning, buffered.outputPartitioning)) + PartitioningCollection(Seq(streamedPlan.outputPartitioning, bufferedPlan.outputPartitioning)) case LeftOuter | rightOuter => - streamed.outputPartitioning + streamedPlan.outputPartitioning case x => throw new IllegalStateException(s"SortMergeJoin should not take $x as the JoinType") } @@ -84,17 +84,15 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - @transient protected lazy val streamedKeyGenerator = newProjection(streamedKeys, streamed.output) - @transient protected lazy val bufferedKeyGenerator = newProjection(bufferedKeys, buffered.output) - - // standard null rows - @transient private[this] lazy val streamedNullRow = new GenericRow(streamed.output.length) - @transient private[this] lazy val bufferedNullRow = new GenericRow(buffered.output.length) + @transient protected lazy val streamedKeyGenerator = + newProjection(streamedKeys, streamedPlan.output) + @transient protected lazy val bufferedKeyGenerator = + newProjection(bufferedKeys, bufferedPlan.output) // checks if the joinedRow can meet condition requirements @transient private[this] lazy val boundCondition = - condition.map( - newPredicate(_, streamed.output ++ buffered.output)).getOrElse((row: InternalRow) => true) + condition.map(newPredicate(_, streamedPlan.output ++ bufferedPlan.output)).getOrElse( + (row: InternalRow) => true) private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. @@ -102,10 +100,13 @@ case class SortMergeJoin( } protected override def doExecute(): RDD[InternalRow] = { - val streamResults = streamed.execute().map(_.copy()) - val bufferResults = buffered.execute().map(_.copy()) + val streamResults = streamedPlan.execute().map(_.copy()) + val bufferResults = bufferedPlan.execute().map(_.copy()) - streamResults.zipPartitions(bufferResults) { (streamedIter, bufferedIter) => + streamResults.zipPartitions(bufferResults) ( (streamedIter, bufferedIter) => { + // standard null rows + val streamedNullRow = new GenericRow(streamedPlan.output.length) + val bufferedNullRow = new GenericRow(bufferedPlan.output.length) new Iterator[InternalRow] { // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) @@ -140,13 +141,13 @@ case class SortMergeJoin( if (hasNext) { if (bufferedMatches == null || bufferedMatches.size == 0) { // we just found a row with no join match and we are here to produce a row - // with this row with a standard null row from the other side. + // with this row and a standard null row from the other side. if (continueStreamed) { - val joinedRow = smartJoinRow(streamedElement, bufferedNullRow.copy()) + val joinedRow = smartJoinRow(streamedElement, bufferedNullRow) fetchStreamed() joinedRow } else { - val joinedRow = smartJoinRow(streamedNullRow.copy(), bufferedElement) + val joinedRow = smartJoinRow(streamedNullRow, bufferedElement) fetchBuffered() joinedRow } @@ -186,7 +187,7 @@ case class SortMergeJoin( case _ => joinRow(streamedRow, bufferedRow) } - private def fetchStreamed() = { + private def fetchStreamed(): Unit = { if (streamedIter.hasNext) { streamedElement = streamedIter.next() streamedKey = streamedKeyGenerator(streamedElement) @@ -195,7 +196,7 @@ case class SortMergeJoin( } } - private def fetchBuffered() = { + private def fetchBuffered(): Unit = { if (bufferedIter.hasNext) { bufferedElement = bufferedIter.next() bufferedKey = bufferedKeyGenerator(bufferedElement) @@ -215,6 +216,8 @@ case class SortMergeJoin( * When this is not a Inner join, we will also return true when we get a row with no match * on the other side. This search will jump out every time from the same position until * `next()` is called. + * Unless we call `next()`, this function can be called multiple times, with the same + * return value and result as running it once, since we have set guardians in it. * * @return true if the search is successful, and false if the right iterator runs out of * tuples. @@ -259,7 +262,7 @@ case class SortMergeJoin( if (boundCondition(joinRow(streamedElement, bufferedElement))) { bufferedMatches += bufferedElement } else if (joinType == FullOuter) { - bufferedMatches += bufferedNullRow.copy() + bufferedMatches += bufferedNullRow secondBufferedMatches += bufferedElement } fetchBuffered() @@ -267,14 +270,14 @@ case class SortMergeJoin( keyOrdering.compare(streamedKey, bufferedKey) != 0 || bufferedElement == null } if (bufferedMatches.size == 0 && joinType != Inner) { - bufferedMatches += bufferedNullRow.copy() + bufferedMatches += bufferedNullRow } if (bufferedMatches.size > 0) { bufferedPosition = 0 matchKey = streamedKey // secondBufferedMatches.size cannot be larger than bufferedMatches if (secondBufferedMatches.size > 0) { - secondStreamedElement = streamedNullRow.copy() + secondStreamedElement = streamedNullRow } } } @@ -299,6 +302,6 @@ case class SortMergeJoin( bufferedMatches != null && bufferedMatches.size > 0 } } - } + }) } } From a8d1ff7262f744a5c76fd78c4eea4f3e5b3ba01d Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 30 Jul 2015 00:12:49 -0700 Subject: [PATCH 03/10] bring it up to date --- .../apache/spark/sql/execution/joins/SortMergeJoin.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 884cdb21b34c4..3dc3592aa617f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -105,13 +105,13 @@ case class SortMergeJoin( streamResults.zipPartitions(bufferResults) ( (streamedIter, bufferedIter) => { // standard null rows - val streamedNullRow = new GenericRow(streamedPlan.output.length) - val bufferedNullRow = new GenericRow(bufferedPlan.output.length) + val streamedNullRow = InternalRow.fromSeq(Seq.fill(bufferedPlan.output.length)(null)) + val bufferedNullRow = InternalRow.fromSeq(Seq.fill(bufferedPlan.output.length)(null)) new Iterator[InternalRow] { // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) // Mutable per row objects. - private[this] val joinRow = new JoinedRow5 + private[this] val joinRow = new JoinedRow private[this] var streamedElement: InternalRow = _ private[this] var bufferedElement: InternalRow = _ private[this] var streamedKey: InternalRow = _ From fdea91d0f6691757bfb78e98a77215b00b5480ae Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 30 Jul 2015 01:04:04 -0700 Subject: [PATCH 04/10] fix default setting change --- .../org/apache/spark/sql/JoinSuite.scala | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 471c29258f161..010ed322db6f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -83,13 +83,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -98,17 +95,20 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, false) Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[SortMergeJoin]), + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[SortMergeJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[SortMergeJoin]) + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) @@ -160,14 +160,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), + classOf[SortMergeJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) + classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, false) Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", From a4cf5cdb89674ce3fff805beeb60957284cbe893 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 30 Jul 2015 22:18:04 -0700 Subject: [PATCH 05/10] fix style --- .../src/test/scala/org/apache/spark/sql/JoinSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 010ed322db6f3..2f0baa1213ab3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -84,8 +84,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[SortMergeJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[SortMergeJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), From 53c2bdb459b0d27ae76fe0351707ad40db3b9d9a Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 3 Aug 2015 23:21:15 -0700 Subject: [PATCH 06/10] fix comments from @jeanlyn --- .../joins/ShuffledHashOuterJoin.scala | 1 - .../sql/execution/joins/SortMergeJoin.scala | 150 +++++++++++------- 2 files changed, 96 insertions(+), 55 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index eee8ad800f98e..1dcfd04bb8665 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -55,7 +55,6 @@ case class ShuffledHashOuterJoin( protected override def doExecute(): RDD[InternalRow] = { val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => val hashed = HashedRelation(rightIter, buildKeyGenerator) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 3dc3592aa617f..44e1839ce2c8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.{BitSet, CompactBuffer} /** * :: DeveloperApi :: @@ -105,7 +105,7 @@ case class SortMergeJoin( streamResults.zipPartitions(bufferResults) ( (streamedIter, bufferedIter) => { // standard null rows - val streamedNullRow = InternalRow.fromSeq(Seq.fill(bufferedPlan.output.length)(null)) + val streamedNullRow = InternalRow.fromSeq(Seq.fill(streamedPlan.output.length)(null)) val bufferedNullRow = InternalRow.fromSeq(Seq.fill(bufferedPlan.output.length)(null)) new Iterator[InternalRow] { // An ordering that can be used to compare keys from both sides. @@ -124,22 +124,80 @@ case class SortMergeJoin( // that do not have a corresponding match. So we need to mark which side it is. True means // streamed side not have match, and False means the buffered side. Only set when needed. private[this] var continueStreamed: Boolean = _ - // when we do full outer join and find all matched keys, we put a null stream row into - // this to tell next() that we need to combine null stream row with all rows that not match - // conditions. - private[this] var secondStreamedElement: InternalRow = _ - // Stores rows that match the join key but not match conditions. - // These rows will be useful when we are doing Full Outer Join. - private[this] var secondBufferedMatches: CompactBuffer[InternalRow] = _ + private[this] var streamNullGenerated: Boolean = false + // Tracks if each element in bufferedMatches have a matched streamedElement. + private[this] var bitSet: BitSet = _ + // marks if the found result has been fetched. + private[this] var found: Boolean = false + private[this] var bufferNullGenerated: Boolean = false // initialize iterator initialize() - override final def hasNext: Boolean = nextMatchingPair() + override final def hasNext: Boolean = { + val matching = nextMatchingBlock() + if (matching && !isBufferEmpty(bufferedMatches)) { + // The buffer stores all rows that match key, but condition may not be matched. + // If none of rows in the buffer match condition, we'll fetch next matching block. + findNextInBuffer() || hasNext + } else { + matching + } + } + + /** + * Run down the current `bufferedMatches` to find rows that match conditions. + * If `joinType` is not `Inner`, we will use `bufferNullGenerated` to mark if + * we need to build a bufferedNullRow for result. + * If `joinType` is `FullOuter`, we will use `streamNullGenerated` to mark if + * a buffered element need to join with a streamedNullRow. + * The method can be called multiple times since `found` serves as a guardian. + */ + def findNextInBuffer(): Boolean = { + while (!found && streamedElement != null + && keyOrdering.compare(streamedKey, matchKey) == 0) { + while (bufferedPosition < bufferedMatches.size && !boundCondition( + smartJoinRow(streamedElement, bufferedMatches(bufferedPosition)))) { + bufferedPosition += 1 + } + if (bufferedPosition == bufferedMatches.size) { + if (joinType == Inner || bufferNullGenerated) { + bufferNullGenerated = false + bufferedPosition = 0 + fetchStreamed() + } else { + found = true + } + } else { + // mark as true so we don't generate null row for streamed row. + bufferNullGenerated = true + bitSet.set(bufferedPosition) + found = true + } + } + if (!found) { + if (joinType == FullOuter && !streamNullGenerated) { + streamNullGenerated = true + } + if (streamNullGenerated) { + while (bufferedPosition < bufferedMatches.size && bitSet.get(bufferedPosition)) { + bufferedPosition += 1 + } + if (bufferedPosition < bufferedMatches.size) { + found = true + } + } + } + if (!found) { + stop = false + bufferedMatches = null + } + found + } override final def next(): InternalRow = { if (hasNext) { - if (bufferedMatches == null || bufferedMatches.size == 0) { + if (isBufferEmpty(bufferedMatches)) { // we just found a row with no join match and we are here to produce a row // with this row and a standard null row from the other side. if (continueStreamed) { @@ -153,26 +211,22 @@ case class SortMergeJoin( } } else { // we are using the buffered right rows and run down left iterator - val joinedRow = smartJoinRow(streamedElement, bufferedMatches(bufferedPosition)) - bufferedPosition += 1 - if (bufferedPosition >= bufferedMatches.size) { - bufferedPosition = 0 - if (joinType != FullOuter || secondStreamedElement == null) { - fetchStreamed() - if (streamedElement == null || keyOrdering.compare(streamedKey, matchKey) != 0) { - stop = false - bufferedMatches = null - } + val joinedRow = if (streamNullGenerated) { + val ret = smartJoinRow(streamedNullRow, bufferedMatches(bufferedPosition)) + bufferedPosition += 1 + ret + } else { + if (bufferedPosition == bufferedMatches.size && !bufferNullGenerated) { + val ret = smartJoinRow(streamedElement, bufferedNullRow) + bufferNullGenerated = true + ret } else { - // in FullOuter join and the first time we finish the match buffer, - // we still want to generate all rows with streamed null row and buffered - // rows that match the join key but not the conditions. - streamedElement = secondStreamedElement - bufferedMatches = secondBufferedMatches - secondStreamedElement = null - secondBufferedMatches = null + val ret = smartJoinRow(streamedElement, bufferedMatches(bufferedPosition)) + bufferedPosition += 1 + ret } } + found = false joinedRow } } else { @@ -211,18 +265,16 @@ case class SortMergeJoin( } /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. - * When this is not a Inner join, we will also return true when we get a row with no match - * on the other side. This search will jump out every time from the same position until - * `next()` is called. + * Searches the right iterator for the next rows that have matches in left side (only check + * key match), and stores them in a buffer. + * This search will jump out every time from the same position until `next()` is called. * Unless we call `next()`, this function can be called multiple times, with the same * return value and result as running it once, since we have set guardians in it. * * @return true if the search is successful, and false if the right iterator runs out of * tuples. */ - private def nextMatchingPair(): Boolean = { + private def nextMatchingBlock(): Boolean = { if (!stop && streamedElement != null) { // step 1: run both side to get the first match pair while (!stop && streamedElement != null && bufferedElement != null) { @@ -253,38 +305,25 @@ case class SortMergeJoin( } // step 2: run down the buffered side to put all matched rows in a buffer bufferedMatches = new CompactBuffer[InternalRow]() - secondBufferedMatches = new CompactBuffer[InternalRow]() if (stop) { stop = false // iterate the right side to buffer all rows that matches // as the records should be ordered, exit when we meet the first that not match while (!stop) { - if (boundCondition(joinRow(streamedElement, bufferedElement))) { - bufferedMatches += bufferedElement - } else if (joinType == FullOuter) { - bufferedMatches += bufferedNullRow - secondBufferedMatches += bufferedElement - } + bufferedMatches += bufferedElement fetchBuffered() stop = keyOrdering.compare(streamedKey, bufferedKey) != 0 || bufferedElement == null } - if (bufferedMatches.size == 0 && joinType != Inner) { - bufferedMatches += bufferedNullRow - } - if (bufferedMatches.size > 0) { - bufferedPosition = 0 - matchKey = streamedKey - // secondBufferedMatches.size cannot be larger than bufferedMatches - if (secondBufferedMatches.size > 0) { - secondStreamedElement = streamedNullRow - } - } + bufferedPosition = 0 + streamNullGenerated = false + bitSet = new BitSet(bufferedMatches.size) + matchKey = streamedKey } } // `stop` is false iff left or right has finished iteration in step 1. // if we get into step 2, `stop` cannot be false. - if (!stop && (bufferedMatches == null || bufferedMatches.size == 0)) { + if (!stop && isBufferEmpty(bufferedMatches)) { if (streamedElement == null && bufferedElement != null) { // streamedElement == null but bufferedElement != null if (joinType == FullOuter) { @@ -299,8 +338,11 @@ case class SortMergeJoin( } } } - bufferedMatches != null && bufferedMatches.size > 0 + !isBufferEmpty(bufferedMatches) } + + private def isBufferEmpty(buffer: CompactBuffer[InternalRow]): Boolean = + buffer == null || buffer.isEmpty } }) } From 6a771e9d9aa771ccbf6ef7cc594246956687f8e5 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 3 Aug 2015 23:43:56 -0700 Subject: [PATCH 07/10] Use withSQLConf in JoinSuite --- .../org/apache/spark/sql/JoinSuite.scala | 65 +++++++------------ 1 file changed, 24 insertions(+), 41 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2f0baa1213ab3..f8b98781f2b6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,13 +22,14 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.test.SQLTestUtils -class JoinSuite extends QueryTest with BeforeAndAfterEach { +class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Ensures tables are loaded. TestData + override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ import ctx.logicalPlanToSparkQuery @@ -66,7 +67,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { ctx.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -96,8 +96,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, false) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq( ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", @@ -112,20 +111,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[ShuffledHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("SortMergeJoin shouldn't work on unsortable columns") { - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } @@ -133,15 +126,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", @@ -149,8 +140,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } ctx.sql("UNCACHE TABLE testData") @@ -160,7 +149,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", @@ -168,8 +156,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, false) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", @@ -177,8 +164,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } ctx.sql("UNCACHE TABLE testData") @@ -220,9 +205,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( x.join(y).where($"x.a" === $"y.a"), Row(1, 1, 1, 1) :: - Row(1, 1, 1, 2) :: - Row(1, 2, 1, 1) :: - Row(1, 2, 1, 2) :: Nil + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil ) } @@ -465,25 +450,24 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("broadcasted left semi join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - val tmp = ctx.conf.autoBroadcastJoinThreshold - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000") - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastLeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") - - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp) ctx.sql("UNCACHE TABLE testData") } @@ -496,6 +480,5 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) - } -} +} \ No newline at end of file From 13f86bdea7ca713263f446548f89a16e2ee8c372 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 4 Aug 2015 00:08:14 -0700 Subject: [PATCH 08/10] minor fixes --- .../org/apache/spark/sql/execution/joins/SortMergeJoin.scala | 2 +- sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 44e1839ce2c8d..2b31d98bce85c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -67,7 +67,7 @@ case class SortMergeJoin( UnknownPartitioning(streamedPlan.outputPartitioning.numPartitions) case Inner => PartitioningCollection(Seq(streamedPlan.outputPartitioning, bufferedPlan.outputPartitioning)) - case LeftOuter | rightOuter => + case LeftOuter | RightOuter => streamedPlan.outputPartitioning case x => throw new IllegalStateException(s"SortMergeJoin should not take $x as the JoinType") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f8b98781f2b6c..463e7ba988745 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -481,4 +481,4 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(3, 1) :: Row(3, 2) :: Nil) } -} \ No newline at end of file +} From d2a1d12422a15bf0c9b3209f7383ca541f4ee635 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 4 Aug 2015 01:02:04 -0700 Subject: [PATCH 09/10] bug fix --- .../org/apache/spark/sql/execution/joins/SortMergeJoin.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 2b31d98bce85c..ec001108d2717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -157,7 +157,7 @@ case class SortMergeJoin( while (!found && streamedElement != null && keyOrdering.compare(streamedKey, matchKey) == 0) { while (bufferedPosition < bufferedMatches.size && !boundCondition( - smartJoinRow(streamedElement, bufferedMatches(bufferedPosition)))) { + joinRow(streamedElement, bufferedMatches(bufferedPosition)))) { bufferedPosition += 1 } if (bufferedPosition == bufferedMatches.size) { From d02f6bbab969a20e8a3cd9d6b065db39462d6ff5 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 5 Aug 2015 00:30:51 -0700 Subject: [PATCH 10/10] fix broadcast selection --- .../spark/sql/execution/SparkStrategies.scala | 20 +++++++++---------- .../org/apache/spark/sql/JoinSuite.scala | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d14501bfc68e6..3324393329d0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -96,6 +96,16 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + case ExtractEquiJoinKeys( + LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys( + RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + // If the sort merge join option is set, we want to use sort merge join prior to hashjoin. // And for outer join, we can not put conditions outside of the join case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) @@ -114,16 +124,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil - case ExtractEquiJoinKeys( - LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil - - case ExtractEquiJoinKeys( - RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.ShuffledHashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 463e7ba988745..16bf88c961b3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -152,9 +152,9 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[SortMergeJoin]), + classOf[BroadcastHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[SortMergeJoin]) + classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq(