Skip to content

Commit c480537

Browse files
andrewor14aarondav
authored andcommitted
[SPARK] Fix NPE for ExternalAppendOnlyMap
It did not handle null keys very gracefully before. Author: Andrew Or <[email protected]> Closes apache#1288 from andrewor14/fix-external and squashes the following commits: 312b8d8 [Andrew Or] Abstract key hash code ed5adf9 [Andrew Or] Fix NPE for ExternalAppendOnlyMap
1 parent 3bbeca6 commit c480537

File tree

2 files changed

+46
-11
lines changed

2 files changed

+46
-11
lines changed

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ class ExternalAppendOnlyMap[K, V, C](
252252
if (it.hasNext) {
253253
var kc = it.next()
254254
kcPairs += kc
255-
val minHash = kc._1.hashCode()
255+
val minHash = getKeyHashCode(kc)
256256
while (it.hasNext && it.head._1.hashCode() == minHash) {
257257
kc = it.next()
258258
kcPairs += kc
@@ -294,8 +294,9 @@ class ExternalAppendOnlyMap[K, V, C](
294294
// Select a key from the StreamBuffer that holds the lowest key hash
295295
val minBuffer = mergeHeap.dequeue()
296296
val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash)
297-
var (minKey, minCombiner) = minPairs.remove(0)
298-
assert(minKey.hashCode() == minHash)
297+
val minPair = minPairs.remove(0)
298+
var (minKey, minCombiner) = minPair
299+
assert(getKeyHashCode(minPair) == minHash)
299300

300301
// For all other streams that may have this key (i.e. have the same minimum key hash),
301302
// merge in the corresponding value (if any) from that stream
@@ -327,15 +328,16 @@ class ExternalAppendOnlyMap[K, V, C](
327328
* StreamBuffers are ordered by the minimum key hash found across all of their own pairs.
328329
*/
329330
private class StreamBuffer(
330-
val iterator: BufferedIterator[(K, C)], val pairs: ArrayBuffer[(K, C)])
331+
val iterator: BufferedIterator[(K, C)],
332+
val pairs: ArrayBuffer[(K, C)])
331333
extends Comparable[StreamBuffer] {
332334

333335
def isEmpty = pairs.length == 0
334336

335337
// Invalid if there are no more pairs in this stream
336-
def minKeyHash = {
338+
def minKeyHash: Int = {
337339
assert(pairs.length > 0)
338-
pairs.head._1.hashCode()
340+
getKeyHashCode(pairs.head)
339341
}
340342

341343
override def compareTo(other: StreamBuffer): Int = {
@@ -422,10 +424,22 @@ class ExternalAppendOnlyMap[K, V, C](
422424
}
423425

424426
private[spark] object ExternalAppendOnlyMap {
427+
428+
/**
429+
* Return the key hash code of the given (key, combiner) pair.
430+
* If the key is null, return a special hash code.
431+
*/
432+
private def getKeyHashCode[K, C](kc: (K, C)): Int = {
433+
if (kc._1 == null) 0 else kc._1.hashCode()
434+
}
435+
436+
/**
437+
* A comparator for (key, combiner) pairs based on their key hash codes.
438+
*/
425439
private class KCComparator[K, C] extends Comparator[(K, C)] {
426440
def compare(kc1: (K, C), kc2: (K, C)): Int = {
427-
val hash1 = kc1._1.hashCode()
428-
val hash2 = kc2._1.hashCode()
441+
val hash1 = getKeyHashCode(kc1)
442+
val hash2 = getKeyHashCode(kc2)
429443
if (hash1 < hash2) -1 else if (hash1 == hash2) 0 else 1
430444
}
431445
}

core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,8 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
334334
conf.set("spark.shuffle.memoryFraction", "0.001")
335335
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
336336

337-
val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
338-
mergeValue, mergeCombiners)
337+
val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](
338+
createCombiner, mergeValue, mergeCombiners)
339339

340340
(1 to 100000).foreach { i => map.insert(i, i) }
341341
map.insert(Int.MaxValue, Int.MaxValue)
@@ -346,11 +346,32 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
346346
it.next()
347347
}
348348
}
349+
350+
test("spilling with null keys and values") {
351+
val conf = new SparkConf(true)
352+
conf.set("spark.shuffle.memoryFraction", "0.001")
353+
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
354+
355+
val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](
356+
createCombiner, mergeValue, mergeCombiners)
357+
358+
(1 to 100000).foreach { i => map.insert(i, i) }
359+
map.insert(null.asInstanceOf[Int], 1)
360+
map.insert(1, null.asInstanceOf[Int])
361+
map.insert(null.asInstanceOf[Int], null.asInstanceOf[Int])
362+
363+
val it = map.iterator
364+
while (it.hasNext) {
365+
// Should not throw NullPointerException
366+
it.next()
367+
}
368+
}
369+
349370
}
350371

351372
/**
352373
* A dummy class that always returns the same hash code, to easily test hash collisions
353374
*/
354-
case class FixedHashObject(val v: Int, val h: Int) extends Serializable {
375+
case class FixedHashObject(v: Int, h: Int) extends Serializable {
355376
override def hashCode(): Int = h
356377
}

0 commit comments

Comments
 (0)