diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala index d0d25b43d047..8f83262b63ff 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala @@ -30,8 +30,7 @@ private[spark] class PartitionedAppendOnlyMap[K, V] def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) : Iterator[((Int, K), V)] = { - val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator) - destructiveSortedIterator(comparator) + destructiveSortedIterator(getComparator(keyComparator)) } def insert(partition: Int, key: K, value: V): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index f5844d5353be..e43bc78b3f08 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -74,7 +74,7 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) /** Iterate through the data in a given order. For this class this is not really destructive. */ override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) : Iterator[((Int, K), V)] = { - val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator) + val comparator = getComparator(keyComparator) new Sorter(new KVArraySortDataFormat[(Int, K), AnyRef]).sort(data, 0, curSize, comparator) iterator } diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 5232c2bd8d6f..cb50d446011d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -75,16 +75,22 @@ private[spark] object WritablePartitionedPairCollection { } /** - * A comparator for (Int, K) pairs that orders them both by their partition ID and a key ordering. + * Takes an optional parameter (keyComparator), use if provided + * and returns a comparator for the partitions */ - def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = { - new Comparator[(Int, K)] { - override def compare(a: (Int, K), b: (Int, K)): Int = { - val partitionDiff = a._1 - b._1 - if (partitionDiff != 0) { - partitionDiff - } else { - keyComparator.compare(a._2, b._2) + def getComparator[K](keyComparator: Option[Comparator[K]]): Comparator[(Int, K)] = { + if (!keyComparator.isDefined) return partitionComparator + else { + val theKeyComp = keyComparator.get + new Comparator[(Int, K)] { + // We know we have a non-empty comparator here + override def compare(a: (Int, K), b: (Int, K)): Int = { + val partitionDiff = a._1 - b._1 + if (partitionDiff != 0) { + partitionDiff + } else { + theKeyComp.compare(a._2, b._2) + } } } }