From fbefa0059527ae64719d86c5aa66321be6904df9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Sep 2015 14:30:22 -0700 Subject: [PATCH 01/26] Rename HashShuffleReader to BlockStoreShuffleReader. --- ...shShuffleReader.scala => BlockStoreShuffleReader.scala} | 5 ++--- .../org/apache/spark/shuffle/hash/HashShuffleManager.scala | 2 +- .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 3 +-- ...eaderSuite.scala => BlockStoreShuffleReaderSuite.scala} | 7 +++---- 4 files changed, 7 insertions(+), 10 deletions(-) rename core/src/main/scala/org/apache/spark/shuffle/{hash/HashShuffleReader.scala => BlockStoreShuffleReader.scala} (97%) rename core/src/test/scala/org/apache/spark/shuffle/{hash/HashShuffleReaderSuite.scala => BlockStoreShuffleReaderSuite.scala} (96%) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala rename to core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0c8f08f0f3b1..6dc9a16e5853 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -15,16 +15,15 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -private[spark] class HashShuffleReader[K, C]( +private[spark] class BlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 0b46634b8b46..d2e2fc4c110a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -51,7 +51,7 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 476cc1f303da..9df4e551669c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { @@ -54,7 +53,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { // We currently use the same block store shuffle fetcher as the hash-based shuffle. - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala similarity index 96% rename from core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 05b3afef5b83..a5eafb1b5529 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer @@ -28,7 +28,6 @@ import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -56,7 +55,7 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed } } -class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { +class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { /** * This test makes sure that, when data is read from a HashShuffleReader, the underlying @@ -134,7 +133,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { new BaseShuffleHandle(shuffleId, numMaps, dependency) } - val shuffleReader = new HashShuffleReader( + val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, From 80d67015614a29980173ce9f4fc0b0ce0c24a6c0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 18 Sep 2015 12:59:26 -0700 Subject: [PATCH 02/26] WIP towards consolidation of sort shuffle implementations. --- .../{unsafe => sort}/PackedRecordPointer.java | 2 +- .../ShuffleExternalSorter.java} | 16 +- .../ShuffleInMemorySorter.java} | 16 +- .../ShuffleSortDataFormat.java} | 8 +- .../shuffle/{unsafe => sort}/SpillInfo.java | 4 +- .../{unsafe => sort}/UnsafeShuffleWriter.java | 12 +- .../shuffle/sort/SortShuffleManager.scala | 138 +++++++-- .../shuffle/unsafe/UnsafeShuffleManager.scala | 202 ------------- .../spark/util/collection/ChainedBuffer.scala | 146 ---------- .../util/collection/ExternalSorter.scala | 16 +- .../PartitionedSerializedPairBuffer.scala | 273 ------------------ .../PackedRecordPointerSuite.java | 5 +- .../ShuffleInMemorySorterSuite.java} | 16 +- .../UnsafeShuffleWriterSuite.java | 6 +- .../org/apache/spark/SortShuffleSuite.scala | 73 +++++ .../SortShuffleManagerSuite.scala} | 8 +- .../shuffle/unsafe/UnsafeShuffleSuite.scala | 105 ------- .../util/collection/ChainedBufferSuite.scala | 144 --------- ...PartitionedSerializedPairBufferSuite.scala | 148 ---------- .../apache/spark/sql/execution/Exchange.scala | 16 +- 20 files changed, 249 insertions(+), 1105 deletions(-) rename core/src/main/java/org/apache/spark/shuffle/{unsafe => sort}/PackedRecordPointer.java (98%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleExternalSorter.java => sort/ShuffleExternalSorter.java} (97%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleInMemorySorter.java => sort/ShuffleInMemorySorter.java} (88%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleSortDataFormat.java => sort/ShuffleSortDataFormat.java} (86%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe => sort}/SpillInfo.java (90%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe => sort}/UnsafeShuffleWriter.java (98%) delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala rename core/src/test/java/org/apache/spark/shuffle/{unsafe => sort}/PackedRecordPointerSuite.java (96%) rename core/src/test/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleInMemorySorterSuite.java => sort/ShuffleInMemorySorterSuite.java} (87%) rename core/src/test/java/org/apache/spark/shuffle/{unsafe => sort}/UnsafeShuffleWriterSuite.java (99%) rename core/src/test/scala/org/apache/spark/shuffle/{unsafe/UnsafeShuffleManagerSuite.scala => sort/SortShuffleManagerSuite.scala} (94%) delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java similarity index 98% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java rename to core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java index 4ee6a82c0423..c11711966fa8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; /** * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java similarity index 97% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index e73ba3946882..1452b4fdbb1f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import javax.annotation.Nullable; import java.io.File; @@ -48,7 +48,7 @@ *

* Incoming records are appended to data pages. When all records have been inserted (or when the * current thread's shuffle memory limit is reached), the in-memory records are sorted according to - * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then + * their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then * written to a single output file (or multiple files, if we've spilled). The format of the output * files is the same as the format of the final output file written by * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are @@ -59,9 +59,9 @@ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a * specialized merge procedure that avoids extra serialization/deserialization. */ -final class UnsafeShuffleExternalSorter { +final class ShuffleExternalSorter { - private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); + private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class); @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; @@ -94,12 +94,12 @@ final class UnsafeShuffleExternalSorter { private long peakMemoryUsedBytes; // These variables are reset after spilling: - @Nullable private UnsafeShuffleInMemorySorter inMemSorter; + @Nullable private ShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; - public UnsafeShuffleExternalSorter( + public ShuffleExternalSorter( TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, @@ -140,7 +140,7 @@ private void initializeForWriting() throws IOException { throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); } - this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize); + this.inMemSorter = new ShuffleInMemorySorter(initialSize); } /** @@ -166,7 +166,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } // This call performs the actual sort. - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = + final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = inMemSorter.getSortedIterator(); // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java similarity index 88% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 5bab501da936..a8dee6c6101c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.util.Comparator; import org.apache.spark.util.collection.Sorter; -final class UnsafeShuffleInMemorySorter { +final class ShuffleInMemorySorter { private final Sorter sorter; private static final class SortComparator implements Comparator { @@ -44,10 +44,10 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { */ private int pointerArrayInsertPosition = 0; - public UnsafeShuffleInMemorySorter(int initialSize) { + public ShuffleInMemorySorter(int initialSize) { assert (initialSize > 0); this.pointerArray = new long[initialSize]; - this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); + this.sorter = new Sorter(ShuffleSortDataFormat.INSTANCE); } public void expandPointerArray() { @@ -92,14 +92,14 @@ public void insertRecord(long recordPointer, int partitionId) { /** * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining. */ - public static final class UnsafeShuffleSorterIterator { + public static final class ShuffleSorterIterator { private final long[] pointerArray; private final int numRecords; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); private int position = 0; - public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) { + public ShuffleSorterIterator(int numRecords, long[] pointerArray) { this.numRecords = numRecords; this.pointerArray = pointerArray; } @@ -117,8 +117,8 @@ public void loadNext() { /** * Return an iterator over record pointers in sorted order. */ - public UnsafeShuffleSorterIterator getSortedIterator() { + public ShuffleSorterIterator getSortedIterator() { sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); - return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java similarity index 86% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index a66d74ee4478..8a1e5aec6ff0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import org.apache.spark.util.collection.SortDataFormat; -final class UnsafeShuffleSortDataFormat extends SortDataFormat { +final class ShuffleSortDataFormat extends SortDataFormat { - public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat(); + public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat(); - private UnsafeShuffleSortDataFormat() { } + private ShuffleSortDataFormat() { } @Override public PackedRecordPointer getKey(long[] data, int pos) { diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java similarity index 90% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java rename to core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java index 7bac0dc0bbeb..df9f7b7abe02 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.io.File; import org.apache.spark.storage.TempShuffleBlockId; /** - * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}. + * Metadata for a block of data written by {@link ShuffleExternalSorter}. */ final class SpillInfo { final long[] partitionLengths; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java similarity index 98% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index fdb309e365f6..1403468e149e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -15,13 +15,15 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; import java.util.Iterator; +import org.apache.spark.shuffle.sort.SortShuffleManager; +import org.apache.spark.shuffle.sort.UnsafeShuffleHandle; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -80,7 +82,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final boolean transferToEnabled; @Nullable private MapStatus mapStatus; - @Nullable private UnsafeShuffleExternalSorter sorter; + @Nullable private ShuffleExternalSorter sorter; private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ @@ -109,10 +111,10 @@ public UnsafeShuffleWriter( TaskContext taskContext, SparkConf sparkConf) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); - if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { + if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { throw new IllegalArgumentException( "UnsafeShuffleWriter can only be used for shuffles with at most " + - UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); } this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; @@ -195,7 +197,7 @@ public void write(scala.collection.Iterator> records) throws IOEx private void open() throws IOException { assert (sorter == null); - sorter = new UnsafeShuffleExternalSorter( + sorter = new ShuffleExternalSorter( memoryManager, shuffleMemoryManager, blockManager, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 9df4e551669c..7da9c990440e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,9 +19,97 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} +import org.apache.spark._ +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. + */ +private[spark] class UnsafeShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} + +private[spark] object SortShuffleManager extends Logging { + + /** + * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. + */ + val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + + /** + * Helper method for determining whether a shuffle should use the optimized unsafe shuffle + * path or whether it should fall back to the original sort-based shuffle. + */ + def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { + val shufId = dependency.shuffleId + val serializer = Serializer.getSerializer(dependency.serializer) + if (!serializer.supportsRelocationOfSerializedObjects) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + + s"${serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.aggregator.isDefined) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") + false + } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") + false + } else { + log.debug(s"Can use UnsafeShuffle for shuffle $shufId") + true + } + } +} + +/** + * A shuffle implementation that uses directly-managed memory to implement several performance + * optimizations for certain types of shuffles. In cases where the new performance optimizations + * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those + * shuffles. + * + * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: + * + * - The shuffle dependency specifies no aggregation or output ordering. + * - The shuffle serializer supports relocation of serialized values (this is currently supported + * by KryoSerializer and Spark SQL's custom serializers). + * - The shuffle produces fewer than 16777216 output partitions. + * + * In addition, extra spill-merging optimizations are automatically applied when the shuffle + * compression codec supports concatenation of serialized streams. This is currently supported by + * Spark's LZF serializer. + * + * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. + * In sort-based shuffle, incoming records are sorted according to their target partition ids, then + * written to a single map output file. Reducers fetch contiguous regions of this file in order to + * read their portion of the map output. In cases where the map output data is too large to fit in + * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * to produce the final output file. + * + * UnsafeShuffleManager optimizes this process in several ways: + * + * - Its sort operates on serialized binary data rather than Java objects, which reduces memory + * consumption and GC overheads. This optimization requires the record serializer to have certain + * properties to allow serialized records to be re-ordered without requiring deserialization. + * See SPARK-4550, where this optimization was first proposed and implemented, for more details. + * + * - It uses a specialized cache-efficient sorter ([[ShuffleExternalSorter]]) that sorts + * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per + * record in the sorting array, this fits more of the array into cache. + * + * - The spill merging procedure operates on blocks of serialized records that belong to the same + * partition and does not need to deserialize records during the merge. + * + * - When the spill compression codec supports concatenation of compressed data, the spill merge + * simply concatenates the serialized and compressed spill partitions to produce the final output + * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used + * and avoids the need to allocate decompression or copying buffers during the merge. + * + * For more details on UnsafeShuffleManager's design, see SPARK-7081. + */ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { if (!conf.getBoolean("spark.shuffle.spill", true)) { @@ -30,8 +118,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager " Shuffle will continue to spill to disk when necessary.") } - private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf) - private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() + private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. @@ -40,7 +128,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - new BaseShuffleHandle(shuffleId, numMaps, dependency) + if (SortShuffleManager.canUseUnsafeShuffle(dependency)) { + new UnsafeShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } } /** @@ -58,32 +151,41 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager } /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) - : ShuffleWriter[K, V] = { - val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]] - shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) - new SortShuffleWriter( - shuffleBlockResolver, baseShuffleHandle, mapId, context) + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { + numMapsForShuffle.putIfAbsent( + handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) + handle match { + case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] => + val env = SparkEnv.get + new UnsafeShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + context.taskMemoryManager(), + env.shuffleMemoryManager, + unsafeShuffleHandle, + mapId, + context, + env.conf) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => + new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) + } } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - if (shuffleMapNumber.containsKey(shuffleId)) { - val numMaps = shuffleMapNumber.remove(shuffleId) - (0 until numMaps).map{ mapId => + Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps => + (0 until numMaps).foreach { mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) } } true } - override val shuffleBlockResolver: IndexShuffleBlockResolver = { - indexShuffleBlockResolver - } - /** Shut down this ShuffleManager. */ override def stop(): Unit = { shuffleBlockResolver.stop() } } - diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala deleted file mode 100644 index 75f22f642b9d..000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe - -import java.util.Collections -import java.util.concurrent.ConcurrentHashMap - -import org.apache.spark._ -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.sort.SortShuffleManager - -/** - * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. - */ -private[spark] class UnsafeShuffleHandle[K, V]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, V]) - extends BaseShuffleHandle(shuffleId, numMaps, dependency) { -} - -private[spark] object UnsafeShuffleManager extends Logging { - - /** - * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. - */ - val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 - - /** - * Helper method for determining whether a shuffle should use the optimized unsafe shuffle - * path or whether it should fall back to the original sort-based shuffle. - */ - def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { - val shufId = dependency.shuffleId - val serializer = Serializer.getSerializer(dependency.serializer) - if (!serializer.supportsRelocationOfSerializedObjects) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + - s"${serializer.getClass.getName}, does not support object relocation") - false - } else if (dependency.aggregator.isDefined) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") - false - } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + - s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") - false - } else { - log.debug(s"Can use UnsafeShuffle for shuffle $shufId") - true - } - } -} - -/** - * A shuffle implementation that uses directly-managed memory to implement several performance - * optimizations for certain types of shuffles. In cases where the new performance optimizations - * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those - * shuffles. - * - * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: - * - * - The shuffle dependency specifies no aggregation or output ordering. - * - The shuffle serializer supports relocation of serialized values (this is currently supported - * by KryoSerializer and Spark SQL's custom serializers). - * - The shuffle produces fewer than 16777216 output partitions. - * - No individual record is larger than 128 MB when serialized. - * - * In addition, extra spill-merging optimizations are automatically applied when the shuffle - * compression codec supports concatenation of serialized streams. This is currently supported by - * Spark's LZF serializer. - * - * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. - * In sort-based shuffle, incoming records are sorted according to their target partition ids, then - * written to a single map output file. Reducers fetch contiguous regions of this file in order to - * read their portion of the map output. In cases where the map output data is too large to fit in - * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged - * to produce the final output file. - * - * UnsafeShuffleManager optimizes this process in several ways: - * - * - Its sort operates on serialized binary data rather than Java objects, which reduces memory - * consumption and GC overheads. This optimization requires the record serializer to have certain - * properties to allow serialized records to be re-ordered without requiring deserialization. - * See SPARK-4550, where this optimization was first proposed and implemented, for more details. - * - * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts - * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per - * record in the sorting array, this fits more of the array into cache. - * - * - The spill merging procedure operates on blocks of serialized records that belong to the same - * partition and does not need to deserialize records during the merge. - * - * - When the spill compression codec supports concatenation of compressed data, the spill merge - * simply concatenates the serialized and compressed spill partitions to produce the final output - * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used - * and avoids the need to allocate decompression or copying buffers during the merge. - * - * For more details on UnsafeShuffleManager's design, see SPARK-7081. - */ -private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { - - if (!conf.getBoolean("spark.shuffle.spill", true)) { - logWarning( - "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " + - "manager; its optimized shuffles will continue to spill to disk when necessary.") - } - - private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) - private[this] val shufflesThatFellBackToSortShuffle = - Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]()) - private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]() - - /** - * Register a shuffle with the manager and obtain a handle for it to pass to tasks. - */ - override def registerShuffle[K, V, C]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) { - new UnsafeShuffleHandle[K, V]( - shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else { - new BaseShuffleHandle(shuffleId, numMaps, dependency) - } - } - - /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). - * Called on executors by reduce tasks. - */ - override def getReader[K, C]( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext): ShuffleReader[K, C] = { - sortShuffleManager.getReader(handle, startPartition, endPartition, context) - } - - /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Int, - context: TaskContext): ShuffleWriter[K, V] = { - handle match { - case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] => - numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) - val env = SparkEnv.get - new UnsafeShuffleWriter( - env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], - context.taskMemoryManager(), - env.shuffleMemoryManager, - unsafeShuffleHandle, - mapId, - context, - env.conf) - case other => - shufflesThatFellBackToSortShuffle.add(handle.shuffleId) - sortShuffleManager.getWriter(handle, mapId, context) - } - } - - /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Boolean = { - if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) { - sortShuffleManager.unregisterShuffle(shuffleId) - } else { - Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps => - (0 until numMaps).foreach { mapId => - shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } - } - true - } - } - - override val shuffleBlockResolver: IndexShuffleBlockResolver = { - sortShuffleManager.shuffleBlockResolver - } - - /** Shut down this ShuffleManager. */ - override def stop(): Unit = { - sortShuffleManager.stop() - } -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala deleted file mode 100644 index ae60f3b0cb55..000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.OutputStream - -import scala.collection.mutable.ArrayBuffer - -/** - * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The - * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts - * of memory and needing to copy the full contents. The disadvantage is that the contents don't - * occupy a contiguous segment of memory. - */ -private[spark] class ChainedBuffer(chunkSize: Int) { - - private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros( - java.lang.Long.highestOneBit(chunkSize)) - assert((1 << chunkSizeLog2) == chunkSize, - s"ChainedBuffer chunk size $chunkSize must be a power of two") - private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]() - private var _size: Long = 0 - - /** - * Feed bytes from this buffer into a DiskBlockObjectWriter. - * - * @param pos Offset in the buffer to read from. - * @param os OutputStream to read into. - * @param len Number of bytes to read. - */ - def read(pos: Long, os: OutputStream, len: Int): Unit = { - if (pos + len > _size) { - throw new IndexOutOfBoundsException( - s"Read of $len bytes at position $pos would go past size ${_size} of buffer") - } - var chunkIndex: Int = (pos >> chunkSizeLog2).toInt - var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt - var written: Int = 0 - while (written < len) { - val toRead: Int = math.min(len - written, chunkSize - posInChunk) - os.write(chunks(chunkIndex), posInChunk, toRead) - written += toRead - chunkIndex += 1 - posInChunk = 0 - } - } - - /** - * Read bytes from this buffer into a byte array. - * - * @param pos Offset in the buffer to read from. - * @param bytes Byte array to read into. - * @param offs Offset in the byte array to read to. - * @param len Number of bytes to read. - */ - def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { - if (pos + len > _size) { - throw new IndexOutOfBoundsException( - s"Read of $len bytes at position $pos would go past size of buffer") - } - var chunkIndex: Int = (pos >> chunkSizeLog2).toInt - var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt - var written: Int = 0 - while (written < len) { - val toRead: Int = math.min(len - written, chunkSize - posInChunk) - System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead) - written += toRead - chunkIndex += 1 - posInChunk = 0 - } - } - - /** - * Write bytes from a byte array into this buffer. - * - * @param pos Offset in the buffer to write to. - * @param bytes Byte array to write from. - * @param offs Offset in the byte array to write from. - * @param len Number of bytes to write. - */ - def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { - if (pos > _size) { - throw new IndexOutOfBoundsException( - s"Write at position $pos starts after end of buffer ${_size}") - } - // Grow if needed - val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt - while (endChunkIndex >= chunks.length) { - chunks += new Array[Byte](chunkSize) - } - - var chunkIndex: Int = (pos >> chunkSizeLog2).toInt - var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt - var written: Int = 0 - while (written < len) { - val toWrite: Int = math.min(len - written, chunkSize - posInChunk) - System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite) - written += toWrite - chunkIndex += 1 - posInChunk = 0 - } - - _size = math.max(_size, pos + len) - } - - /** - * Total size of buffer that can be written to without allocating additional memory. - */ - def capacity: Long = chunks.size.toLong * chunkSize - - /** - * Size of the logical buffer. - */ - def size: Long = _size -} - -/** - * Output stream that writes to a ChainedBuffer. - */ -private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream { - private var pos: Long = 0 - - override def write(b: Int): Unit = { - throw new UnsupportedOperationException() - } - - override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = { - chainedBuffer.write(pos, bytes, offs, len) - pos += len - } -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 2a30f751ff03..ead4d480ebdc 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -128,23 +128,11 @@ private[spark] class ExternalSorter[K, V, C]( // grow internal data structures by growing + copying every time the number of objects doubles. private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000) - private val useSerializedPairBuffer = - ordering.isEmpty && - conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && - ser.supportsRelocationOfSerializedObjects - private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB - private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = { - if (useSerializedPairBuffer) { - new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance) - } else { - new PartitionedPairBuffer[K, C] - } - } // Data structures to store in-memory objects before we spill. Depending on whether we have an // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we // store them in an array buffer. private var map = new PartitionedAppendOnlyMap[K, C] - private var buffer = newBuffer() + private var buffer = new PartitionedPairBuffer[K, C] // Total spilling statistics private var _diskBytesSpilled = 0L @@ -236,7 +224,7 @@ private[spark] class ExternalSorter[K, V, C]( } else { estimatedSize = buffer.estimateSize() if (maybeSpill(buffer, estimatedSize)) { - buffer = newBuffer() + buffer = new PartitionedPairBuffer[K, C] } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala deleted file mode 100644 index 87a786b02d65..000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.InputStream -import java.nio.IntBuffer -import java.util.Comparator - -import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance} -import org.apache.spark.storage.DiskBlockObjectWriter -import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ - -/** - * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes - * its records upon insert and stores them as raw bytes. - * - * We use two data-structures to store the contents. The serialized records are stored in a - * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a - * metadata buffer that stores pointers into the data buffer as well as the partition ID of each - * record. Each entry in the metadata buffer takes up a fixed amount of space. - * - * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not - * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can - * happen without following any pointers, which should minimize cache misses. - * - * Currently, only sorting by partition is supported. - * - * Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across - * two integers: - * - * +-------------+------------+------------+-------------+ - * | keyStart | keyValLen | partitionId | - * +-------------+------------+------------+-------------+ - * - * The buffer can support up to `536870911 (2 ^ 29 - 1)` records. - * - * @param metaInitialRecords The initial number of entries in the metadata buffer. - * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records. - * @param serializerInstance the serializer used for serializing inserted records. - */ -private[spark] class PartitionedSerializedPairBuffer[K, V]( - metaInitialRecords: Int, - kvBlockSize: Int, - serializerInstance: SerializerInstance) - extends WritablePartitionedPairCollection[K, V] with SizeTracker { - - if (serializerInstance.isInstanceOf[JavaSerializerInstance]) { - throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" + - " Java-serialized objects.") - } - - require(metaInitialRecords <= MAXIMUM_RECORDS, - s"Can't make capacity bigger than ${MAXIMUM_RECORDS} records") - private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE) - - private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize) - private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer) - private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream) - - def insert(partition: Int, key: K, value: V): Unit = { - if (metaBuffer.position == metaBuffer.capacity) { - growMetaBuffer() - } - - val keyStart = kvBuffer.size - kvSerializationStream.writeKey[Any](key) - kvSerializationStream.writeValue[Any](value) - kvSerializationStream.flush() - val keyValLen = (kvBuffer.size - keyStart).toInt - - // keyStart, a long, gets split across two ints - metaBuffer.put(keyStart.toInt) - metaBuffer.put((keyStart >> 32).toInt) - metaBuffer.put(keyValLen) - metaBuffer.put(partition) - } - - /** Double the size of the array because we've reached capacity */ - private def growMetaBuffer(): Unit = { - if (metaBuffer.capacity >= MAXIMUM_META_BUFFER_CAPACITY) { - throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_RECORDS} records") - } - val newCapacity = - if (metaBuffer.capacity * 2 < 0 || metaBuffer.capacity * 2 > MAXIMUM_META_BUFFER_CAPACITY) { - // Overflow - MAXIMUM_META_BUFFER_CAPACITY - } else { - metaBuffer.capacity * 2 - } - val newMetaBuffer = IntBuffer.allocate(newCapacity) - newMetaBuffer.put(metaBuffer.array) - metaBuffer = newMetaBuffer - } - - /** Iterate through the data in a given order. For this class this is not really destructive. */ - override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) - : Iterator[((Int, K), V)] = { - sort(keyComparator) - val is = orderedInputStream - val deserStream = serializerInstance.deserializeStream(is) - new Iterator[((Int, K), V)] { - var metaBufferPos = 0 - def hasNext: Boolean = metaBufferPos < metaBuffer.position - def next(): ((Int, K), V) = { - val key = deserStream.readKey[Any]().asInstanceOf[K] - val value = deserStream.readValue[Any]().asInstanceOf[V] - val partition = metaBuffer.get(metaBufferPos + PARTITION) - metaBufferPos += RECORD_SIZE - ((partition, key), value) - } - } - } - - override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity - - override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) - : WritablePartitionedIterator = { - sort(keyComparator) - new WritablePartitionedIterator { - // current position in the meta buffer in ints - var pos = 0 - - def writeNext(writer: DiskBlockObjectWriter): Unit = { - val keyStart = getKeyStartPos(metaBuffer, pos) - val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) - pos += RECORD_SIZE - kvBuffer.read(keyStart, writer, keyValLen) - writer.recordWritten() - } - def nextPartition(): Int = metaBuffer.get(pos + PARTITION) - def hasNext(): Boolean = pos < metaBuffer.position - } - } - - // Visible for testing - def orderedInputStream: OrderedInputStream = { - new OrderedInputStream(metaBuffer, kvBuffer) - } - - private def sort(keyComparator: Option[Comparator[K]]): Unit = { - val comparator = if (keyComparator.isEmpty) { - new Comparator[Int]() { - def compare(partition1: Int, partition2: Int): Int = { - partition1 - partition2 - } - } - } else { - throw new UnsupportedOperationException() - } - - val sorter = new Sorter(new SerializedSortDataFormat) - sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator) - } -} - -private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer) - extends InputStream { - - import PartitionedSerializedPairBuffer._ - - private var metaBufferPos = 0 - private var kvBufferPos = - if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0 - - override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length) - - override def read(bytes: Array[Byte], offs: Int, len: Int): Int = { - if (metaBufferPos >= metaBuffer.position) { - return -1 - } - val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) - - (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt - val toRead = math.min(bytesRemainingInRecord, len) - kvBuffer.read(kvBufferPos, bytes, offs, toRead) - if (toRead == bytesRemainingInRecord) { - metaBufferPos += RECORD_SIZE - if (metaBufferPos < metaBuffer.position) { - kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos) - } - } else { - kvBufferPos += toRead - } - toRead - } - - override def read(): Int = { - throw new UnsupportedOperationException() - } -} - -private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] { - - private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE) - - /** Return the sort key for the element at the given index. */ - override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = { - metaBuffer.get(pos * RECORD_SIZE + PARTITION) - } - - /** Swap two elements. */ - override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = { - val iOff = pos0 * RECORD_SIZE - val jOff = pos1 * RECORD_SIZE - System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE) - System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE) - System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE) - } - - /** Copy a single element from src(srcPos) to dst(dstPos). */ - override def copyElement( - src: IntBuffer, - srcPos: Int, - dst: IntBuffer, - dstPos: Int): Unit = { - val srcOff = srcPos * RECORD_SIZE - val dstOff = dstPos * RECORD_SIZE - System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE) - } - - /** - * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos. - * Overlapping ranges are allowed. - */ - override def copyRange( - src: IntBuffer, - srcPos: Int, - dst: IntBuffer, - dstPos: Int, - length: Int): Unit = { - val srcOff = srcPos * RECORD_SIZE - val dstOff = dstPos * RECORD_SIZE - System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length) - } - - /** - * Allocates a Buffer that can hold up to 'length' elements. - * All elements of the buffer should be considered invalid until data is explicitly copied in. - */ - override def allocate(length: Int): IntBuffer = { - IntBuffer.allocate(length * RECORD_SIZE) - } -} - -private object PartitionedSerializedPairBuffer { - val KEY_START = 0 // keyStart, a long, gets split across two ints - val KEY_VAL_LEN = 2 - val PARTITION = 3 - val RECORD_SIZE = PARTITION + 1 // num ints of metadata - - val MAXIMUM_RECORDS = Int.MaxValue / RECORD_SIZE // (2 ^ 29) - 1 - val MAXIMUM_META_BUFFER_CAPACITY = MAXIMUM_RECORDS * RECORD_SIZE // (2 ^ 31) - 4 - - def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = { - val lower32 = metaBuffer.get(metaBufferPos + KEY_START) - val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1) - (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL) - } -} diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java similarity index 96% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java rename to core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java index 934b7e03050b..232ae4d926bc 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -15,8 +15,9 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; +import org.apache.spark.shuffle.sort.PackedRecordPointer; import org.junit.Test; import static org.junit.Assert.*; @@ -24,7 +25,7 @@ import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*; +import static org.apache.spark.shuffle.sort.PackedRecordPointer.*; public class PackedRecordPointerSuite { diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java similarity index 87% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java rename to core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 40fefe2c9d14..1ef3c5ff64ba 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.util.Arrays; import java.util.Random; @@ -30,7 +30,7 @@ import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; -public class UnsafeShuffleInMemorySorterSuite { +public class ShuffleInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; @@ -40,8 +40,8 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, @Test public void testSortingEmptyInput() { - final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100); - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100); + final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); assert(!iter.hasNext()); } @@ -62,7 +62,7 @@ public void testBasicSorting() throws Exception { new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); final MemoryBlock dataPage = memoryManager.allocatePage(2048); final Object baseObject = dataPage.getBaseObject(); - final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); final HashPartitioner hashPartitioner = new HashPartitioner(4); // Write the records into the data page and store pointers into the sorter @@ -79,7 +79,7 @@ public void testBasicSorting() throws Exception { } // Sort the records - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); int prevPartitionId = -1; Arrays.sort(dataToSort); for (int i = 0; i < dataToSort.length; i++) { @@ -103,7 +103,7 @@ public void testBasicSorting() throws Exception { @Test public void testSortingManyNumbers() throws Exception { - UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); + ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { @@ -112,7 +112,7 @@ public void testSortingManyNumbers() throws Exception { } Arrays.sort(numbersToSort); int[] sorterResult = new int[numbersToSort.length]; - UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); int j = 0; while (iter.hasNext()) { iter.loadNext(); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java similarity index 99% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java rename to core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a266b0c36e0f..718cb8307f2c 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.io.*; import java.nio.ByteBuffer; import java.util.*; +import org.apache.spark.shuffle.sort.UnsafeShuffleHandle; import scala.*; import scala.collection.Iterator; -import scala.reflect.ClassTag; import scala.runtime.AbstractFunction1; import com.google.common.collect.Iterators; @@ -462,7 +462,7 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList>(); - final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; + final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; new Random(42).nextBytes(bytes); dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes))); writer.write(dataToWrite.iterator()); diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 63358172ea1f..993649dc4be4 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -17,8 +17,19 @@ package org.apache.spark +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter import org.scalatest.BeforeAndAfterAll +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.util.Utils + class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with sort-based shuffle. @@ -26,4 +37,66 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { override def beforeAll() { conf.set("spark.shuffle.manager", "sort") } + + test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") { + val tmpDir = Utils.createTempDir() + try { + val myConf = conf.clone() + .set("spark.local.dir", tmpDir.getAbsolutePath) + sc = new SparkContext("local", "test", myConf) + // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new KryoSerializer(myConf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(SortShuffleManager.canUseUnsafeShuffle(shuffleDep)) + def getAllFiles: Set[File] = + FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } finally { + Utils.deleteRecursively(tmpDir) + } + } + + test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") { + val tmpDir = Utils.createTempDir() + try { + val myConf = conf.clone() + .set("spark.local.dir", tmpDir.getAbsolutePath) + sc = new SparkContext("local", "test", myConf) + // Create a shuffled RDD and verify that it will actually use the old SortShuffle path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new JavaSerializer(myConf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!SortShuffleManager.canUseUnsafeShuffle(shuffleDep)) + def getAllFiles: Set[File] = + FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } finally { + Utils.deleteRecursively(tmpDir) + } + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala similarity index 94% rename from core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index 6727934d8c7c..2dc4fb86712d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe +package org.apache.spark.shuffle.sort import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -29,9 +29,9 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are * performed in other suites. */ -class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { +class SortShuffleManagerSuite extends SparkFunSuite with Matchers { - import UnsafeShuffleManager.canUseUnsafeShuffle + import SortShuffleManager.canUseUnsafeShuffle private class RuntimeExceptionAnswer extends Answer[Object] { override def answer(invocation: InvocationOnMock): Object = { @@ -102,7 +102,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { // We do not support shuffles with more than 16 million output partitions assert(!canUseUnsafeShuffle(shuffleDep( - partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1), + partitioner = new HashPartitioner(SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1), serializer = kryo, keyOrdering = None, aggregator = None, diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala deleted file mode 100644 index 6351539e91e9..000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe - -import java.io.File - -import scala.collection.JavaConverters._ - -import org.apache.commons.io.FileUtils -import org.apache.commons.io.filefilter.TrueFileFilter -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite} -import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.util.Utils - -class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { - - // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle. - - override def beforeAll() { - conf.set("spark.shuffle.manager", "tungsten-sort") - // UnsafeShuffleManager requires at least 128 MB of memory per task in order to be able to sort - // shuffle records. - conf.set("spark.shuffle.memoryFraction", "0.5") - } - - test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") { - val tmpDir = Utils.createTempDir() - try { - val myConf = conf.clone() - .set("spark.local.dir", tmpDir.getAbsolutePath) - sc = new SparkContext("local", "test", myConf) - // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path - val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) - val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) - .setSerializer(new KryoSerializer(myConf)) - val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) - def getAllFiles: Set[File] = - FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet - val filesBeforeShuffle = getAllFiles - // Force the shuffle to be performed - shuffledRdd.count() - // Ensure that the shuffle actually created files that will need to be cleaned up - val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle - filesCreatedByShuffle.map(_.getName) should be - Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") - // Check that the cleanup actually removes the files - sc.env.blockManager.master.removeShuffle(0, blocking = true) - for (file <- filesCreatedByShuffle) { - assert (!file.exists(), s"Shuffle file $file was not cleaned up") - } - } finally { - Utils.deleteRecursively(tmpDir) - } - } - - test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") { - val tmpDir = Utils.createTempDir() - try { - val myConf = conf.clone() - .set("spark.local.dir", tmpDir.getAbsolutePath) - sc = new SparkContext("local", "test", myConf) - // Create a shuffled RDD and verify that it will actually use the old SortShuffle path - val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) - val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) - .setSerializer(new JavaSerializer(myConf)) - val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) - def getAllFiles: Set[File] = - FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet - val filesBeforeShuffle = getAllFiles - // Force the shuffle to be performed - shuffledRdd.count() - // Ensure that the shuffle actually created files that will need to be cleaned up - val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle - filesCreatedByShuffle.map(_.getName) should be - Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") - // Check that the cleanup actually removes the files - sc.env.blockManager.master.removeShuffle(0, blocking = true) - for (file <- filesCreatedByShuffle) { - assert (!file.exists(), s"Shuffle file $file was not cleaned up") - } - } finally { - Utils.deleteRecursively(tmpDir) - } - } -} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala deleted file mode 100644 index 05306f408847..000000000000 --- a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.nio.ByteBuffer - -import org.scalatest.Matchers._ - -import org.apache.spark.SparkFunSuite - -class ChainedBufferSuite extends SparkFunSuite { - test("write and read at start") { - // write from start of source array - val buffer = new ChainedBuffer(8) - buffer.capacity should be (0) - verifyWriteAndRead(buffer, 0, 0, 0, 4) - buffer.capacity should be (8) - - // write from middle of source array - verifyWriteAndRead(buffer, 0, 5, 0, 4) - buffer.capacity should be (8) - - // read to middle of target array - verifyWriteAndRead(buffer, 0, 0, 5, 4) - buffer.capacity should be (8) - - // write up to border - verifyWriteAndRead(buffer, 0, 0, 0, 8) - buffer.capacity should be (8) - - // expand into second buffer - verifyWriteAndRead(buffer, 0, 0, 0, 12) - buffer.capacity should be (16) - - // expand into multiple buffers - verifyWriteAndRead(buffer, 0, 0, 0, 28) - buffer.capacity should be (32) - } - - test("write and read at middle") { - val buffer = new ChainedBuffer(8) - - // fill to a middle point - verifyWriteAndRead(buffer, 0, 0, 0, 3) - - // write from start of source array - verifyWriteAndRead(buffer, 3, 0, 0, 4) - buffer.capacity should be (8) - - // write from middle of source array - verifyWriteAndRead(buffer, 3, 5, 0, 4) - buffer.capacity should be (8) - - // read to middle of target array - verifyWriteAndRead(buffer, 3, 0, 5, 4) - buffer.capacity should be (8) - - // write up to border - verifyWriteAndRead(buffer, 3, 0, 0, 5) - buffer.capacity should be (8) - - // expand into second buffer - verifyWriteAndRead(buffer, 3, 0, 0, 12) - buffer.capacity should be (16) - - // expand into multiple buffers - verifyWriteAndRead(buffer, 3, 0, 0, 28) - buffer.capacity should be (32) - } - - test("write and read at later buffer") { - val buffer = new ChainedBuffer(8) - - // fill to a middle point - verifyWriteAndRead(buffer, 0, 0, 0, 11) - - // write from start of source array - verifyWriteAndRead(buffer, 11, 0, 0, 4) - buffer.capacity should be (16) - - // write from middle of source array - verifyWriteAndRead(buffer, 11, 5, 0, 4) - buffer.capacity should be (16) - - // read to middle of target array - verifyWriteAndRead(buffer, 11, 0, 5, 4) - buffer.capacity should be (16) - - // write up to border - verifyWriteAndRead(buffer, 11, 0, 0, 5) - buffer.capacity should be (16) - - // expand into second buffer - verifyWriteAndRead(buffer, 11, 0, 0, 12) - buffer.capacity should be (24) - - // expand into multiple buffers - verifyWriteAndRead(buffer, 11, 0, 0, 28) - buffer.capacity should be (40) - } - - - // Used to make sure we're writing different bytes each time - var rangeStart = 0 - - /** - * @param buffer The buffer to write to and read from. - * @param offsetInBuffer The offset to write to in the buffer. - * @param offsetInSource The offset in the array that the bytes are written from. - * @param offsetInTarget The offset in the array to read the bytes into. - * @param length The number of bytes to read and write - */ - def verifyWriteAndRead( - buffer: ChainedBuffer, - offsetInBuffer: Int, - offsetInSource: Int, - offsetInTarget: Int, - length: Int): Unit = { - val source = new Array[Byte](offsetInSource + length) - (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource) - buffer.write(offsetInBuffer, source, offsetInSource, length) - val target = new Array[Byte](offsetInTarget + length) - buffer.read(offsetInBuffer, target, offsetInTarget, length) - ByteBuffer.wrap(source, offsetInSource, length) should be - (ByteBuffer.wrap(target, offsetInTarget, length)) - - rangeStart += 100 - } -} diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala deleted file mode 100644 index 3b67f6206495..000000000000 --- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} - -import com.google.common.io.ByteStreams - -import org.mockito.Matchers.any -import org.mockito.Mockito._ -import org.mockito.Mockito.RETURNS_SMART_NULLS -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.Matchers._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.storage.DiskBlockObjectWriter - -class PartitionedSerializedPairBufferSuite extends SparkFunSuite { - test("OrderedInputStream single record") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct = SomeStruct("something", 5) - buffer.insert(4, 10, struct) - - val bytes = ByteStreams.toByteArray(buffer.orderedInputStream) - - val baos = new ByteArrayOutputStream() - val stream = serializerInstance.serializeStream(baos) - stream.writeObject(10) - stream.writeObject(struct) - stream.close() - - baos.toByteArray should be (bytes) - } - - test("insert single record") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct = SomeStruct("something", 5) - buffer.insert(4, 10, struct) - val elements = buffer.partitionedDestructiveSortedIterator(None).toArray - elements.size should be (1) - elements.head should be (((4, 10), struct)) - } - - test("insert multiple records") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct1 = SomeStruct("something1", 8) - buffer.insert(6, 1, struct1) - val struct2 = SomeStruct("something2", 9) - buffer.insert(4, 2, struct2) - val struct3 = SomeStruct("something3", 10) - buffer.insert(5, 3, struct3) - - val elements = buffer.partitionedDestructiveSortedIterator(None).toArray - elements.size should be (3) - elements(0) should be (((4, 2), struct2)) - elements(1) should be (((5, 3), struct3)) - elements(2) should be (((6, 1), struct1)) - } - - test("write single record") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct = SomeStruct("something", 5) - buffer.insert(4, 10, struct) - val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val (writer, baos) = createMockWriter() - assert(it.hasNext) - it.nextPartition should be (4) - it.writeNext(writer) - assert(!it.hasNext) - - val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) - stream.readObject[AnyRef]() should be (10) - stream.readObject[AnyRef]() should be (struct) - } - - test("write multiple records") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct1 = SomeStruct("something1", 8) - buffer.insert(6, 1, struct1) - val struct2 = SomeStruct("something2", 9) - buffer.insert(4, 2, struct2) - val struct3 = SomeStruct("something3", 10) - buffer.insert(5, 3, struct3) - - val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val (writer, baos) = createMockWriter() - assert(it.hasNext) - it.nextPartition should be (4) - it.writeNext(writer) - assert(it.hasNext) - it.nextPartition should be (5) - it.writeNext(writer) - assert(it.hasNext) - it.nextPartition should be (6) - it.writeNext(writer) - assert(!it.hasNext) - - val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) - val iter = stream.asIterator - iter.next() should be (2) - iter.next() should be (struct2) - iter.next() should be (3) - iter.next() should be (struct3) - iter.next() should be (1) - iter.next() should be (struct1) - assert(!iter.hasNext) - } - - def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = { - val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS) - val baos = new ByteArrayOutputStream() - when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] { - override def answer(invocationOnMock: InvocationOnMock): Unit = { - val args = invocationOnMock.getArguments - val bytes = args(0).asInstanceOf[Array[Byte]] - val offset = args(1).asInstanceOf[Int] - val length = args(2).asInstanceOf[Int] - baos.write(bytes, offset, length) - } - }) - (writer, baos) - } -} - -case class SomeStruct(str: String, num: Int) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 029f2264a6a2..f0a9b099eb14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -22,7 +22,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree @@ -88,10 +87,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // fewer partitions (like RangePartitioner, for example). val conf = child.sqlContext.sparkContext.conf val shuffleManager = SparkEnv.get.shuffleManager - val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] || - shuffleManager.isInstanceOf[UnsafeShuffleManager] + val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) if (sortBasedShuffleOn) { val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { @@ -100,11 +97,11 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false - } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) { - // SPARK-4550 extended sort-based shuffle to serialize individual records prior to sorting - // them. This optimization is guarded by a feature-flag and is only applied in cases where - // shuffle dependency does not specify an aggregator or ordering and the record serializer - // has certain properties. If this optimization is enabled, we can safely avoid the copy. + } else if (serializer.supportsRelocationOfSerializedObjects) { + // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records + // prior to sorting them. This optimization is only applied in cases where shuffle + // dependency does not specify an aggregator or ordering and the record serializer has + // certain properties. If this optimization is enabled, we can safely avoid the copy. // // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only // need to check whether the optimization is enabled and supported by our serializer. @@ -130,7 +127,6 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf private val serializer: Serializer = { - val rowDataTypes = child.output.map(_.dataType).toArray if (tungstenMode) { new UnsafeRowSerializer(child.output.size) } else { From 26ecf5cfda50238aa86e8971296ebf1774d2a176 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Sep 2015 14:51:58 -0700 Subject: [PATCH 03/26] Fix MiMa. --- project/MimaExcludes.scala | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b2e6be706637..023f38489e18 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -78,6 +78,28 @@ object MimaExcludes { "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.regression.LeastSquaresCostFun.this") + ) ++ Seq( + // SPARK-10708: Consolidate sort shuffle implementations + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.OrderedInputStream"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.ChainedBuffer"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.PartitionedSerializedPairBuffer"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.ChainedBufferOutputStream"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.PartitionedSerializedPairBuffer$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.SerializedSortDataFormat"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.shuffle.unsafe.UnsafeShuffleWriter"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.shuffle.unsafe.UnsafeShuffleHandle"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager$") ) case v if v.startsWith("1.5") => Seq( From 67de10aaa16aa6eb5971aaa8add0172f35e1194a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Sep 2015 15:00:18 -0700 Subject: [PATCH 04/26] Update comment. --- .../org/apache/spark/util/collection/ExternalSorter.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index ead4d480ebdc..9320e545c30d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -69,8 +69,8 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} * At a high level, this class works internally as follows: * * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if - * we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we - * don't. Inside these buffers, we sort elements by partition ID and then possibly also by key. + * we want to combine by key, or a PartitionedPairBuffer if we don't. + * Inside these buffers, we sort elements by partition ID and then possibly also by key. * To avoid calling the partitioner multiple times with each key, we store the partition ID * alongside each record. * From af2794ce590137c07a0e88445e621dfb3fa67ccf Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Sep 2015 15:19:30 -0700 Subject: [PATCH 05/26] Comment updates. --- .../shuffle/sort/UnsafeShuffleWriter.java | 8 ++-- .../shuffle/sort/SortShuffleManager.scala | 38 ++++++++++--------- .../sort/UnsafeShuffleWriterSuite.java | 6 +-- .../org/apache/spark/SortShuffleSuite.scala | 4 +- .../sort/SortShuffleManagerSuite.scala | 26 +++++++------ 5 files changed, 44 insertions(+), 38 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 1403468e149e..ed9e9452d284 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -23,7 +23,7 @@ import java.util.Iterator; import org.apache.spark.shuffle.sort.SortShuffleManager; -import org.apache.spark.shuffle.sort.UnsafeShuffleHandle; +import org.apache.spark.shuffle.sort.SerializedShuffleHandle; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -106,15 +106,15 @@ public UnsafeShuffleWriter( IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager, - UnsafeShuffleHandle handle, + SerializedShuffleHandle handle, int mapId, TaskContext taskContext, SparkConf sparkConf) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); - if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { + if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( "UnsafeShuffleWriter can only be used for shuffles with at most " + - SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions"); } this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 7da9c990440e..14a17376daf5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -24,9 +24,10 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ /** - * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * serialized shuffle. */ -private[spark] class UnsafeShuffleHandle[K, V]( +private[spark] class SerializedShuffleHandle[K, V]( shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, V]) @@ -36,30 +37,34 @@ private[spark] class UnsafeShuffleHandle[K, V]( private[spark] object SortShuffleManager extends Logging { /** - * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. + * The maximum number of shuffle output partitions that SortShuffleManager supports when + * */ - val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE = + PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 /** - * Helper method for determining whether a shuffle should use the optimized unsafe shuffle - * path or whether it should fall back to the original sort-based shuffle. + * Helper method for determining whether a shuffle should use an optimized serialized shuffle + * path or whether it should fall back to the original path that operates on deserialized objects. */ - def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { + def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = { val shufId = dependency.shuffleId + val numPartitions = dependency.partitioner.numPartitions val serializer = Serializer.getSerializer(dependency.serializer) if (!serializer.supportsRelocationOfSerializedObjects) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + + log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + s"${serializer.getClass.getName}, does not support object relocation") false } else if (dependency.aggregator.isDefined) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") + log.debug( + s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") false - } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + - s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") + } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions") false } else { - log.debug(s"Can use UnsafeShuffle for shuffle $shufId") + log.debug(s"Can use serialized shuffle for shuffle $shufId") true } } @@ -128,8 +133,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (SortShuffleManager.canUseUnsafeShuffle(dependency)) { - new UnsafeShuffleHandle[K, V]( + if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + new SerializedShuffleHandle[K, V]( shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) } else { new BaseShuffleHandle(shuffleId, numMaps, dependency) @@ -145,7 +150,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - // We currently use the same block store shuffle fetcher as the hash-based shuffle. new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } @@ -158,7 +162,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager numMapsForShuffle.putIfAbsent( handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) handle match { - case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] => + case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => val env = SparkEnv.get new UnsafeShuffleWriter( env.blockManager, diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 718cb8307f2c..2228612fb3db 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -21,7 +21,7 @@ import java.nio.ByteBuffer; import java.util.*; -import org.apache.spark.shuffle.sort.UnsafeShuffleHandle; +import org.apache.spark.shuffle.sort.SerializedShuffleHandle; import scala.*; import scala.collection.Iterator; import scala.runtime.AbstractFunction1; @@ -205,7 +205,7 @@ private UnsafeShuffleWriter createWriter( shuffleBlockResolver, taskMemoryManager, shuffleMemoryManager, - new UnsafeShuffleHandle(0, 1, shuffleDep), + new SerializedShuffleHandle(0, 1, shuffleDep), 0, // map id taskContext, conf @@ -517,7 +517,7 @@ public void testPeakMemoryUsed() throws Exception { shuffleBlockResolver, taskMemoryManager, shuffleMemoryManager, - new UnsafeShuffleHandle<>(0, 1, shuffleDep), + new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf); diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 993649dc4be4..7898eba23211 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -49,7 +49,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) .setSerializer(new KryoSerializer(myConf)) val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - assert(SortShuffleManager.canUseUnsafeShuffle(shuffleDep)) + assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep)) def getAllFiles: Set[File] = FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet val filesBeforeShuffle = getAllFiles @@ -80,7 +80,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) .setSerializer(new JavaSerializer(myConf)) val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - assert(!SortShuffleManager.canUseUnsafeShuffle(shuffleDep)) + assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep)) def getAllFiles: Set[File] = FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet val filesBeforeShuffle = getAllFiles diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index 2dc4fb86712d..d08c3905235d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} */ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { - import SortShuffleManager.canUseUnsafeShuffle + import SortShuffleManager.canUseSerializedShuffle private class RuntimeExceptionAnswer extends Answer[Object] { override def answer(invocation: InvocationOnMock): Object = { @@ -55,10 +55,10 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { dep } - test("supported shuffle dependencies") { + test("supported shuffle dependencies for serialized shuffle") { val kryo = Some(new KryoSerializer(new SparkConf())) - assert(canUseUnsafeShuffle(shuffleDep( + assert(canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = None, @@ -68,7 +68,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]]) when(rangePartitioner.numPartitions).thenReturn(2) - assert(canUseUnsafeShuffle(shuffleDep( + assert(canUseSerializedShuffle(shuffleDep( partitioner = rangePartitioner, serializer = kryo, keyOrdering = None, @@ -77,7 +77,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { ))) // Shuffles with key orderings are supported as long as no aggregator is specified - assert(canUseUnsafeShuffle(shuffleDep( + assert(canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = Some(mock(classOf[Ordering[Any]])), @@ -87,12 +87,12 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { } - test("unsupported shuffle dependencies") { + test("unsupported shuffle dependencies for serialized shuffle") { val kryo = Some(new KryoSerializer(new SparkConf())) val java = Some(new JavaSerializer(new SparkConf())) // We only support serializers that support object relocation - assert(!canUseUnsafeShuffle(shuffleDep( + assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = java, keyOrdering = None, @@ -100,9 +100,11 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) - // We do not support shuffles with more than 16 million output partitions - assert(!canUseUnsafeShuffle(shuffleDep( - partitioner = new HashPartitioner(SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1), + // The serialized shuffle path do not support shuffles with more than 16 million output + // partitions, due to a limitation in its sorter implementatino. + assert(!canUseSerializedShuffle(shuffleDep( + partitioner = new HashPartitioner( + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1), serializer = kryo, keyOrdering = None, aggregator = None, @@ -110,14 +112,14 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { ))) // We do not support shuffles that perform aggregation - assert(!canUseUnsafeShuffle(shuffleDep( + assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = None, aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), mapSideCombine = false ))) - assert(!canUseUnsafeShuffle(shuffleDep( + assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = Some(mock(classOf[Ordering[Any]])), From 7a13b994f709a9eca43e049283fba0007ce5bb6f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Sep 2015 15:21:08 -0700 Subject: [PATCH 06/26] Typo fix. --- .../org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index d08c3905235d..8744a072cb3f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -101,7 +101,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { ))) // The serialized shuffle path do not support shuffles with more than 16 million output - // partitions, due to a limitation in its sorter implementatino. + // partitions, due to a limitation in its sorter implementation. assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner( SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1), From e6fdd46e1d3af0720df1ca0d94166787dfc3d601 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Sep 2015 15:07:30 -0700 Subject: [PATCH 07/26] [SPARK-10403] Allow UnsafeRowSerializer to work with tungsten-sort ShuffleManager. --- .../sql/execution/UnsafeRowSerializer.scala | 22 +++++++++---------- .../execution/UnsafeRowSerializerSuite.scala | 15 ++++++++++++- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index e060c06d9e2a..7e981268de39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -45,16 +45,9 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S } private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance { - - /** - * Marks the end of a stream written with [[serializeStream()]]. - */ - private[this] val EOF: Int = -1 - /** * Serializes a stream of UnsafeRows. Within the stream, each record consists of a record * length (stored as a 4-byte integer, written high byte first), followed by the record's bytes. - * The end of the stream is denoted by a record with the special length `EOF` (-1). */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) @@ -92,7 +85,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def close(): Unit = { writeBuffer = null - dOut.writeInt(EOF) dOut.close() } } @@ -104,12 +96,20 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) private[this] var row: UnsafeRow = new UnsafeRow() private[this] var rowTuple: (Int, UnsafeRow) = (0, row) + private[this] val EOF: Int = -1 override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { new Iterator[(Int, UnsafeRow)] { - private[this] var rowSize: Int = dIn.readInt() - if (rowSize == EOF) dIn.close() + private[this] def readSize(): Int = try { + dIn.readInt() + } catch { + case e: EOFException => + dIn.close() + EOF + } + + private[this] var rowSize: Int = readSize() override def hasNext: Boolean = rowSize != EOF override def next(): (Int, UnsafeRow) = { @@ -118,7 +118,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) - rowSize = dIn.readInt() // read the next row's size + rowSize = readSize() if (rowSize == EOF) { // We are returning the last row in this stream dIn.close() val _rowTuple = rowTuple diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 0113d052e338..deddecb6ef96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.rdd.RDD import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.Utils @@ -41,7 +42,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea } } -class UnsafeRowSerializerSuite extends SparkFunSuite { +class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { val converter = unsafeRowConverter(schema) @@ -143,4 +144,16 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { } } } + + test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") { + val conf = new SparkConf() + .set("spark.shuffle.manager", "tungsten-sort") + sc = new SparkContext("local", "test", conf) + val row = Row("Hello", 123) + val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) + val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))) + .asInstanceOf[RDD[Product2[Int, InternalRow]]] + val shuffled = new ShuffledRowRDD(rowsRDD, new UnsafeRowSerializer(2), 2) + shuffled.count() + } } From 5419ca4d4c81f9fab79b39e5b2c8ba771432c68c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Sep 2015 15:51:05 -0700 Subject: [PATCH 08/26] Fix. --- .../spark/sql/execution/UnsafeRowSerializerSuite.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index deddecb6ef96..f7d48bc53ebb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.rdd.RDD @@ -88,11 +88,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } test("close empty input stream") { - val baos = new ByteArrayOutputStream() - val dout = new DataOutputStream(baos) - dout.writeInt(-1) // EOF - dout.flush() - val input = new ClosableByteArrayInputStream(baos.toByteArray) + val input = new ClosableByteArrayInputStream(Array.empty) val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator assert(!deserializerIter.hasNext) From 3ffa137a3510ae1d5443d664f9e93756b4c06da9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 24 Sep 2015 13:11:05 -0700 Subject: [PATCH 09/26] Fix NPEs in DAGScheduler suite. --- .../spark/scheduler/DAGSchedulerSuite.scala | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6b5bcf0574de..96f72c563c4c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -351,7 +351,7 @@ class DAGSchedulerSuite */ test("getMissingParentStages should consider all ancestor RDDs' cache statuses") { val rddA = new MyRDD(sc, 1, Nil) - val rddB = new MyRDD(sc, 1, List(new ShuffleDependency(rddA, null))) + val rddB = new MyRDD(sc, 1, List(new ShuffleDependency(rddA, new HashPartitioner(2)))) val rddC = new MyRDD(sc, 1, List(new OneToOneDependency(rddB))).cache() val rddD = new MyRDD(sc, 1, List(new OneToOneDependency(rddC))) cacheLocations(rddC.id -> 0) = @@ -458,7 +458,7 @@ class DAGSchedulerSuite test("run trivial shuffle") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) submit(reduceRdd, Array(0)) @@ -474,7 +474,7 @@ class DAGSchedulerSuite test("run trivial shuffle with fetch failure") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) @@ -590,7 +590,7 @@ class DAGSchedulerSuite val parts = 8 val shuffleMapRdd = new MyRDD(sc, parts, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, parts, List(shuffleDep)) submit(reduceRdd, (0 until parts).toArray) @@ -625,7 +625,7 @@ class DAGSchedulerSuite setupStageAbortTest(sc) val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) @@ -668,9 +668,9 @@ class DAGSchedulerSuite setupStageAbortTest(sc) val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache() - val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, new HashPartitioner(2)) val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache() - val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, new HashPartitioner(2)) val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) submit(finalRdd, Array(0)) @@ -717,9 +717,9 @@ class DAGSchedulerSuite setupStageAbortTest(sc) val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache() - val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, new HashPartitioner(2)) val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache() - val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, new HashPartitioner(2)) val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) submit(finalRdd, Array(0)) @@ -777,7 +777,7 @@ class DAGSchedulerSuite test("trivial shuffle with multiple fetch failures") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) @@ -818,7 +818,7 @@ class DAGSchedulerSuite */ test("late fetch failures don't cause multiple concurrent attempts for the same map stage") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) @@ -886,7 +886,7 @@ class DAGSchedulerSuite test("extremely late fetch failures don't cause multiple concurrent attempts for " + "the same stage") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) @@ -949,7 +949,7 @@ class DAGSchedulerSuite test("ignore late map task completions") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) @@ -1018,7 +1018,7 @@ class DAGSchedulerSuite test("run shuffle with map stage failure") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) @@ -1043,10 +1043,10 @@ class DAGSchedulerSuite */ test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") { val firstRDD = new MyRDD(sc, 3, Nil) - val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) val firstShuffleId = firstShuffleDep.shuffleId val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) submit(reduceRdd, Array(0)) @@ -1156,7 +1156,7 @@ class DAGSchedulerSuite */ test("register map outputs correctly after ExecutorLost and task Resubmitted") { val firstRDD = new MyRDD(sc, 3, Nil) - val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep)) submit(reduceRdd, Array(0)) @@ -1221,9 +1221,9 @@ class DAGSchedulerSuite */ test("failure of stage used by two jobs") { val shuffleMapRdd1 = new MyRDD(sc, 2, Nil) - val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, null) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) val shuffleMapRdd2 = new MyRDD(sc, 2, Nil) - val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, null) + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) val reduceRdd1 = new MyRDD(sc, 2, List(shuffleDep1)) val reduceRdd2 = new MyRDD(sc, 2, List(shuffleDep1, shuffleDep2)) @@ -1258,7 +1258,7 @@ class DAGSchedulerSuite test("run trivial shuffle with out-of-band failure and retry") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) submit(reduceRdd, Array(0)) @@ -1281,9 +1281,9 @@ class DAGSchedulerSuite test("recursive shuffle failures") { val shuffleOneRdd = new MyRDD(sc, 2, Nil) - val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, new HashPartitioner(2)) val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)) - val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, new HashPartitioner(2)) val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) submit(finalRdd, Array(0)) // have the first stage complete normally @@ -1310,9 +1310,9 @@ class DAGSchedulerSuite test("cached post-shuffle") { val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache() - val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, new HashPartitioner(2)) val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache() - val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, new HashPartitioner(2)) val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) submit(finalRdd, Array(0)) cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) @@ -1419,7 +1419,7 @@ class DAGSchedulerSuite test("reduce tasks should be placed locally with map output") { // Create an shuffleMapRdd with 1 partition val shuffleMapRdd = new MyRDD(sc, 1, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) submit(reduceRdd, Array(0)) @@ -1440,7 +1440,7 @@ class DAGSchedulerSuite val numMapTasks = 4 // Create an shuffleMapRdd with more partitions val shuffleMapRdd = new MyRDD(sc, numMapTasks, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) submit(reduceRdd, Array(0)) @@ -1464,7 +1464,7 @@ class DAGSchedulerSuite // Create an RDD that has both a shuffle dependency and a narrow dependency (e.g. for a join) val rdd1 = new MyRDD(sc, 1, Nil) val rdd2 = new MyRDD(sc, 1, Nil, locations = Seq(Seq("hostB"))) - val shuffleDep = new ShuffleDependency(rdd1, null) + val shuffleDep = new ShuffleDependency(rdd1, new HashPartitioner(2)) val narrowDep = new OneToOneDependency(rdd2) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 1, List(shuffleDep, narrowDep)) From 0e8b05ef22174297b36771f42deea3c4282a3690 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 24 Sep 2015 16:09:41 -0700 Subject: [PATCH 10/26] Fix ExternalSorterSuite --- .../apache/spark/util/collection/ExternalSorterSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index bdb0f4d507a7..b74387bfce6d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -174,7 +174,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def testSpillingInLocalCluster(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.memoryFraction", "0.01") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) @@ -252,7 +252,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) { - conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.memoryFraction", "0.01") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) From 8fe9094e038c92a06d44c456b8e640294fd877cb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 24 Sep 2015 16:18:19 -0700 Subject: [PATCH 11/26] Fix tests in UnsafeRowSerializerSuite --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 2 +- .../spark/sql/execution/UnsafeRowSerializerSuite.scala | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index c6fef7f91f00..261ead4f1f4d 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -318,7 +318,7 @@ object SparkEnv extends Logging { val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", - "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") + "tungsten-sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index f7d48bc53ebb..e28e6313c15e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -101,7 +101,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten Utils.tryWithSafeFinally { val conf = new SparkConf() - .set("spark.shuffle.spill.initialMemoryThreshold", "1024") + .set("spark.shuffle.spill.initialMemoryThreshold", "1") .set("spark.shuffle.sort.bypassMergeThreshold", "0") .set("spark.shuffle.memoryFraction", "0.0001") @@ -109,7 +109,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") // prepare data val converter = unsafeRowConverter(Array(IntegerType)) - val data = (1 to 1000).iterator.map { i => + val data = (1 to 10000).iterator.map { i => (i, converter(Row(i))) } val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( @@ -141,9 +141,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } } - test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") { - val conf = new SparkConf() - .set("spark.shuffle.manager", "tungsten-sort") + test("SPARK-10403: unsafe row serializer with SortShuffleManager") { + val conf = new SparkConf().set("spark.shuffle.manager", "sort") sc = new SparkContext("local", "test", conf) val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) From 18bd9b7406b6ca214720ed1ffc08b1afdce0da4c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 14 Oct 2015 13:49:33 -0700 Subject: [PATCH 12/26] Remove documentation reference to tungsten-sort --- docs/configuration.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 771d93be04b0..392bcd09bd14 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -437,12 +437,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.manager sort - Implementation to use for shuffling data. There are three implementations available: - sort, hash and the new (1.5+) tungsten-sort. + Implementation to use for shuffling data. There are two implementations available: + sort and hash. Sort-based shuffle is more memory-efficient and is the default option starting in 1.2. - Tungsten-sort is similar to the sort based shuffle, with a direct binary cache-friendly - implementation with a fall back to regular sort based shuffle if its requirements are not - met. From ddf858c3ae9f5e4a87e9e26ce773b78d1e63e9b8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 14 Oct 2015 14:00:39 -0700 Subject: [PATCH 13/26] Update comments in Exchange. --- .../scala/org/apache/spark/sql/execution/Exchange.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 5470c06243fc..f68822b3792d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -107,14 +107,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only // need to check whether the optimization is enabled and supported by our serializer. - // - // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081). false } else { - // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code - // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls - // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In - // both cases, we must copy. + // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must + // copy. true } } else if (shuffleManager.isInstanceOf[HashShuffleManager]) { From 37aedce8b32b3b38092ec62af1b70b4d4f7a762a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 14 Oct 2015 14:02:15 -0700 Subject: [PATCH 14/26] More comment updates --- core/src/test/scala/org/apache/spark/SortShuffleSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 7898eba23211..037d92d69248 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -44,7 +44,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { val myConf = conf.clone() .set("spark.local.dir", tmpDir.getAbsolutePath) sc = new SparkContext("local", "test", myConf) - // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path + // Create a shuffled RDD and verify that it actually uses the new serialized map output path val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) .setSerializer(new KryoSerializer(myConf)) @@ -75,7 +75,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { val myConf = conf.clone() .set("spark.local.dir", tmpDir.getAbsolutePath) sc = new SparkContext("local", "test", myConf) - // Create a shuffled RDD and verify that it will actually use the old SortShuffle path + // Create a shuffled RDD and verify that it actually uses the old deserialized map output path val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) .setSerializer(new JavaSerializer(myConf)) From e2792759788f2476da634453dd9f6c3c3f0fd2ba Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 15 Oct 2015 16:54:29 -0700 Subject: [PATCH 15/26] Update bypass merge sort behavior. --- .../sort/BypassMergeSortShuffleWriter.java | 100 +++++++++++++----- .../shuffle/sort/SortShuffleFileWriter.java | 53 ---------- .../shuffle/sort/UnsafeShuffleWriter.java | 2 - .../shuffle/sort/SortShuffleManager.scala | 35 +++++- .../shuffle/sort/SortShuffleWriter.scala | 12 +-- .../util/collection/ExternalSorter.scala | 9 +- .../BypassMergeSortShuffleWriterSuite.scala | 64 ++++++----- 7 files changed, 151 insertions(+), 124 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index f5d80bbcf355..c97b3785f40d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -22,6 +22,7 @@ import java.io.FileOutputStream; import java.io.IOException; +import scala.Option; import scala.Product2; import scala.Tuple2; import scala.collection.Iterator; @@ -31,14 +32,21 @@ import org.slf4j.LoggerFactory; import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; +import javax.annotation.Nullable; + /** * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path * writes incoming records to separate files, one file per reduce partition, then concatenates these @@ -62,7 +70,7 @@ *

* There have been proposals to completely remove this code path; see SPARK-6026 for details. */ -final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { +final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); @@ -72,31 +80,52 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< private final BlockManager blockManager; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; + private final int shuffleId; + private final int mapId; private final Serializer serializer; + private final IndexShuffleBlockResolver shuffleBlockResolver; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; + @Nullable private MapStatus mapStatus; + private long[] partitionLengths; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; public BypassMergeSortShuffleWriter( - SparkConf conf, BlockManager blockManager, - Partitioner partitioner, - ShuffleWriteMetrics writeMetrics, - Serializer serializer) { + IndexShuffleBlockResolver shuffleBlockResolver, + BypassMergeSortShuffleHandle handle, + int mapId, + TaskContext taskContext, + SparkConf conf) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); - this.numPartitions = partitioner.numPartitions(); this.blockManager = blockManager; - this.partitioner = partitioner; - this.writeMetrics = writeMetrics; - this.serializer = serializer; + final ShuffleDependency dep = handle.dependency(); + this.mapId = mapId; + this.shuffleId = dep.shuffleId(); + this.partitioner = dep.partitioner(); + this.numPartitions = partitioner.numPartitions(); + this.writeMetrics = new ShuffleWriteMetrics(); + taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.serializer = Serializer.getSerializer(dep.serializer()); + this.shuffleBlockResolver = shuffleBlockResolver; } @Override - public void insertAll(Iterator> records) throws IOException { + public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { + partitionLengths = new long[numPartitions]; + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -124,13 +153,19 @@ public void insertAll(Iterator> records) throws IOException { for (DiskBlockObjectWriter writer : partitionWriters) { writer.commitAndClose(); } + + partitionLengths = + writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId)); + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } - @Override - public long[] writePartitionedFile( - BlockId blockId, - TaskContext context, - File outputFile) throws IOException { + // Exposed for testing + long[] getPartitionLengths() { + return partitionLengths; + } + + private long[] writePartitionedFile(File outputFile) throws IOException { // Track location of the partition starts in the output file final long[] lengths = new long[numPartitions]; if (partitionWriters == null) { @@ -165,18 +200,33 @@ public long[] writePartitionedFile( } @Override - public void stop() throws IOException { - if (partitionWriters != null) { - try { - for (DiskBlockObjectWriter writer : partitionWriters) { - // This method explicitly does _not_ throw exceptions: - File file = writer.revertPartialWritesAndClose(); - if (!file.delete()) { - logger.error("Error while deleting file {}", file.getAbsolutePath()); + public Option stop(boolean success) { + if (stopping) { + return Option.apply(null); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + if (partitionWriters != null) { + try { + for (DiskBlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + File file = writer.revertPartialWritesAndClose(); + if (!file.delete()) { + logger.error("Error while deleting file {}", file.getAbsolutePath()); + } + } + } finally { + partitionWriters = null; } } - } finally { - partitionWriters = null; + shuffleBlockResolver.removeDataByMap(shuffleId, mapId); + return Option.apply(null); } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java deleted file mode 100644 index 656ea0401a14..000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort; - -import java.io.File; -import java.io.IOException; - -import scala.Product2; -import scala.collection.Iterator; - -import org.apache.spark.annotation.Private; -import org.apache.spark.TaskContext; -import org.apache.spark.storage.BlockId; - -/** - * Interface for objects that {@link SortShuffleWriter} uses to write its output files. - */ -@Private -public interface SortShuffleFileWriter { - - void insertAll(Iterator> records) throws IOException; - - /** - * Write all the data added into this shuffle sorter into a file in the disk store. This is - * called by the SortShuffleWriter and can go through an efficient path of just concatenating - * binary files if we decided to avoid merge-sorting. - * - * @param blockId block ID to write to. The index file will be blockId.name + ".index". - * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. - * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) - */ - long[] writePartitionedFile( - BlockId blockId, - TaskContext context, - File outputFile) throws IOException; - - void stop() throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index ed9e9452d284..e8f050cb2dab 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -22,8 +22,6 @@ import java.nio.channels.FileChannel; import java.util.Iterator; -import org.apache.spark.shuffle.sort.SortShuffleManager; -import org.apache.spark.shuffle.sort.SerializedShuffleHandle; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 14a17376daf5..35cd8ee91a6c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -34,6 +34,17 @@ private[spark] class SerializedShuffleHandle[K, V]( extends BaseShuffleHandle(shuffleId, numMaps, dependency) { } +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * bypass merge sort shuffle path. + */ +private[spark] class BypassMergeSortShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} + private[spark] object SortShuffleManager extends Logging { /** @@ -133,7 +144,19 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + if (SortShuffleWriter.shouldBypassMergeSort( + SparkEnv.get.conf, + dependency.partitioner.numPartitions, + aggregator = None, + keyOrdering = None)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need local aggregation and sorting, write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { new SerializedShuffleHandle[K, V]( shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) } else { @@ -161,9 +184,9 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager context: TaskContext): ShuffleWriter[K, V] = { numMapsForShuffle.putIfAbsent( handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) + val env = SparkEnv.get handle match { case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => - val env = SparkEnv.get new UnsafeShuffleWriter( env.blockManager, shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], @@ -173,6 +196,14 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager mapId, context, env.conf) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new BypassMergeSortShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + bypassMergeSortHandle, + mapId, + context, + env.conf) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 5865e7640c1c..6b38007a9fdc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -20,7 +20,6 @@ package org.apache.spark.shuffle.sort import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -36,7 +35,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val blockManager = SparkEnv.get.blockManager - private var sorter: SortShuffleFileWriter[K, V] = null + private var sorter: ExternalSorter[K, V, _] = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -54,15 +53,6 @@ private[spark] class SortShuffleWriter[K, V, C]( require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - } else if (SortShuffleWriter.shouldBypassMergeSort( - SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) { - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't - // need local aggregation and sorting, write numPartitions files directly and just concatenate - // them at the end. This avoids doing serialization and deserialization twice to merge - // together the spilled files, which would happen with the normal code path. The downside is - // having multiple files open at a time and thus more memory allocated to buffers. - new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner, - writeMetrics, Serializer.getSerializer(dep.serializer)) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 4981bc44c933..754d7705d127 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -29,7 +29,7 @@ import com.google.common.io.ByteStreams import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter} +import org.apache.spark.shuffle.sort.SortShuffleWriter import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** @@ -93,8 +93,7 @@ private[spark] class ExternalSorter[K, V, C]( ordering: Option[Ordering[K]] = None, serializer: Option[Serializer] = None) extends Logging - with Spillable[WritablePartitionedPairCollection[K, C]] - with SortShuffleFileWriter[K, V] { + with Spillable[WritablePartitionedPairCollection[K, C]] { private val conf = SparkEnv.get.conf @@ -180,7 +179,7 @@ private[spark] class ExternalSorter[K, V, C]( */ private[spark] def numSpills: Int = spills.size - override def insertAll(records: Iterator[Product2[K, V]]): Unit = { + def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined @@ -647,7 +646,7 @@ private[spark] class ExternalSorter[K, V, C]( * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ - override def writePartitionedFile( + def writePartitionedFile( blockId: BlockId, context: TaskContext, outputFile: File): Array[Long] = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 341f56df2daf..3899856c5cc7 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -33,7 +33,8 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} -import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer} +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.serializer.{JavaSerializer, SerializerInstance} import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -42,25 +43,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _ private var taskMetrics: TaskMetrics = _ - private var shuffleWriteMetrics: ShuffleWriteMetrics = _ private var tempDir: File = _ private var outputFile: File = _ private val conf: SparkConf = new SparkConf(loadDefaults = false) private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] - private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0) - private val serializer: Serializer = new JavaSerializer(conf) + private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ override def beforeEach(): Unit = { tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) - shuffleWriteMetrics = new ShuffleWriteMetrics taskMetrics = new TaskMetrics - taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) MockitoAnnotations.initMocks(this) + shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int]( + shuffleId = 0, + numMaps = 2, + dependency = dependency + ) + when(dependency.partitioner).thenReturn(new HashPartitioner(7)) + when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf))) when(taskContext.taskMetrics()).thenReturn(taskMetrics) + when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(blockManager.getDiskWriter( any[BlockId], @@ -107,18 +114,20 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("write empty iterator") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( - new SparkConf(loadDefaults = false), blockManager, - new HashPartitioner(7), - shuffleWriteMetrics, - serializer + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf ) - writer.insertAll(Iterator.empty) - val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) - assert(partitionLengths.sum === 0) + writer.write(Iterator.empty) + writer.stop(true) + assert(writer.getPartitionLengths.sum === 0) assert(outputFile.exists()) assert(outputFile.length() === 0) assert(temporaryFilesCreated.isEmpty) + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get assert(shuffleWriteMetrics.shuffleBytesWritten === 0) assert(shuffleWriteMetrics.shuffleRecordsWritten === 0) assert(taskMetrics.diskBytesSpilled === 0) @@ -129,17 +138,19 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte def records: Iterator[(Int, Int)] = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) val writer = new BypassMergeSortShuffleWriter[Int, Int]( - new SparkConf(loadDefaults = false), blockManager, - new HashPartitioner(7), - shuffleWriteMetrics, - serializer + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf ) - writer.insertAll(records) + writer.write(records) + writer.stop(true) assert(temporaryFilesCreated.nonEmpty) - val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) - assert(partitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.sum === outputFile.length()) assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length()) assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length) assert(taskMetrics.diskBytesSpilled === 0) @@ -148,14 +159,15 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("cleanup of intermediate files after errors") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( - new SparkConf(loadDefaults = false), blockManager, - new HashPartitioner(7), - shuffleWriteMetrics, - serializer + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf ) intercept[SparkException] { - writer.insertAll((0 until 100000).iterator.map(i => { + writer.write((0 until 100000).iterator.map(i => { if (i == 99990) { throw new SparkException("Intentional failure") } @@ -163,7 +175,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte })) } assert(temporaryFilesCreated.nonEmpty) - writer.stop() + writer.stop(false) assert(temporaryFilesCreated.count(_.exists()) === 0) } From 104fb989a7be00b134f43047b32aae37f540d06c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 16 Oct 2015 11:47:45 -0700 Subject: [PATCH 16/26] Fix MiMa --- project/MimaExcludes.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index e6d22ee25276..60eef612eba9 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -121,7 +121,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.shuffle.unsafe.UnsafeShuffleHandle"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager$") + "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.shuffle.sort.SortShuffleFileWriter") ) case v if v.startsWith("1.5") => Seq( From b3d539077f38c8e05e2ca6d46add5c31eca30c4c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 16 Oct 2015 11:56:22 -0700 Subject: [PATCH 17/26] Re-order classes so Scaladoc appears at top of file. --- .../shuffle/sort/SortShuffleManager.scala | 117 +++++++++--------- 1 file changed, 59 insertions(+), 58 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 35cd8ee91a6c..686fe2cc345c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -23,64 +23,6 @@ import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ -/** - * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the - * serialized shuffle. - */ -private[spark] class SerializedShuffleHandle[K, V]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, V]) - extends BaseShuffleHandle(shuffleId, numMaps, dependency) { -} - -/** - * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the - * bypass merge sort shuffle path. - */ -private[spark] class BypassMergeSortShuffleHandle[K, V]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, V]) - extends BaseShuffleHandle(shuffleId, numMaps, dependency) { -} - -private[spark] object SortShuffleManager extends Logging { - - /** - * The maximum number of shuffle output partitions that SortShuffleManager supports when - * - */ - val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE = - PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 - - /** - * Helper method for determining whether a shuffle should use an optimized serialized shuffle - * path or whether it should fall back to the original path that operates on deserialized objects. - */ - def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = { - val shufId = dependency.shuffleId - val numPartitions = dependency.partitioner.numPartitions - val serializer = Serializer.getSerializer(dependency.serializer) - if (!serializer.supportsRelocationOfSerializedObjects) { - log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + - s"${serializer.getClass.getName}, does not support object relocation") - false - } else if (dependency.aggregator.isDefined) { - log.debug( - s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") - false - } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { - log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + - s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions") - false - } else { - log.debug(s"Can use serialized shuffle for shuffle $shufId") - true - } - } -} - /** * A shuffle implementation that uses directly-managed memory to implement several performance * optimizations for certain types of shuffles. In cases where the new performance optimizations @@ -224,3 +166,62 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager shuffleBlockResolver.stop() } } + + +private[spark] object SortShuffleManager extends Logging { + + /** + * The maximum number of shuffle output partitions that SortShuffleManager supports when + * + */ + val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE = + PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + + /** + * Helper method for determining whether a shuffle should use an optimized serialized shuffle + * path or whether it should fall back to the original path that operates on deserialized objects. + */ + def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = { + val shufId = dependency.shuffleId + val numPartitions = dependency.partitioner.numPartitions + val serializer = Serializer.getSerializer(dependency.serializer) + if (!serializer.supportsRelocationOfSerializedObjects) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + + s"${serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.aggregator.isDefined) { + log.debug( + s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") + false + } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions") + false + } else { + log.debug(s"Can use serialized shuffle for shuffle $shufId") + true + } + } +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * serialized shuffle. + */ +private[spark] class SerializedShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * bypass merge sort shuffle path. + */ +private[spark] class BypassMergeSortShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} From e1d7d59527dc937838cd052a27d6aee29bbf0df3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 16 Oct 2015 12:41:42 -0700 Subject: [PATCH 18/26] Update SortShuffleManager Scaladoc --- .../shuffle/sort/SortShuffleManager.scala | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 686fe2cc345c..ae0cab3a3af1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -24,30 +24,28 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ /** - * A shuffle implementation that uses directly-managed memory to implement several performance - * optimizations for certain types of shuffles. In cases where the new performance optimizations - * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those - * shuffles. - * - * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: - * - * - The shuffle dependency specifies no aggregation or output ordering. - * - The shuffle serializer supports relocation of serialized values (this is currently supported - * by KryoSerializer and Spark SQL's custom serializers). - * - The shuffle produces fewer than 16777216 output partitions. - * - * In addition, extra spill-merging optimizations are automatically applied when the shuffle - * compression codec supports concatenation of serialized streams. This is currently supported by - * Spark's LZF serializer. - * - * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. * In sort-based shuffle, incoming records are sorted according to their target partition ids, then * written to a single map output file. Reducers fetch contiguous regions of this file in order to * read their portion of the map output. In cases where the map output data is too large to fit in * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged * to produce the final output file. * - * UnsafeShuffleManager optimizes this process in several ways: + * Sort-based shuffle has two different write paths for producing its map output files: + * + * - Serialized sorting: used when all three of the following conditions hold: + * 1. The shuffle dependency specifies no aggregation or output ordering. + * 2. The shuffle serializer supports relocation of serialized values (this is currently + * supported by KryoSerializer and Spark SQL's custom serializers). + * 3. The shuffle produces fewer than 16777216 output partitions. + * - Deserialized sorting: used to handle all other cases. + * + * ----------------------- + * Serialized sorting mode + * ----------------------- + * + * In the serialized sorting mode, incoming records are serialized as soon as they are passed to the + * shuffle writer and are buffered in a serialized form during sorting. This write path implements + * several optimizations: * * - Its sort operates on serialized binary data rather than Java objects, which reduces memory * consumption and GC overheads. This optimization requires the record serializer to have certain @@ -66,7 +64,7 @@ import org.apache.spark.shuffle._ * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used * and avoids the need to allocate decompression or copying buffers during the merge. * - * For more details on UnsafeShuffleManager's design, see SPARK-7081. + * For more details on these optimizations, see SPARK-7081. */ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { From d3b091e0fde83f65df4e8c9c01a22e4f3f2b54b9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 16 Oct 2015 12:54:48 -0700 Subject: [PATCH 19/26] Fix bug in planning of bypass merge sort path --- .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index ae0cab3a3af1..2ffc5e01b4e4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -84,7 +84,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (SortShuffleWriter.shouldBypassMergeSort( + if (!dependency.mapSideCombine && SortShuffleWriter.shouldBypassMergeSort( SparkEnv.get.conf, dependency.partitioner.numPartitions, aggregator = None, From b3af7a8168ca496bc53c6b78d72b8a766587d5fe Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 17 Oct 2015 18:48:26 -0700 Subject: [PATCH 20/26] Fix shuffle spilling test --- .../spark/shuffle/sort/ShuffleExternalSorter.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 1452b4fdbb1f..85fdaa8115fa 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -76,6 +76,10 @@ final class ShuffleExternalSorter { private final BlockManager blockManager; private final TaskContext taskContext; private final ShuffleWriteMetrics writeMetrics; + private long numRecordsInsertedSinceLastSpill = 0; + + /** Force this sorter to spill when there are this many elements in memory. For testing only */ + private final long numElementsForSpillThreshold; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -117,6 +121,8 @@ public ShuffleExternalSorter( this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.numElementsForSpillThreshold = + conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.pageSizeBytes = (int) Math.min( PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes()); this.maxRecordSizeBytes = pageSizeBytes - 4; @@ -141,6 +147,7 @@ private void initializeForWriting() throws IOException { } this.inMemSorter = new ShuffleInMemorySorter(initialSize); + numRecordsInsertedSinceLastSpill = 0; } /** @@ -406,6 +413,10 @@ public void insertRecord( int lengthInBytes, int partitionId) throws IOException { + if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) { + spill(); + } + growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. final int totalSpaceRequired = lengthInBytes + 4; @@ -453,6 +464,7 @@ public void insertRecord( recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, partitionId); + numRecordsInsertedSinceLastSpill += 1; } /** From bf0170e82d2555996ca5170297b392bf760a92b1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 20 Oct 2015 15:10:27 -0700 Subject: [PATCH 21/26] Simplify MiMa excludes. --- project/MimaExcludes.scala | 31 ++++++------------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 0ab2b986dec1..b5e661d3ecfa 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -37,6 +37,7 @@ object MimaExcludes { Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("network"), + MimaBuild.excludeSparkPackage("unsafe"), // These are needed if checking against the sbt build, since they are part of // the maven-generated artifacts in 1.3. excludePackage("org.spark-project.jetty"), @@ -44,7 +45,11 @@ object MimaExcludes { // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), // SQL columnar is considered private. - excludePackage("org.apache.spark.sql.columnar") + excludePackage("org.apache.spark.sql.columnar"), + // The shuffle package is considered private. + excludePackage("org.apache.spark.shuffle"), + // The collections utlities are considered pricate. + excludePackage("org.apache.spark.util.collection") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ @@ -103,30 +108,6 @@ object MimaExcludes { ) ++ Seq( ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.SparkContext.preferredNodeLocationData_=") - ) ++ Seq( - // SPARK-10708: Consolidate sort shuffle implementations - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.OrderedInputStream"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.ChainedBuffer"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.PartitionedSerializedPairBuffer"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.ChainedBufferOutputStream"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.PartitionedSerializedPairBuffer$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.SerializedSortDataFormat"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.shuffle.unsafe.UnsafeShuffleWriter"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.shuffle.unsafe.UnsafeShuffleHandle"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.shuffle.sort.SortShuffleFileWriter") ) case v if v.startsWith("1.5") => Seq( From 26f5f6da0afecff46921e6b25034298d6169bd12 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 20 Oct 2015 15:37:30 -0700 Subject: [PATCH 22/26] Improve clarity of logic for choosing to bypass merge sort. --- .../shuffle/sort/SortShuffleManager.scala | 10 ++--- .../shuffle/sort/SortShuffleWriter.scala | 16 ++++--- .../util/collection/ExternalSorter.scala | 8 ---- .../shuffle/sort/SortShuffleWriterSuite.scala | 45 ------------------- 4 files changed, 13 insertions(+), 66 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 2ffc5e01b4e4..e8cec055c063 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -84,22 +84,20 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (!dependency.mapSideCombine && SortShuffleWriter.shouldBypassMergeSort( - SparkEnv.get.conf, - dependency.partitioner.numPartitions, - aggregator = None, - keyOrdering = None)) { + if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) { // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't - // need local aggregation and sorting, write numPartitions files directly and just concatenate + // need map-side aggregation, then write numPartitions files directly and just concatenate // them at the end. This avoids doing serialization and deserialization twice to merge // together the spilled files, which would happen with the normal code path. The downside is // having multiple files open at a time and thus more memory allocated to buffers. new BypassMergeSortShuffleHandle[K, V]( shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: new SerializedShuffleHandle[K, V]( shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) } else { + // Otherwise, buffer map outputs in a deserialized form: new BaseShuffleHandle(shuffleId, numMaps, dependency) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 6b38007a9fdc..bbd9c1ab53cd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -101,12 +101,14 @@ private[spark] class SortShuffleWriter[K, V, C]( } private[spark] object SortShuffleWriter { - def shouldBypassMergeSort( - conf: SparkConf, - numPartitions: Int, - aggregator: Option[Aggregator[_, _, _]], - keyOrdering: Option[Ordering[_]]): Boolean = { - val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + // We cannot bypass sorting if we need to do map-side aggregation. + if (dep.mapSideCombine) { + require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") + false + } else { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + dep.partitioner.numPartitions <= bypassMergeThreshold + } } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 754d7705d127..c48c453a90d0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -29,7 +29,6 @@ import com.google.common.io.ByteStreams import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.shuffle.sort.SortShuffleWriter import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** @@ -103,13 +102,6 @@ private[spark] class ExternalSorter[K, V, C]( if (shouldPartition) partitioner.get.getPartition(key) else 0 } - // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class. - // As a sanity check, make sure that we're not handling a shuffle which should use that path. - if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) { - throw new IllegalArgumentException("ExternalSorter should not be used to handle " - + " a sort that the BypassMergeSortShuffleWriter should handle") - } - private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager private val ser = Serializer.getSerializer(serializer) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala deleted file mode 100644 index 34b4984f12c0..000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort - -import org.mockito.Mockito._ - -import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite} - -class SortShuffleWriterSuite extends SparkFunSuite { - - import SortShuffleWriter._ - - test("conditions for bypassing merge-sort") { - val conf = new SparkConf(loadDefaults = false) - val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS) - val ord = implicitly[Ordering[Int]] - - // Numbers of partitions that are above and below the default bypassMergeThreshold - val FEW_PARTITIONS = 50 - val MANY_PARTITIONS = 10000 - - // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high - assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None)) - assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None)) - - // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions - assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord))) - assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None)) - } -} From 71d67fe956f530d415d7e5b8fdd30d4c69cbb7ab Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 20 Oct 2015 15:55:16 -0700 Subject: [PATCH 23/26] Reduce duplication in SortShuffleSuite. --- .../org/apache/spark/SortShuffleSuite.scala | 97 +++++++++---------- 1 file changed, 44 insertions(+), 53 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 037d92d69248..419a54b2a058 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -34,69 +34,60 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with sort-based shuffle. + private var tempDir: File = _ + override def beforeAll() { conf.set("spark.shuffle.manager", "sort") } - test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") { - val tmpDir = Utils.createTempDir() + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + conf.set("spark.local.dir", tempDir.getAbsolutePath) + sc = new SparkContext("local", "test", conf) + } + + override def afterEach(): Unit = { try { - val myConf = conf.clone() - .set("spark.local.dir", tmpDir.getAbsolutePath) - sc = new SparkContext("local", "test", myConf) - // Create a shuffled RDD and verify that it actually uses the new serialized map output path - val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) - val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) - .setSerializer(new KryoSerializer(myConf)) - val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep)) - def getAllFiles: Set[File] = - FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet - val filesBeforeShuffle = getAllFiles - // Force the shuffle to be performed - shuffledRdd.count() - // Ensure that the shuffle actually created files that will need to be cleaned up - val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle - filesCreatedByShuffle.map(_.getName) should be - Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") - // Check that the cleanup actually removes the files - sc.env.blockManager.master.removeShuffle(0, blocking = true) - for (file <- filesCreatedByShuffle) { - assert (!file.exists(), s"Shuffle file $file was not cleaned up") - } + Utils.deleteRecursively(tempDir) } finally { - Utils.deleteRecursively(tmpDir) + super.afterEach() } } + test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") { + // Create a shuffled RDD and verify that it actually uses the new serialized map output path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new KryoSerializer(conf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep)) + ensureFilesAreCleanedUp(shuffledRdd) + } + test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") { - val tmpDir = Utils.createTempDir() - try { - val myConf = conf.clone() - .set("spark.local.dir", tmpDir.getAbsolutePath) - sc = new SparkContext("local", "test", myConf) - // Create a shuffled RDD and verify that it actually uses the old deserialized map output path - val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) - val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) - .setSerializer(new JavaSerializer(myConf)) - val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep)) - def getAllFiles: Set[File] = - FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet - val filesBeforeShuffle = getAllFiles - // Force the shuffle to be performed - shuffledRdd.count() - // Ensure that the shuffle actually created files that will need to be cleaned up - val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle - filesCreatedByShuffle.map(_.getName) should be - Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") - // Check that the cleanup actually removes the files - sc.env.blockManager.master.removeShuffle(0, blocking = true) - for (file <- filesCreatedByShuffle) { - assert (!file.exists(), s"Shuffle file $file was not cleaned up") - } - } finally { - Utils.deleteRecursively(tmpDir) + // Create a shuffled RDD and verify that it actually uses the old deserialized map output path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new JavaSerializer(conf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep)) + ensureFilesAreCleanedUp(shuffledRdd) + } + + private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit = { + def getAllFiles: Set[File] = + FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") } } } From f7c620c8f6a26d33a33a074b45a66e735a0ba3bb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 20 Oct 2015 15:56:00 -0700 Subject: [PATCH 24/26] Address a number of minor review comments --- .../sort/BypassMergeSortShuffleWriter.java | 16 +++++++++++----- .../spark/shuffle/sort/SortShuffleManager.scala | 9 +++++++-- .../shuffle/sort/UnsafeShuffleWriterSuite.java | 2 +- .../sort/BypassMergeSortShuffleWriterSuite.scala | 6 +++--- 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index c97b3785f40d..ee82d679935c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -21,12 +21,15 @@ import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; +import javax.annotation.Nullable; +import scala.None$; import scala.Option; import scala.Product2; import scala.Tuple2; import scala.collection.Iterator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,8 +48,6 @@ import org.apache.spark.storage.*; import org.apache.spark.util.Utils; -import javax.annotation.Nullable; - /** * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path * writes incoming records to separate files, one file per reduce partition, then concatenates these @@ -160,11 +161,16 @@ public void write(Iterator> records) throws IOException { mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } - // Exposed for testing + @VisibleForTesting long[] getPartitionLengths() { return partitionLengths; } + /** + * Concatenate all of the per-partition files into a single combined file. + * + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). + */ private long[] writePartitionedFile(File outputFile) throws IOException { // Track location of the partition starts in the output file final long[] lengths = new long[numPartitions]; @@ -202,7 +208,7 @@ private long[] writePartitionedFile(File outputFile) throws IOException { @Override public Option stop(boolean success) { if (stopping) { - return Option.apply(null); + return None$.empty(); } else { stopping = true; if (success) { @@ -226,7 +232,7 @@ public Option stop(boolean success) { } } shuffleBlockResolver.removeDataByMap(shuffleId, mapId); - return Option.apply(null); + return None$.empty(); } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index e8cec055c063..1105167d39d8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -74,7 +74,11 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager " Shuffle will continue to spill to disk when necessary.") } + /** + * A mapping from shuffle ids to the number of mappers producing output for those shuffles. + */ private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** @@ -168,8 +172,9 @@ private[spark] object SortShuffleManager extends Logging { /** * The maximum number of shuffle output partitions that SortShuffleManager supports when - * - */ + * buffering map outputs in a serialized form. This is an extreme defensive programming measure, + * since it's extremely unlikely that a single shuffle produces over 16 million output partitions. + * */ val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a83693feb98b..29d9823b1f71 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -21,7 +21,6 @@ import java.nio.ByteBuffer; import java.util.*; -import org.apache.spark.shuffle.sort.SerializedShuffleHandle; import scala.*; import scala.collection.Iterator; import scala.runtime.AbstractFunction1; @@ -56,6 +55,7 @@ import org.apache.spark.scheduler.MapStatus; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.shuffle.sort.SerializedShuffleHandle; import org.apache.spark.storage.*; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 3899856c5cc7..b92a302806f7 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -122,7 +122,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte conf ) writer.write(Iterator.empty) - writer.stop(true) + writer.stop( /* success = */ true) assert(writer.getPartitionLengths.sum === 0) assert(outputFile.exists()) assert(outputFile.length() === 0) @@ -146,7 +146,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte conf ) writer.write(records) - writer.stop(true) + writer.stop( /* success = */ true) assert(temporaryFilesCreated.nonEmpty) assert(writer.getPartitionLengths.sum === outputFile.length()) assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted @@ -175,7 +175,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte })) } assert(temporaryFilesCreated.nonEmpty) - writer.stop(false) + writer.stop( /* success = */ false) assert(temporaryFilesCreated.count(_.exists()) === 0) } From 322fd131ae7db9cc5b56ee3950c18b578401d892 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 21 Oct 2015 22:12:45 -0700 Subject: [PATCH 25/26] Call superclass's withFixture method to try to fix tests --- core/src/test/scala/org/apache/spark/SparkFunSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 9be9db01c7de..e8bb2732cae5 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -39,7 +39,7 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging { val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s") try { logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") - test() + super.withFixture(test) } finally { logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") } From db0cd28015a2c790911fc7bbdf81cd65b973d2fc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 21 Oct 2015 22:17:32 -0700 Subject: [PATCH 26/26] Correct test fix --- core/src/test/scala/org/apache/spark/SortShuffleSuite.scala | 3 ++- core/src/test/scala/org/apache/spark/SparkFunSuite.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 419a54b2a058..b8ab227517cc 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -43,7 +43,6 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { override def beforeEach(): Unit = { tempDir = Utils.createTempDir() conf.set("spark.local.dir", tempDir.getAbsolutePath) - sc = new SparkContext("local", "test", conf) } override def afterEach(): Unit = { @@ -55,6 +54,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { } test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") { + sc = new SparkContext("local", "test", conf) // Create a shuffled RDD and verify that it actually uses the new serialized map output path val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) @@ -65,6 +65,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { } test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") { + sc = new SparkContext("local", "test", conf) // Create a shuffled RDD and verify that it actually uses the old deserialized map output path val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index e8bb2732cae5..9be9db01c7de 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -39,7 +39,7 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging { val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s") try { logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") - super.withFixture(test) + test() } finally { logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") }