From 1c41f6f248f1145c7d730129795e50bdd8a53f2b Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Tue, 28 Oct 2014 16:47:35 -0700 Subject: [PATCH 1/7] initial commit --- .../org/apache/spark/sql/execution/Exchange.scala | 14 ++++++++++++-- .../spark/sql/execution/SparkStrategies.scala | 4 ++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 927f40063e47e..3727208f58b3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.RowOrdering +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -57,10 +58,19 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una iter.map(r => mutablePair.update(hashExpressions(r), r)) } } + + val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending)) + implicit val ordering = new RowOrdering(sortingExpressions, child.output) val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) + val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) + //val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) - shuffled.map(_._2) + val temp = shuffled.map(_._2) + for ( x <- temp.collect()) { + println(x) + } + println("------------") + temp case RangePartitioning(sortingExpressions, numPartitions) => val rdd = if (sortBasedShuffleOn) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4f1af7234d551..d6ec5295cd032 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -83,6 +83,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => + val mergeJoin = joins.MergeJoin(leftKeys, rightKeys, Inner, condition, planLater(left), planLater(right)) + condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { From dc6a6840e2d2b1681e70a6a3eeb10d7a9e6437ce Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Tue, 28 Oct 2014 17:17:59 -0700 Subject: [PATCH 2/7] add MergeJoin.scala --- .../spark/sql/execution/joins/MergeJoin.scala | 168 ++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala new file mode 100644 index 0000000000000..ebb6c04884976 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, OrderedDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.{Inner, FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer +/* : Developer Api : + Sort-merge join +*/ +@DeveloperApi +case class MergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan +) extends BinaryNode { + + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def output = left.output ++ right.output + + //SortOrder meaning? + private val leftOrders = leftKeys.map(s => SortOrder(s, Ascending)) + private val rightOrders = leftKeys.map(s => SortOrder(s, Ascending)) + + //Ordered distribution, what order? + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + @transient protected lazy val leftKeyGenerator: Projection = + newProjection(leftKeys, left.output) + + @transient protected lazy val rightKeyGenerator: Projection = + newProjection(rightKeys, right.output) + + val ordering = new RowOrdering(leftOrders, left.output) + + // According to Postgres' merge join + // Join { + // get initial outer and inner tuples INITIALIZE + // do forever { + // while (outer != inner) { SKIP_TEST + // if (outer < inner) + // advance outer SKIPOUTER_ADVANCE + // else + // advance inner SKIPINNER_ADVANCE + // } + // mark inner position SKIP_TEST + // do forever { + // while (outer == inner) { + // join tuples JOINTUPLES + // advance inner position NEXTINNER + // } + // advance outer position NEXTOUTER + // if (outer == mark) TESTOUTER + // restore inner position to mark TESTOUTER + // else + // break // return to top of outer loop + // } + // } + // } + + // put maching tuples in rightIter into compact buffer + // find a matching tuple between left and right + // + override def execute() = { + + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + new Iterator[Row] { + private[this] val joinRow = new JoinedRow2 + var currentRow:Row = null + var leftElement:Row = null + var rightElement:Row = null + var leftKey:Row = null + var rightKey:Row = null + val leftLength = leftIter.length + val rightLength = rightIter.length + var leftIndex = 0 + var rightIndex = 0 + var mark = 0 + val buffer = new CompactBuffer[Row]() + + private def initialize() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + } + if (rightIter.hasNext) { + rightElement = rightIter.next() + } + } + + initialize() + + override final def hasNext: Boolean = { + leftIter.hasNext && rigthIter.hasNext + + + while(leftIndex < leftLength) { + if (leftElement == null) leftElement = leftIter.next() + if (rightElement == null) rightElement = rightIter.next() + leftKey = leftKeyGenerator(leftElement) + rightKey = rightKeyGenerator(rightElement) + if (ordering.compare(leftKey, rightKey) == 0) { + buffer += rightElement.copy() + currentRow = joinRow(leftElement, rightElement) + return true + } + + if (ordering.compare(leftKey, rightKey) <= 0 && rightIndex < rightLength) { + rightIndex += 1 + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + if (ordering.compare(leftKey, rightKey) < 0) { + mark = rightIndex + } + } else { + leftIndex += 1 + leftElement = leftIter.next() + rightIndex = mark + } + } + false + // while(leftIter.hasNext && rightIter.hasNext) { + // if (ordering.compare(leftKey, rightKey) == 0) { + // currentRow = joinRow(leftElement, rightElement) + // leftElement = null + // rightElement = null + // return true + // } else if (ordering.compare(leftKey,rightKey) < 0) { + // if (leftIter.hasNext) leftElement = leftIter.next() + // } else { + // if (rightIter.hasNext) rightElement = rightIter.next() + // } + // } + // false + } + + override final def next() = { + currentRow.copy() + } + } + } + } +} \ No newline at end of file From f5ef4624aea5304ffdcc8daf5fbebc20943c3cf4 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sat, 8 Nov 2014 20:05:56 -0800 Subject: [PATCH 3/7] Merge join working --- .../plans/physical/partitioning.scala | 45 +++++ .../apache/spark/sql/execution/Exchange.scala | 29 ++- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../spark/sql/execution/joins/MergeJoin.scala | 181 +++++++++--------- .../org/apache/spark/sql/JoinSuite.scala | 1 + 5 files changed, 155 insertions(+), 103 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ccb0df113c063..5bd2396c46dc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -58,6 +58,20 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi "a single partition.") } +/** + * Represents data where tuples that share the same values for the `clustering` + * [[Expression Expressions]] will be co-located. Based on the context, this + * can mean such tuples are either co-located in the same partition or they will be contiguous + * within a single partition. + */ +case class ClusteredOrderedDistribution(clustering: Seq[Expression]) extends Distribution { + require( + clustering != Nil, + "The clustering expressions of a ClusteredDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") +} + /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. This is a strictly stronger guarantee than @@ -162,6 +176,37 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. In each partition, the keys are sorted according to expressions + */ +case class HashSortedPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends Expression + with Partitioning { + + override def children = expressions + override def nullable = false + override def dataType = IntegerType + + private[this] lazy val clusteringSet = expressions.toSet + + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case ClusteredOrderedDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) + case _ => false + } + + override def compatibleWith(other: Partitioning) = other match { + case BroadcastPartitioning => true + case h: HashSortedPartitioning if h == this => true + case _ => false + } + + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") +} /** * Represents a partitioning where rows are split across partitions based on some total ordering of * the expressions specified in `ordering`. When data is partitioned in this manner the following diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 3727208f58b3f..443266725ea05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -59,18 +59,31 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } + val part = new HashPartitioner(numPartitions) + val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) + shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.map(_._2) + + case HashSortedPartitioning(expressions, numPartitions) => + val rdd = if (sortBasedShuffleOn) { + child.execute().mapPartitions { iter => + val hashExpressions = newProjection(expressions, child.output) + iter.map(r => (hashExpressions(r), r.copy())) + } + } else { + child.execute().mapPartitions { iter => + val hashExpressions = newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[Row, Row]() + iter.map(r => mutablePair.update(hashExpressions(r), r)) + } + } + val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending)) implicit val ordering = new RowOrdering(sortingExpressions, child.output) val part = new HashPartitioner(numPartitions) val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering) - //val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) - val temp = shuffled.map(_._2) - for ( x <- temp.collect()) { - println(x) - } - println("------------") - temp + shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => val rdd = if (sortBasedShuffleOn) { @@ -168,6 +181,8 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl addExchangeIfNecessary(SinglePartition, child) case (ClusteredDistribution(clustering), child) => addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) + case (ClusteredOrderedDistribution(clustering), child) => + addExchangeIfNecessary(HashSortedPartitioning(clustering, numPartitions), child) case (OrderedDistribution(ordering), child) => addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) case (UnspecifiedDistribution, child) => child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d6ec5295cd032..bbc4b041d6f73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -84,7 +84,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => - val mergeJoin = joins.MergeJoin(leftKeys, rightKeys, Inner, condition, planLater(left), planLater(right)) + val mergeJoin = joins.MergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala index ebb6c04884976..3a4b2d7319064 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala @@ -17,39 +17,33 @@ package org.apache.spark.sql.execution.joins -import scala.collection.JavaConversions._ - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, OrderedDistribution, Partitioning} -import org.apache.spark.sql.catalyst.plans.{Inner, FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredOrderedDistribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.collection.CompactBuffer /* : Developer Api : - Sort-merge join + * Performs sort-merge join of two child relations by first shuffling the data using the join + * keys. Also, when shuffling the data, sort the data by join keys. */ @DeveloperApi case class MergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], left: SparkPlan, right: SparkPlan ) extends BinaryNode { - - + // Implementation: the tricky part is handling duplicate join keys. + // To handle duplicate keys, we use a buffer to store a override def outputPartitioning: Partitioning = left.outputPartitioning override def output = left.output ++ right.output - //SortOrder meaning? private val leftOrders = leftKeys.map(s => SortOrder(s, Ascending)) - private val rightOrders = leftKeys.map(s => SortOrder(s, Ascending)) + private val rightOrders = rightKeys.map(s => SortOrder(s, Ascending)) - //Ordered distribution, what order? override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil @transient protected lazy val leftKeyGenerator: Projection = newProjection(leftKeys, left.output) @@ -57,112 +51,109 @@ case class MergeJoin( @transient protected lazy val rightKeyGenerator: Projection = newProjection(rightKeys, right.output) - val ordering = new RowOrdering(leftOrders, left.output) - - // According to Postgres' merge join - // Join { - // get initial outer and inner tuples INITIALIZE - // do forever { - // while (outer != inner) { SKIP_TEST - // if (outer < inner) - // advance outer SKIPOUTER_ADVANCE - // else - // advance inner SKIPINNER_ADVANCE - // } - // mark inner position SKIP_TEST - // do forever { - // while (outer == inner) { - // join tuples JOINTUPLES - // advance inner position NEXTINNER - // } - // advance outer position NEXTOUTER - // if (outer == mark) TESTOUTER - // restore inner position to mark TESTOUTER - // else - // break // return to top of outer loop - // } - // } - // } + private val ordering = new RowOrdering(leftOrders, left.output) - // put maching tuples in rightIter into compact buffer - // find a matching tuple between left and right - // override def execute() = { left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new Iterator[Row] { private[this] val joinRow = new JoinedRow2 - var currentRow:Row = null - var leftElement:Row = null - var rightElement:Row = null - var leftKey:Row = null - var rightKey:Row = null - val leftLength = leftIter.length - val rightLength = rightIter.length - var leftIndex = 0 - var rightIndex = 0 - var mark = 0 - val buffer = new CompactBuffer[Row]() - + private[this] var leftElement:Row = _ + private[this] var rightElement:Row = _ + private[this] var leftKey:Row = _ + private[this] var rightKey:Row = _ + private[this] var buffer:CompactBuffer[Row] = _ + private[this] var index = -1 + private[this] var last = false + + // initialize iterator private def initialize() = { if (leftIter.hasNext) { leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } else { + last = true } if (rightIter.hasNext) { rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } else { + last = true } } initialize() override final def hasNext: Boolean = { - leftIter.hasNext && rigthIter.hasNext - - - while(leftIndex < leftLength) { - if (leftElement == null) leftElement = leftIter.next() - if (rightElement == null) rightElement = rightIter.next() - leftKey = leftKeyGenerator(leftElement) - rightKey = rightKeyGenerator(rightElement) - if (ordering.compare(leftKey, rightKey) == 0) { - buffer += rightElement.copy() - currentRow = joinRow(leftElement, rightElement) - return true - } - - if (ordering.compare(leftKey, rightKey) <= 0 && rightIndex < rightLength) { - rightIndex += 1 - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) - if (ordering.compare(leftKey, rightKey) < 0) { - mark = rightIndex + if (index != -1) return true + if (last) return false + return nextMatchingPair() + } + + override final def next(): Row = { + if (index == -1) { + if (!hasNext) return null + } + val joinedRow = joinRow(leftElement, buffer(index)) + index += 1 + if (index == buffer.size) { + if (leftIter.hasNext) { + val leftElem = leftElement + val leftK = leftKeyGenerator(leftElem) + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + if (ordering.compare(leftKey,leftK) == 0) { + index = 0 + } else { + index = -1 } } else { - leftIndex += 1 - leftElement = leftIter.next() - rightIndex = mark + index = -1 + last = true } } - false - // while(leftIter.hasNext && rightIter.hasNext) { - // if (ordering.compare(leftKey, rightKey) == 0) { - // currentRow = joinRow(leftElement, rightElement) - // leftElement = null - // rightElement = null - // return true - // } else if (ordering.compare(leftKey,rightKey) < 0) { - // if (leftIter.hasNext) leftElement = leftIter.next() - // } else { - // if (rightIter.hasNext) rightElement = rightIter.next() - // } - // } - // false + joinedRow } - - override final def next() = { - currentRow.copy() + + private def nextMatchingPair(): Boolean = { + while (ordering.compare(leftKey, rightKey) != 0) { + if (ordering.compare(leftKey, rightKey) < 0) { + if (leftIter.hasNext) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } else { + last = true + return false + } + } else { + if (rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } else { + last = true + return false + } + } + } + // outer == inner + index = 0 + buffer = null + buffer = new CompactBuffer[Row]() + buffer += rightElement + val rightElem = rightElement + val rightK = rightKeyGenerator(rightElem) + while(rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + if (ordering.compare(rightKey,rightK) == 0) { + buffer += rightElement + } else { + return true + } + } + true } } } } -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 07f4d2946c1b5..f82a1100bd946 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -53,6 +53,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j + case j: MergeJoin => j } assert(operators.size === 1) From d6b6e7b8194682c713400823a4fd17e0419d89e4 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sat, 8 Nov 2014 20:51:04 -0800 Subject: [PATCH 4/7] add inline comments for merge join --- .../plans/physical/partitioning.scala | 2 +- .../spark/sql/execution/joins/MergeJoin.scala | 33 ++++++++++++++++--- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 5bd2396c46dc6..3ed83f04039b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -67,7 +67,7 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi case class ClusteredOrderedDistribution(clustering: Seq[Expression]) extends Distribution { require( clustering != Nil, - "The clustering expressions of a ClusteredDistribution should not be Nil. " + + "The clustering expressions of a ClusteredOrderedDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala index 3a4b2d7319064..aad5c0af3d20e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala @@ -34,13 +34,17 @@ case class MergeJoin( right: SparkPlan ) extends BinaryNode { // Implementation: the tricky part is handling duplicate join keys. - // To handle duplicate keys, we use a buffer to store a + // To handle duplicate keys, we use a buffer to store all maching tuples + // in right relation for a certain join key. This buffer is used by the + // merge join iterator to generate join tuples. The buffer is used for + // generating join tuples when the join key of the next left element is + // is the same as the current join key. + // TODO: add outer join support override def outputPartitioning: Partitioning = left.outputPartitioning override def output = left.output ++ right.output - private val leftOrders = leftKeys.map(s => SortOrder(s, Ascending)) - private val rightOrders = rightKeys.map(s => SortOrder(s, Ascending)) + private val orders = leftKeys.map(s => SortOrder(s, Ascending)) override def requiredChildDistribution = ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil @@ -51,7 +55,7 @@ case class MergeJoin( @transient protected lazy val rightKeyGenerator: Projection = newProjection(rightKeys, right.output) - private val ordering = new RowOrdering(leftOrders, left.output) + private val ordering = new RowOrdering(orders, left.output) override def execute() = { @@ -85,6 +89,14 @@ case class MergeJoin( initialize() override final def hasNext: Boolean = { + // Two cases that hasNext returns true + // 1. We are iterating the buffer + // 2. We can find tuple pairs that have matching join key + // + // hasNext is stateless as nextMatchingPair() is called when + // index == -1 and will set index to 0 when nextMatchingPair() + // returns true. Muptiple calls to hasNext modifies iterator + // state at most once. if (index != -1) return true if (last) return false return nextMatchingPair() @@ -92,22 +104,33 @@ case class MergeJoin( override final def next(): Row = { if (index == -1) { + // We need this becasue the client of the join iterator may + // call next() without calling hasNext if (!hasNext) return null } val joinedRow = joinRow(leftElement, buffer(index)) index += 1 if (index == buffer.size) { + // finished iterating the buffer, fetch + // next element from left iterator if (leftIter.hasNext) { + // fetch next element val leftElem = leftElement val leftK = leftKeyGenerator(leftElem) leftElement = leftIter.next() leftKey = leftKeyGenerator(leftElement) if (ordering.compare(leftKey,leftK) == 0) { + // need to go over the buffer again + // as we have the same join key for + // next left element index = 0 } else { + // need to find a matching element from + // right iterator index = -1 } } else { + // no next left element, we are done index = -1 last = true } @@ -115,6 +138,8 @@ case class MergeJoin( joinedRow } + // find the next pair of left/right tuples that have a + // matching join key private def nextMatchingPair(): Boolean = { while (ordering.compare(leftKey, rightKey) != 0) { if (ordering.compare(leftKey, rightKey) < 0) { From 837eb081e6382a23b4fd67a5265188aab1c7e305 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sat, 8 Nov 2014 21:02:17 -0800 Subject: [PATCH 5/7] use merge join as inner join operator in JoinSuite --- .../src/test/scala/org/apache/spark/sql/JoinSuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8b4cf5bac0187..22c1dd482c36c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j + case j: MergeJoin => j } assert(operators.size === 1) @@ -72,9 +73,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[MergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[MergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[MergeJoin]), ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[HashOuterJoin]), From 5cb98c306f76183e4148d9b0a6b0a8ce4d58368e Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sat, 8 Nov 2014 21:30:52 -0800 Subject: [PATCH 6/7] improve inline comments --- .../org/apache/spark/sql/execution/joins/MergeJoin.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala index aad5c0af3d20e..2ca9541ccc307 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/MergeJoin.scala @@ -34,11 +34,10 @@ case class MergeJoin( right: SparkPlan ) extends BinaryNode { // Implementation: the tricky part is handling duplicate join keys. - // To handle duplicate keys, we use a buffer to store all maching tuples - // in right relation for a certain join key. This buffer is used by the - // merge join iterator to generate join tuples. The buffer is used for + // To handle duplicate keys, we use a buffer to store all matching elements + // in right iterator for a certain join key. The buffer is used for // generating join tuples when the join key of the next left element is - // is the same as the current join key. + // the same as the current join key. // TODO: add outer join support override def outputPartitioning: Partitioning = left.outputPartitioning From cc4647d2332806026fe592c5798f0900ba001744 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sat, 8 Nov 2014 22:03:17 -0800 Subject: [PATCH 7/7] comment out unmatched code --- .../spark/sql/execution/SparkStrategies.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5bf6cb17f3cca..03db6d1126519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -87,16 +87,16 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val mergeJoin = joins.MergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => - val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - joins.BuildRight - } else { - joins.BuildLeft - } - val hashJoin = joins.ShuffledHashJoin( - leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) - condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + // case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => + // val buildSide = + // if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + // joins.BuildRight + // } else { + // joins.BuildLeft + // } + // val hashJoin = joins.ShuffledHashJoin( + // leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) + // condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.HashOuterJoin(