Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class ExternalAppendOnlyMap[K, V, C](
if (it.hasNext) {
var kc = it.next()
kcPairs += kc
val minHash = kc._1.hashCode()
val minHash = getKeyHashCode(kc)
while (it.hasNext && it.head._1.hashCode() == minHash) {
kc = it.next()
kcPairs += kc
Expand Down Expand Up @@ -294,8 +294,9 @@ class ExternalAppendOnlyMap[K, V, C](
// Select a key from the StreamBuffer that holds the lowest key hash
val minBuffer = mergeHeap.dequeue()
val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash)
var (minKey, minCombiner) = minPairs.remove(0)
assert(minKey.hashCode() == minHash)
val minPair = minPairs.remove(0)
var (minKey, minCombiner) = minPair
assert(getKeyHashCode(minPair) == minHash)

// For all other streams that may have this key (i.e. have the same minimum key hash),
// merge in the corresponding value (if any) from that stream
Expand Down Expand Up @@ -327,15 +328,16 @@ class ExternalAppendOnlyMap[K, V, C](
* StreamBuffers are ordered by the minimum key hash found across all of their own pairs.
*/
private class StreamBuffer(
val iterator: BufferedIterator[(K, C)], val pairs: ArrayBuffer[(K, C)])
val iterator: BufferedIterator[(K, C)],
val pairs: ArrayBuffer[(K, C)])
extends Comparable[StreamBuffer] {

def isEmpty = pairs.length == 0

// Invalid if there are no more pairs in this stream
def minKeyHash = {
def minKeyHash: Int = {
assert(pairs.length > 0)
pairs.head._1.hashCode()
getKeyHashCode(pairs.head)
}

override def compareTo(other: StreamBuffer): Int = {
Expand Down Expand Up @@ -422,10 +424,22 @@ class ExternalAppendOnlyMap[K, V, C](
}

private[spark] object ExternalAppendOnlyMap {

/**
* Return the key hash code of the given (key, combiner) pair.
* If the key is null, return a special hash code.
*/
private def getKeyHashCode[K, C](kc: (K, C)): Int = {
if (kc._1 == null) 0 else kc._1.hashCode()
}

/**
* A comparator for (key, combiner) pairs based on their key hash codes.
*/
private class KCComparator[K, C] extends Comparator[(K, C)] {
def compare(kc1: (K, C), kc2: (K, C)): Int = {
val hash1 = kc1._1.hashCode()
val hash2 = kc2._1.hashCode()
val hash1 = getKeyHashCode(kc1)
val hash2 = getKeyHashCode(kc2)
if (hash1 < hash2) -1 else if (hash1 == hash2) 0 else 1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,8 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)

val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
mergeValue, mergeCombiners)
val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](
createCombiner, mergeValue, mergeCombiners)

(1 to 100000).foreach { i => map.insert(i, i) }
map.insert(Int.MaxValue, Int.MaxValue)
Expand All @@ -346,11 +346,32 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
it.next()
}
}

test("spilling with null keys and values") {
val conf = new SparkConf(true)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)

val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](
createCombiner, mergeValue, mergeCombiners)

(1 to 100000).foreach { i => map.insert(i, i) }
map.insert(null.asInstanceOf[Int], 1)
map.insert(1, null.asInstanceOf[Int])
map.insert(null.asInstanceOf[Int], null.asInstanceOf[Int])

val it = map.iterator
while (it.hasNext) {
// Should not throw NullPointerException
it.next()
}
}

}

/**
* A dummy class that always returns the same hash code, to easily test hash collisions
*/
case class FixedHashObject(val v: Int, val h: Int) extends Serializable {
case class FixedHashObject(v: Int, h: Int) extends Serializable {
override def hashCode(): Int = h
}