Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 33 additions & 34 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@ package org.apache.spark.sql.execution

import java.util.{HashMap => JavaHashMap}

import scala.collection.mutable.{ArrayBuffer, BitSet}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent._
import scala.concurrent.duration._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.collection.CompactBuffer

@DeveloperApi
sealed abstract class BuildSide
Expand Down Expand Up @@ -67,7 +66,7 @@ trait HashJoin {
def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = {
// TODO: Use Spark's HashMap implementation.

val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]()
var currentRow: Row = null

// Create a mapping of buildKeys -> rows
Expand All @@ -77,7 +76,7 @@ trait HashJoin {
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new ArrayBuffer[Row]()
val newMatchList = new CompactBuffer[Row]()
hashTable.put(rowKey, newMatchList)
newMatchList
} else {
Expand All @@ -89,7 +88,7 @@ trait HashJoin {

new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatches: ArrayBuffer[Row] = _
private[this] var currentHashMatches: CompactBuffer[Row] = _
private[this] var currentMatchPosition: Int = -1

// Mutable per row objects.
Expand Down Expand Up @@ -140,7 +139,7 @@ trait HashJoin {

/**
* :: DeveloperApi ::
* Performs a hash based outer join for two child relations by shuffling the data using
* Performs a hash based outer join for two child relations by shuffling the data using
* the join keys. This operator requires loading the associated partition in both side into memory.
*/
@DeveloperApi
Expand Down Expand Up @@ -179,26 +178,26 @@ case class HashOuterJoin(
@transient private[this] lazy val EMPTY_LIST = Seq.empty[Row]

// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.
// iterator for performance purpose.

private[this] def leftOuterIterator(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val rightNullRow = new GenericRow(right.output.length)
val boundCondition =
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

leftIter.iterator.flatMap { l =>
leftIter.iterator.flatMap { l =>
joinedRow.withLeft(l)
var matched = false
(if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
(if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
matched = true
joinedRow.copy
} else {
Nil
}) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
// DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all of the
// as we don't know whether we need to append it until finish iterating all of the
// records in right side.
// If we didn't get any proper row, then append a single row with empty right
joinedRow.withRight(rightNullRow).copy
Expand All @@ -210,20 +209,20 @@ case class HashOuterJoin(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val leftNullRow = new GenericRow(left.output.length)
val boundCondition =
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

rightIter.iterator.flatMap { r =>
rightIter.iterator.flatMap { r =>
joinedRow.withRight(r)
var matched = false
(if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
(if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
matched = true
joinedRow.copy
} else {
Nil
}) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
// DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all of the
// as we don't know whether we need to append it until finish iterating all of the
// records in left side.
// If we didn't get any proper row, then append a single row with empty left.
joinedRow.withLeft(leftNullRow).copy
Expand All @@ -236,7 +235,7 @@ case class HashOuterJoin(
val joinedRow = new JoinedRow()
val leftNullRow = new GenericRow(left.output.length)
val rightNullRow = new GenericRow(right.output.length)
val boundCondition =
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

if (!key.anyNull) {
Expand All @@ -246,8 +245,8 @@ case class HashOuterJoin(
leftIter.iterator.flatMap[Row] { l =>
joinedRow.withLeft(l)
var matched = false
rightIter.zipWithIndex.collect {
// 1. For those matched (satisfy the join condition) records with both sides filled,
rightIter.zipWithIndex.collect {
// 1. For those matched (satisfy the join condition) records with both sides filled,
// append them directly

case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
Expand All @@ -260,16 +259,16 @@ case class HashOuterJoin(
// 2. For those unmatched records in left, append additional records with empty right.

// DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all
// as we don't know whether we need to append it until finish iterating all
// of the records in right side.
// If we didn't get any proper row, then append a single row with empty right.
joinedRow.withRight(rightNullRow).copy
})
} ++ rightIter.zipWithIndex.collect {
// 3. For those unmatched records in right, append additional records with empty left.

// Re-visiting the records in right, and append additional row with empty left, if its not
// in the matched set.
// Re-visiting the records in right, and append additional row with empty left, if its not
// in the matched set.
case (r, idx) if (!rightMatchedSet.contains(idx)) => {
joinedRow(leftNullRow, r).copy
}
Expand All @@ -284,15 +283,15 @@ case class HashOuterJoin(
}

private[this] def buildHashTable(
iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, ArrayBuffer[Row]] = {
val hashTable = new JavaHashMap[Row, ArrayBuffer[Row]]()
iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = {
val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]()
while (iter.hasNext) {
val currentRow = iter.next()
val rowKey = keyGenerator(currentRow)

var existingMatchList = hashTable.get(rowKey)
if (existingMatchList == null) {
existingMatchList = new ArrayBuffer[Row]()
existingMatchList = new CompactBuffer[Row]()
hashTable.put(rowKey, existingMatchList)
}

Expand All @@ -311,20 +310,20 @@ case class HashOuterJoin(
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))

import scala.collection.JavaConversions._
val boundCondition =
val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
joinType match {
case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
rightHashTable.getOrElse(key, EMPTY_LIST))
}
case RightOuter => rightHashTable.keysIterator.flatMap { key =>
rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
rightHashTable.getOrElse(key, EMPTY_LIST))
}
case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
fullOuterIterator(key,
leftHashTable.getOrElse(key, EMPTY_LIST),
fullOuterIterator(key,
leftHashTable.getOrElse(key, EMPTY_LIST),
rightHashTable.getOrElse(key, EMPTY_LIST))
}
case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
Expand Down Expand Up @@ -550,7 +549,7 @@ case class BroadcastNestedLoopJoin(

/** All rows that either match both-way, or rows from streamed joined with nulls. */
val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
val matchedRows = new ArrayBuffer[Row]
val matchedRows = new CompactBuffer[Row]
// TODO: Use Spark's BitSet.
val includedBroadcastTuples =
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
Expand Down Expand Up @@ -602,20 +601,20 @@ case class BroadcastNestedLoopJoin(
val rightNulls = new GenericMutableRow(right.output.size)
/** Rows from broadcasted joined with nulls. */
val broadcastRowsWithNulls: Seq[Row] = {
val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer()
val buf: CompactBuffer[Row] = new CompactBuffer()
var i = 0
val rel = broadcastedRelation.value
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
(joinType, buildSide) match {
case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i))
case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls)
case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
case _ =>
}
}
i += 1
}
arrBuf.toSeq
buf.toSeq
}

// TODO: Breaks lineage.
Expand Down