-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SQL] Rewrite join implementation to allow streaming of one relation. #250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,21 +17,22 @@ | |
|
|
||
| package org.apache.spark.sql.execution | ||
|
|
||
| import scala.collection.mutable | ||
| import scala.collection.mutable.{ArrayBuffer, BitSet} | ||
|
|
||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.SparkContext | ||
|
|
||
| import org.apache.spark.sql.catalyst.errors._ | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans._ | ||
| import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} | ||
|
|
||
| import org.apache.spark.rdd.PartitionLocalRDDFunctions._ | ||
| sealed abstract class BuildSide | ||
| case object BuildLeft extends BuildSide | ||
| case object BuildRight extends BuildSide | ||
|
|
||
| case class SparkEquiInnerJoin( | ||
| case class HashJoin( | ||
| leftKeys: Seq[Expression], | ||
| rightKeys: Seq[Expression], | ||
| buildSide: BuildSide, | ||
| left: SparkPlan, | ||
| right: SparkPlan) extends BinaryNode { | ||
|
|
||
|
|
@@ -40,33 +41,93 @@ case class SparkEquiInnerJoin( | |
| override def requiredChildDistribution = | ||
| ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil | ||
|
|
||
| val (buildPlan, streamedPlan) = buildSide match { | ||
| case BuildLeft => (left, right) | ||
| case BuildRight => (right, left) | ||
| } | ||
|
|
||
| val (buildKeys, streamedKeys) = buildSide match { | ||
| case BuildLeft => (leftKeys, rightKeys) | ||
| case BuildRight => (rightKeys, leftKeys) | ||
| } | ||
|
|
||
| def output = left.output ++ right.output | ||
|
|
||
| def execute() = attachTree(this, "execute") { | ||
| val leftWithKeys = left.execute().mapPartitions { iter => | ||
| val generateLeftKeys = new Projection(leftKeys, left.output) | ||
| iter.map(row => (generateLeftKeys(row), row.copy())) | ||
| } | ||
| @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) | ||
| @transient lazy val streamSideKeyGenerator = | ||
| () => new MutableProjection(streamedKeys, streamedPlan.output) | ||
|
|
||
| val rightWithKeys = right.execute().mapPartitions { iter => | ||
| val generateRightKeys = new Projection(rightKeys, right.output) | ||
| iter.map(row => (generateRightKeys(row), row.copy())) | ||
| } | ||
| def execute() = { | ||
|
|
||
| // Do the join. | ||
| val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys)) | ||
| // Drop join keys and merge input tuples. | ||
| joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) } | ||
| } | ||
| buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => | ||
| // TODO: Use Spark's HashMap implementation. | ||
| val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the hash map here: org.apache.spark.util.collection.OpenHashMap It is way faster than Java's, uses less memory, and the changeValue function allows you to use a single lookup to do lookup and update the value (so you can avoid the two hash lookups : one in get and one in put) |
||
| var currentRow: Row = null | ||
|
|
||
| // Create a mapping of buildKeys -> rows | ||
| while (buildIter.hasNext) { | ||
| currentRow = buildIter.next() | ||
| val rowKey = buildSideKeyGenerator(currentRow) | ||
| if(!rowKey.anyNull) { | ||
| val existingMatchList = hashTable.get(rowKey) | ||
| val matchList = if (existingMatchList == null) { | ||
| val newMatchList = new ArrayBuffer[Row]() | ||
| hashTable.put(rowKey, newMatchList) | ||
| newMatchList | ||
| } else { | ||
| existingMatchList | ||
| } | ||
| matchList += currentRow.copy() | ||
| } | ||
| } | ||
|
|
||
| new Iterator[Row] { | ||
| private[this] var currentStreamedRow: Row = _ | ||
| private[this] var currentHashMatches: ArrayBuffer[Row] = _ | ||
| private[this] var currentMatchPosition: Int = -1 | ||
|
|
||
| /** | ||
| * Filters any rows where the any of the join keys is null, ensuring three-valued | ||
| * logic for the equi-join conditions. | ||
| */ | ||
| protected def filterNulls(rdd: RDD[(Row, Row)]) = | ||
| rdd.filter { | ||
| case (key: Seq[_], _) => !key.exists(_ == null) | ||
| // Mutable per row objects. | ||
| private[this] val joinRow = new JoinedRow | ||
|
|
||
| private[this] val joinKeys = streamSideKeyGenerator() | ||
|
|
||
| override final def hasNext: Boolean = | ||
| (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || | ||
| (streamIter.hasNext && fetchNext()) | ||
|
|
||
| override final def next() = { | ||
| val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) | ||
| currentMatchPosition += 1 | ||
| ret | ||
| } | ||
|
|
||
| /** | ||
| * Searches the streamed iterator for the next row that has at least one match in hashtable. | ||
| * | ||
| * @return true if the search is successful, and false the streamed iterator runs out of | ||
| * tuples. | ||
| */ | ||
| private final def fetchNext(): Boolean = { | ||
| currentHashMatches = null | ||
| currentMatchPosition = -1 | ||
|
|
||
| while (currentHashMatches == null && streamIter.hasNext) { | ||
| currentStreamedRow = streamIter.next() | ||
| if (!joinKeys(currentStreamedRow).anyNull) { | ||
| currentHashMatches = hashTable.get(joinKeys.currentValue) | ||
| } | ||
| } | ||
|
|
||
| if (currentHashMatches == null) { | ||
| false | ||
| } else { | ||
| currentMatchPosition = 0 | ||
| true | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { | ||
|
|
@@ -95,17 +156,19 @@ case class BroadcastNestedLoopJoin( | |
| def right = broadcast | ||
|
|
||
| @transient lazy val boundCondition = | ||
| condition | ||
| .map(c => BindReferences.bindReference(c, left.output ++ right.output)) | ||
| .getOrElse(Literal(true)) | ||
| InterpretedPredicate( | ||
| condition | ||
| .map(c => BindReferences.bindReference(c, left.output ++ right.output)) | ||
| .getOrElse(Literal(true))) | ||
|
|
||
|
|
||
| def execute() = { | ||
| val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) | ||
|
|
||
| val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => | ||
| val matchedRows = new mutable.ArrayBuffer[Row] | ||
| val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size) | ||
| val matchedRows = new ArrayBuffer[Row] | ||
| // TODO: Use Spark's BitSet. | ||
| val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) | ||
| val joinedRow = new JoinedRow | ||
|
|
||
| streamedIter.foreach { streamedRow => | ||
|
|
@@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin( | |
| while (i < broadcastedRelation.value.size) { | ||
| // TODO: One bitset per partition instead of per row. | ||
| val broadcastedRow = broadcastedRelation.value(i) | ||
| if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) { | ||
| if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { | ||
| matchedRows += buildRow(streamedRow ++ broadcastedRow) | ||
| matched = true | ||
| includedBroadcastTuples += i | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to me this only handles inner join right? If yes, maybe reflect that in the name. Technically we can do outer joins with hash join too ...