@@ -707,4 +707,53 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
707707 Some (agg), Some (new HashPartitioner (FEW_PARTITIONS )), None , None )
708708 assertDidNotBypassMergeSort(sorter4)
709709 }
710+
711+ test(" sort without breaking sorting contracts" ) {
712+ val conf = createSparkConf(true )
713+ conf.set(" spark.shuffle.memoryFraction" , " 0.001" )
714+ conf.set(" spark.shuffle.manager" , " sort" )
715+ sc = new SparkContext (" local-cluster[1,1,512]" , " test" , conf)
716+
717+ val testData = Array [String ](
718+ " hierarch" , // -1732884796
719+ " variants" , // -1249574770
720+ " inwork" , // -1183663690
721+ " isohel" , // -1179291542
722+ " misused" // 1069518484
723+ )
724+ val expected = testData.map(s => (s, 200000 ))
725+
726+ def createCombiner (i : Int ) = ArrayBuffer (i)
727+ def mergeValue (c : ArrayBuffer [Int ], i : Int ) = c += i
728+ def mergeCombiners (c1 : ArrayBuffer [Int ], c2 : ArrayBuffer [Int ]) = c1 ++= c2
729+
730+ val agg = new Aggregator [String , Int , ArrayBuffer [Int ]](
731+ createCombiner, mergeValue, mergeCombiners)
732+
733+ // Using wrongHashOrdering to show that integer overflow will lead to wrong sort result.
734+ val wrongHashOrdering = new Ordering [String ] {
735+ override def compare (a : String , b : String ) = {
736+ val h1 = a.hashCode()
737+ val h2 = b.hashCode()
738+ h1 - h2
739+ }
740+ }
741+ val sorter1 = new ExternalSorter [String , Int , ArrayBuffer [Int ]](
742+ None , None , Some (wrongHashOrdering), None )
743+ sorter1.insertAll(expected.iterator)
744+
745+ val unexpectedResults = sorter1.iterator.toArray
746+ assert(unexpectedResults !== expected)
747+
748+ // Using aggregation and external spill to make sure ExternalSorter using
749+ // partitionKeyComparator.
750+ val sorter2 = new ExternalSorter [String , Int , ArrayBuffer [Int ]](
751+ Some (agg), None , None , None )
752+ sorter2.insertAll(expected.flatMap { case (k, v) =>
753+ (0 until v).map(_ => (k, 1 ))
754+ }.iterator)
755+
756+ val expectedResults = sorter2.iterator.map(kv => (kv._1, kv._2.sum)).toArray
757+ assert(expectedResults === expected)
758+ }
710759}
0 commit comments