From bc0cb843bef71d03c7002523a567ac39760f65c9 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 26 Mar 2014 21:11:01 -0700 Subject: [PATCH 1/4] Rewrite join implementation to allow streaming of one relation. --- .../spark/sql/catalyst/expressions/Row.scala | 10 ++ .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 6 +- .../apache/spark/sql/execution/joins.scala | 133 +++++++++++++----- .../apache/spark/sql/hive/HiveContext.scala | 2 +- 5 files changed, 113 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 31d42b9ee71a0..f53d7a8ec0124 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -44,6 +44,16 @@ trait Row extends Seq[Any] with Serializable { s"[${this.mkString(",")}]" def copy(): Row + + /** Returns true if there are any NULL values in this row. */ + def anyNull: Boolean = { + var i = 0 + while(i < length) { + if(isNullAt(i)) return true + i += 1 + } + false + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cf3c06acce5b0..f950ea08ec57a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -117,7 +117,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val strategies: Seq[Strategy] = TopK :: PartialAggregation :: - SparkEquiInnerJoin :: + HashJoin :: ParquetOperations :: BasicOperators :: CartesianProduct :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 86f9d3e0fa954..e35ac0b6ca95a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.parquet._ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => - object SparkEquiInnerJoin extends Strategy { + object HashJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) => logger.debug(s"Considering join: ${predicates ++ condition}") @@ -51,8 +51,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val leftKeys = joinKeys.map(_._1) val rightKeys = joinKeys.map(_._2) - val joinOp = execution.SparkEquiInnerJoin( - leftKeys, rightKeys, planLater(left), planLater(right)) + val joinOp = execution.HashJoin( + leftKeys, rightKeys, BuildRight, planLater(left), planLater(right)) // Make sure other conditions are met if present. if (otherPredicates.nonEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index f0d21143ba5d1..62f001a697063 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -15,23 +15,33 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.sql +package execution -import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, BitSet} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import catalyst.errors._ +import catalyst.expressions._ +import catalyst.plans._ +import catalyst.plans.physical.{ClusteredDistribution, Partitioning} -import org.apache.spark.rdd.PartitionLocalRDDFunctions._ +sealed abstract class BuildSide +case object BuildLeft extends BuildSide +case object BuildRight extends BuildSide -case class SparkEquiInnerJoin( +object InterpretCondition { + def apply(expression: Expression): (Row => Boolean) = { + (r: Row) => expression.apply(r).asInstanceOf[Boolean] + } +} + +case class HashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + buildSide: BuildSide, left: SparkPlan, right: SparkPlan) extends BinaryNode { @@ -40,33 +50,85 @@ case class SparkEquiInnerJoin( override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + val (buildPlan, streamedPlan) = buildSide match { + case BuildLeft => (left, right) + case BuildRight => (right, left) + } + + val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + def output = left.output ++ right.output - def execute() = attachTree(this, "execute") { - val leftWithKeys = left.execute().mapPartitions { iter => - val generateLeftKeys = new Projection(leftKeys, left.output) - iter.map(row => (generateLeftKeys(row), row.copy())) - } + @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) + @transient lazy val streamSideKeyGenerator = + () => new MutableProjection(streamedKeys, streamedPlan.output) - val rightWithKeys = right.execute().mapPartitions { iter => - val generateRightKeys = new Projection(rightKeys, right.output) - iter.map(row => (generateRightKeys(row), row.copy())) - } + def execute() = { - // Do the join. - val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys)) - // Drop join keys and merge input tuples. - joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) } - } + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() + var currentRow: Row = null + + // Create a mapping of buildKeys -> rows + while(buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if(!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new ArrayBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += currentRow.copy() + } + } - /** - * Filters any rows where the any of the join keys is null, ensuring three-valued - * logic for the equi-join conditions. - */ - protected def filterNulls(rdd: RDD[(Row, Row)]) = - rdd.filter { - case (key: Seq[_], _) => !key.exists(_ == null) + new Iterator[Row] { + private[this] var currentRow: Row = _ + private[this] var currentMatches: ArrayBuffer[Row] = _ + private[this] var currentPosition: Int = -1 + + // Mutable per row objects. + private[this] val joinRow = new JoinedRow + + @transient private val joinKeys = streamSideKeyGenerator() + + def hasNext: Boolean = + (currentPosition != -1 && currentPosition < currentMatches.size) || + (streamIter.hasNext && fetchNext()) + + def next() = { + val ret = joinRow(currentRow, currentMatches(currentPosition)) + currentPosition += 1 + ret + } + + private def fetchNext(): Boolean = { + currentMatches = null + currentPosition = -1 + + while (currentMatches == null && streamIter.hasNext) { + currentRow = streamIter.next() + if(!joinKeys(currentRow).anyNull) + currentMatches = hashTable.get(joinKeys.currentValue) + } + + if (currentMatches == null) { + false + } else { + currentPosition = 0 + true + } + } + } } + } } case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { @@ -95,17 +157,18 @@ case class BroadcastNestedLoopJoin( def right = broadcast @transient lazy val boundCondition = - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true)) + InterpretCondition( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) def execute() = { val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new mutable.ArrayBuffer[Row] - val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size) + val matchedRows = new ArrayBuffer[Row] + val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow streamedIter.foreach { streamedRow => @@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin( while (i < broadcastedRelation.value.size) { // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) { + if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { matchedRows += buildRow(streamedRow ++ broadcastedRow) matched = true includedBroadcastTuples += i diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fc5057b73fe24..197b557cba5f4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -194,7 +194,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { DataSinks, Scripts, PartialAggregation, - SparkEquiInnerJoin, + HashJoin, BasicOperators, CartesianProduct, BroadcastNestedLoopJoin From 1e9fb63349a4748b0cc31a34d8a26406a7ed657e Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 26 Mar 2014 23:57:06 -0700 Subject: [PATCH 2/4] style --- .../main/scala/org/apache/spark/sql/execution/joins.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 62f001a697063..4ee833c7321ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -115,8 +115,9 @@ case class HashJoin( while (currentMatches == null && streamIter.hasNext) { currentRow = streamIter.next() - if(!joinKeys(currentRow).anyNull) + if(!joinKeys(currentRow).anyNull) { currentMatches = hashTable.get(joinKeys.currentValue) + } } if (currentMatches == null) { @@ -168,7 +169,8 @@ case class BroadcastNestedLoopJoin( val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => val matchedRows = new ArrayBuffer[Row] - val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + val includedBroadcastTuples = + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow streamedIter.foreach { streamedRow => From 8e6f2a262e33669f91c24416b5e54af8c2a9689d Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 30 Mar 2014 21:08:20 -0700 Subject: [PATCH 3/4] Review comments. --- .../spark/sql/catalyst/expressions/Row.scala | 4 +- .../sql/catalyst/expressions/predicates.scala | 6 ++ .../apache/spark/sql/execution/joins.scala | 75 ++++++++++--------- 3 files changed, 46 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index f53d7a8ec0124..6f939e6c41f6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -48,8 +48,8 @@ trait Row extends Seq[Any] with Serializable { /** Returns true if there are any NULL values in this row. */ def anyNull: Boolean = { var i = 0 - while(i < length) { - if(isNullAt(i)) return true + while (i < length) { + if (isNullAt(i)) { return true } i += 1 } false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 722ff517d250e..02fedd16b8d4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.types.{BooleanType, StringType} +object InterpretedPredicate { + def apply(expression: Expression): (Row => Boolean) = { + (r: Row) => expression.apply(r).asInstanceOf[Boolean] + } +} + trait Predicate extends Expression { self: Product => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 4ee833c7321ba..78e4814c07b7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -15,29 +15,20 @@ * limitations under the License. */ -package org.apache.spark.sql -package execution +package org.apache.spark.sql.execution import scala.collection.mutable.{ArrayBuffer, BitSet} -import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext -import catalyst.errors._ -import catalyst.expressions._ -import catalyst.plans._ -import catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} sealed abstract class BuildSide case object BuildLeft extends BuildSide case object BuildRight extends BuildSide -object InterpretCondition { - def apply(expression: Expression): (Row => Boolean) = { - (r: Row) => expression.apply(r).asInstanceOf[Boolean] - } -} - case class HashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -69,11 +60,12 @@ case class HashJoin( def execute() = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + // TODO: Use Spark's HashMap implementation. val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() var currentRow: Row = null // Create a mapping of buildKeys -> rows - while(buildIter.hasNext) { + while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) if(!rowKey.anyNull) { @@ -90,40 +82,49 @@ case class HashJoin( } new Iterator[Row] { - private[this] var currentRow: Row = _ - private[this] var currentMatches: ArrayBuffer[Row] = _ - private[this] var currentPosition: Int = -1 + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: ArrayBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. private[this] val joinRow = new JoinedRow - @transient private val joinKeys = streamSideKeyGenerator() + private[this] val joinKeys = streamSideKeyGenerator() - def hasNext: Boolean = - (currentPosition != -1 && currentPosition < currentMatches.size) || - (streamIter.hasNext && fetchNext()) + override final def hasNext: Boolean = + if (currentMatchPosition != -1) { + currentMatchPosition < currentHashMatches.size + } else { + fetchNext() + } - def next() = { - val ret = joinRow(currentRow, currentMatches(currentPosition)) - currentPosition += 1 + override final def next() = { + val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + currentMatchPosition += 1 ret } - private def fetchNext(): Boolean = { - currentMatches = null - currentPosition = -1 - - while (currentMatches == null && streamIter.hasNext) { - currentRow = streamIter.next() - if(!joinKeys(currentRow).anyNull) { - currentMatches = hashTable.get(joinKeys.currentValue) + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashTable.get(joinKeys.currentValue) } } - if (currentMatches == null) { + if (currentHashMatches == null) { false } else { - currentPosition = 0 + currentMatchPosition = 0 true } } @@ -158,7 +159,7 @@ case class BroadcastNestedLoopJoin( def right = broadcast @transient lazy val boundCondition = - InterpretCondition( + InterpretedPredicate( condition .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) @@ -169,8 +170,8 @@ case class BroadcastNestedLoopJoin( val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => val matchedRows = new ArrayBuffer[Row] - val includedBroadcastTuples = - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + // TODO: Use Spark's BitSet. + val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow streamedIter.foreach { streamedRow => From 1ad873ef0e4add4a023e3e5bf55b2300a8a9689f Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 31 Mar 2014 12:03:02 -0700 Subject: [PATCH 4/4] Change hasNext logic back to the correct version. --- .../main/scala/org/apache/spark/sql/execution/joins.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 78e4814c07b7d..c89dae9358bf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -92,11 +92,8 @@ case class HashJoin( private[this] val joinKeys = streamSideKeyGenerator() override final def hasNext: Boolean = - if (currentMatchPosition != -1) { - currentMatchPosition < currentHashMatches.size - } else { - fetchNext() - } + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + (streamIter.hasNext && fetchNext()) override final def next() = { val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))