@@ -22,6 +22,7 @@ import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStr
2222import org .apache .spark .executor .ShuffleWriteMetrics
2323import org .apache .spark .storage .ShuffleBlockId
2424import org .apache .spark .util .collection .ExternalSorter
25+ import org .apache .spark .util .Utils
2526import org .apache .spark .sql .Row
2627import org .apache .spark .sql .catalyst .{CatalystTypeConverters , InternalRow }
2728import org .apache .spark .sql .catalyst .expressions .{UnsafeProjection , UnsafeRow }
@@ -43,9 +44,15 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea
4344class UnsafeRowSerializerSuite extends SparkFunSuite {
4445
4546 private def toUnsafeRow (row : Row , schema : Array [DataType ]): UnsafeRow = {
46- val internalRow = CatalystTypeConverters .convertToCatalyst(row).asInstanceOf [InternalRow ]
47+ val converter = unsafeRowConverter(schema)
48+ converter(row)
49+ }
50+
51+ private def unsafeRowConverter (schema : Array [DataType ]): Row => UnsafeRow = {
4752 val converter = UnsafeProjection .create(schema)
48- converter.apply(internalRow)
53+ (row : Row ) => {
54+ converter(CatalystTypeConverters .convertToCatalyst(row).asInstanceOf [InternalRow ])
55+ }
4956 }
5057
5158 test(" toUnsafeRow() test helper method" ) {
@@ -92,37 +99,45 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
9299 }
93100
94101 test(" SPARK-10466: external sorter spilling with unsafe row serializer" ) {
95- val conf = new SparkConf ()
96- .set(" spark.shuffle.spill.initialMemoryThreshold" , " 1024" )
97- .set(" spark.shuffle.sort.bypassMergeThreshold" , " 0" )
98- .set(" spark.shuffle.memoryFraction" , " 0.0001" )
99102 var sc : SparkContext = null
100103 var outputFile : File = null
101- try {
102- sc = new SparkContext (" local" , " test" , conf)
103- outputFile = File .createTempFile(" test-unsafe-row-serializer-spill" , " " )
104- val data = (1 to 1000 ).iterator.map { i =>
105- (i, toUnsafeRow(Row (i), Array (IntegerType )))
106- }
107- val sorter = new ExternalSorter [Int , UnsafeRow , UnsafeRow ](
108- partitioner = Some (new HashPartitioner (10 )),
109- serializer = Some (new UnsafeRowSerializer (numFields = 1 )))
104+ val oldEnv = SparkEnv .get // save the old SparkEnv, as it will be overwritten
105+ Utils .tryWithSafeFinally {
106+ val conf = new SparkConf ()
107+ .set(" spark.shuffle.spill.initialMemoryThreshold" , " 1024" )
108+ .set(" spark.shuffle.sort.bypassMergeThreshold" , " 0" )
109+ .set(" spark.shuffle.memoryFraction" , " 0.0001" )
110110
111- // Ensure we spilled something and have to merge them later
112- assert(sorter.numSpills === 0 )
113- sorter.insertAll(data)
114- assert(sorter.numSpills > 0 )
111+ sc = new SparkContext (" local" , " test" , conf)
112+ outputFile = File .createTempFile(" test-unsafe-row-serializer-spill" , " " )
113+ // prepare data
114+ val converter = unsafeRowConverter(Array (IntegerType ))
115+ val data = (1 to 1000 ).iterator.map { i =>
116+ (i, converter(Row (i)))
117+ }
118+ val sorter = new ExternalSorter [Int , UnsafeRow , UnsafeRow ](
119+ partitioner = Some (new HashPartitioner (10 )),
120+ serializer = Some (new UnsafeRowSerializer (numFields = 1 )))
115121
116- // Merging spilled files should not throw assertion error
117- val taskContext = new TaskContextImpl ( 0 , 0 , 0 , 0 , null , null , InternalAccumulator .create(sc) )
118- taskContext.taskMetrics.shuffleWriteMetrics = Some ( new ShuffleWriteMetrics )
119- sorter.writePartitionedFile( ShuffleBlockId ( 0 , 0 , 0 ), taskContext, outputFile )
122+ // Ensure we spilled something and have to merge them later
123+ assert(sorter.numSpills === 0 )
124+ sorter.insertAll(data )
125+ assert( sorter.numSpills > 0 )
120126
121- } finally {
127+ // Merging spilled files should not throw assertion error
128+ val taskContext =
129+ new TaskContextImpl (0 , 0 , 0 , 0 , null , null , InternalAccumulator .create(sc))
130+ taskContext.taskMetrics.shuffleWriteMetrics = Some (new ShuffleWriteMetrics )
131+ sorter.writePartitionedFile(ShuffleBlockId (0 , 0 , 0 ), taskContext, outputFile)
132+ } {
122133 // Clean up
123134 if (sc != null ) {
124135 sc.stop()
125136 }
137+
138+ // restore the spark env
139+ SparkEnv .set(oldEnv)
140+
126141 if (outputFile != null ) {
127142 outputFile.delete()
128143 }
0 commit comments