From b8dd7eb765df0e48991eb231a4898a513100f133 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 9 Sep 2015 17:59:39 -0700 Subject: [PATCH 1/3] simplify the unit test --- .../util/collection/ExternalSorter.scala | 6 +++ .../sql/execution/UnsafeRowSerializer.scala | 2 +- .../execution/UnsafeRowSerializerSuite.scala | 45 ++++++++++++++++++- 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 19287edbaf166..138c05dff19e4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -188,6 +188,12 @@ private[spark] class ExternalSorter[K, V, C]( private val spills = new ArrayBuffer[SpilledFile] + /** + * Number of files this sorter has spilled so far. + * Exposed for testing. + */ + private[spark] def numSpills: Int = spills.size + override def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 5c18558f9bde7..e060c06d9e2a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -72,7 +72,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def writeKey[T: ClassTag](key: T): SerializationStream = { // The key is only needed on the map side when computing partition ids. It does not need to // be shuffled. - assert(key.isInstanceOf[Int]) + assert(null == key || key.isInstanceOf[Int]) this } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index bd02c73a26ace..8ad1515a17697 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,13 +17,16 @@ package org.apache.spark.sql.execution -import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} -import org.apache.spark.SparkFunSuite +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ +import org.apache.spark._ /** @@ -87,4 +90,42 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { assert(!deserializerIter.hasNext) assert(input.closed) } + + test("SPARK-10466: external sorter spilling with unsafe row serializer") { + val conf = new SparkConf() + .set("spark.shuffle.spill.initialMemoryThreshold", "1024") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + .set("spark.shuffle.memoryFraction", "0.0001") + var sc: SparkContext = null + var outputFile: File = null + try { + sc = new SparkContext("local", "test", conf) + outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + val data = (1 to 1000).iterator.map { i => + (i, toUnsafeRow(Row(i), Array(IntegerType))) + } + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + partitioner = Some(new HashPartitioner(10)), + serializer = Some(new UnsafeRowSerializer(numFields = 1))) + + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) + + // Merging spilled files should not throw assertion error + val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc)) + taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile) + + } finally { + // Clean up + if (sc != null) { + sc.stop() + } + if (outputFile != null) { + outputFile.delete() + } + } + } } From 68ff3d38a39753c9e45b9222d4c32c541030f19c Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 9 Sep 2015 19:42:08 -0700 Subject: [PATCH 2/3] restore the SparkEnv after SparkContext.stop() --- .../execution/UnsafeRowSerializerSuite.scala | 63 ++++++++++++------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 8ad1515a17697..e24cf69ff42f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -22,6 +22,7 @@ import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStr import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.Utils import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} @@ -43,9 +44,15 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea class UnsafeRowSerializerSuite extends SparkFunSuite { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { - val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] + val converter = unsafeRowConverter(schema) + converter(row) + } + + private def unsafeRowConverter(schema: Array[DataType]): Row => UnsafeRow = { val converter = UnsafeProjection.create(schema) - converter.apply(internalRow) + (row: Row) => { + converter(CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow]) + } } test("toUnsafeRow() test helper method") { @@ -92,37 +99,45 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { } test("SPARK-10466: external sorter spilling with unsafe row serializer") { - val conf = new SparkConf() - .set("spark.shuffle.spill.initialMemoryThreshold", "1024") - .set("spark.shuffle.sort.bypassMergeThreshold", "0") - .set("spark.shuffle.memoryFraction", "0.0001") var sc: SparkContext = null var outputFile: File = null - try { - sc = new SparkContext("local", "test", conf) - outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") - val data = (1 to 1000).iterator.map { i => - (i, toUnsafeRow(Row(i), Array(IntegerType))) - } - val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( - partitioner = Some(new HashPartitioner(10)), - serializer = Some(new UnsafeRowSerializer(numFields = 1))) + val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten + Utils.tryWithSafeFinally { + val conf = new SparkConf() + .set("spark.shuffle.spill.initialMemoryThreshold", "1024") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + .set("spark.shuffle.memoryFraction", "0.0001") - // Ensure we spilled something and have to merge them later - assert(sorter.numSpills === 0) - sorter.insertAll(data) - assert(sorter.numSpills > 0) + sc = new SparkContext("local", "test", conf) + outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + // prepare data + val converter = unsafeRowConverter(Array(IntegerType)) + val data = (1 to 1000).iterator.map { i => + (i, converter(Row(i))) + } + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + partitioner = Some(new HashPartitioner(10)), + serializer = Some(new UnsafeRowSerializer(numFields = 1))) - // Merging spilled files should not throw assertion error - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc)) - taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) - sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile) + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) - } finally { + // Merging spilled files should not throw assertion error + val taskContext = + new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc)) + taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile) + } { // Clean up if (sc != null) { sc.stop() } + + // restore the spark env + SparkEnv.set(oldEnv) + if (outputFile != null) { outputFile.delete() } From e8b27b5515251c856b7711c0c253e3b92e354fab Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 9 Sep 2015 19:53:51 -0700 Subject: [PATCH 3/3] code style --- .../execution/UnsafeRowSerializerSuite.scala | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index e24cf69ff42f5..0113d052e338d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -108,27 +108,27 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { .set("spark.shuffle.sort.bypassMergeThreshold", "0") .set("spark.shuffle.memoryFraction", "0.0001") - sc = new SparkContext("local", "test", conf) - outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") - // prepare data - val converter = unsafeRowConverter(Array(IntegerType)) - val data = (1 to 1000).iterator.map { i => - (i, converter(Row(i))) - } - val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( - partitioner = Some(new HashPartitioner(10)), - serializer = Some(new UnsafeRowSerializer(numFields = 1))) - - // Ensure we spilled something and have to merge them later - assert(sorter.numSpills === 0) - sorter.insertAll(data) - assert(sorter.numSpills > 0) - - // Merging spilled files should not throw assertion error - val taskContext = - new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc)) - taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) - sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile) + sc = new SparkContext("local", "test", conf) + outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + // prepare data + val converter = unsafeRowConverter(Array(IntegerType)) + val data = (1 to 1000).iterator.map { i => + (i, converter(Row(i))) + } + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + partitioner = Some(new HashPartitioner(10)), + serializer = Some(new UnsafeRowSerializer(numFields = 1))) + + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) + + // Merging spilled files should not throw assertion error + val taskContext = + new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc)) + taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile) } { // Clean up if (sc != null) {