Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ trait Row extends Seq[Any] with Serializable {
s"[${this.mkString(",")}]"

def copy(): Row

/** Returns true if there are any NULL values in this row. */
def anyNull: Boolean = {
var i = 0
while (i < length) {
if (isNullAt(i)) { return true }
i += 1
}
false
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.types.{BooleanType, StringType}

object InterpretedPredicate {
def apply(expression: Expression): (Row => Boolean) = {
(r: Row) => expression.apply(r).asInstanceOf[Boolean]
}
}

trait Predicate extends Expression {
self: Product =>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val strategies: Seq[Strategy] =
TopK ::
PartialAggregation ::
SparkEquiInnerJoin ::
HashJoin ::
ParquetOperations ::
BasicOperators ::
CartesianProduct ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.parquet._
abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>

object SparkEquiInnerJoin extends Strategy {
object HashJoin extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) =>
logger.debug(s"Considering join: ${predicates ++ condition}")
Expand All @@ -51,8 +51,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val leftKeys = joinKeys.map(_._1)
val rightKeys = joinKeys.map(_._2)

val joinOp = execution.SparkEquiInnerJoin(
leftKeys, rightKeys, planLater(left), planLater(right))
val joinOp = execution.HashJoin(
leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))

// Make sure other conditions are met if present.
if (otherPredicates.nonEmpty) {
Expand Down
127 changes: 95 additions & 32 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,22 @@

package org.apache.spark.sql.execution

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, BitSet}

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext

import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}

import org.apache.spark.rdd.PartitionLocalRDDFunctions._
sealed abstract class BuildSide
case object BuildLeft extends BuildSide
case object BuildRight extends BuildSide

case class SparkEquiInnerJoin(
case class HashJoin(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me this only handles inner join right? If yes, maybe reflect that in the name. Technically we can do outer joins with hash join too ...

leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

Expand All @@ -40,33 +41,93 @@ case class SparkEquiInnerJoin(
override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

val (buildPlan, streamedPlan) = buildSide match {
case BuildLeft => (left, right)
case BuildRight => (right, left)
}

val (buildKeys, streamedKeys) = buildSide match {
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}

def output = left.output ++ right.output

def execute() = attachTree(this, "execute") {
val leftWithKeys = left.execute().mapPartitions { iter =>
val generateLeftKeys = new Projection(leftKeys, left.output)
iter.map(row => (generateLeftKeys(row), row.copy()))
}
@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
@transient lazy val streamSideKeyGenerator =
() => new MutableProjection(streamedKeys, streamedPlan.output)

val rightWithKeys = right.execute().mapPartitions { iter =>
val generateRightKeys = new Projection(rightKeys, right.output)
iter.map(row => (generateRightKeys(row), row.copy()))
}
def execute() = {

// Do the join.
val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys))
// Drop join keys and merge input tuples.
joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) }
}
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
// TODO: Use Spark's HashMap implementation.
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the hash map here: org.apache.spark.util.collection.OpenHashMap

It is way faster than Java's, uses less memory, and the changeValue function allows you to use a single lookup to do lookup and update the value (so you can avoid the two hash lookups : one in get and one in put)

var currentRow: Row = null

// Create a mapping of buildKeys -> rows
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new ArrayBuffer[Row]()
hashTable.put(rowKey, newMatchList)
newMatchList
} else {
existingMatchList
}
matchList += currentRow.copy()
}
}

new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatches: ArrayBuffer[Row] = _
private[this] var currentMatchPosition: Int = -1

/**
* Filters any rows where the any of the join keys is null, ensuring three-valued
* logic for the equi-join conditions.
*/
protected def filterNulls(rdd: RDD[(Row, Row)]) =
rdd.filter {
case (key: Seq[_], _) => !key.exists(_ == null)
// Mutable per row objects.
private[this] val joinRow = new JoinedRow

private[this] val joinKeys = streamSideKeyGenerator()

override final def hasNext: Boolean =
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
(streamIter.hasNext && fetchNext())

override final def next() = {
val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
currentMatchPosition += 1
ret
}

/**
* Searches the streamed iterator for the next row that has at least one match in hashtable.
*
* @return true if the search is successful, and false the streamed iterator runs out of
* tuples.
*/
private final def fetchNext(): Boolean = {
currentHashMatches = null
currentMatchPosition = -1

while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
currentHashMatches = hashTable.get(joinKeys.currentValue)
}
}

if (currentHashMatches == null) {
false
} else {
currentMatchPosition = 0
true
}
}
}
}
}
}

case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
Expand Down Expand Up @@ -95,17 +156,19 @@ case class BroadcastNestedLoopJoin(
def right = broadcast

@transient lazy val boundCondition =
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true))
InterpretedPredicate(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))


def execute() = {
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)

val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
val matchedRows = new mutable.ArrayBuffer[Row]
val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size)
val matchedRows = new ArrayBuffer[Row]
// TODO: Use Spark's BitSet.
val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
val joinedRow = new JoinedRow

streamedIter.foreach { streamedRow =>
Expand All @@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin(
while (i < broadcastedRelation.value.size) {
// TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) {
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
matchedRows += buildRow(streamedRow ++ broadcastedRow)
matched = true
includedBroadcastTuples += i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
DataSinks,
Scripts,
PartialAggregation,
SparkEquiInnerJoin,
HashJoin,
BasicOperators,
CartesianProduct,
BroadcastNestedLoopJoin
Expand Down