diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/MixedShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/MixedShuffleReader.scala new file mode 100644 index 0000000000000..1a98c4e1f7f20 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/MixedShuffleReader.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import org.apache.spark.{TaskContext, Logging} +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.shuffle.hash.HashShuffleReader + +/** + * ShuffleReader that chooses SortShuffleReader or HashShuffleReader depending on whether there is + * a key ordering. + */ +private[spark] class MixedShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext) + extends ShuffleReader[K, C] with Logging { + + private val shuffleReader = if (handle.dependency.keyOrdering.isDefined) { + new SortShuffleReader[K, C](handle, startPartition, endPartition, context) + } else { + new HashShuffleReader[K, C](handle, startPartition, endPartition, context) + } + + override def read(): Iterator[Product2[K, C]] = shuffleReader.read() +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 0497036192154..5338cc9881482 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.{SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { @@ -48,7 +47,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { // We currently use the same block store shuffle fetcher as the hash-based shuffle. - new HashShuffleReader( + new MixedShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleReader.scala new file mode 100644 index 0000000000000..53d48536fe050 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleReader.scala @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.File +import java.io.FileOutputStream +import java.nio.ByteBuffer +import java.util.Comparator + +import scala.collection.mutable.{ArrayBuffer, HashMap, Queue} +import scala.util.{Failure, Success, Try} + +import org.apache.spark._ +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.{BaseShuffleHandle, FetchFailedException, ShuffleReader} +import org.apache.spark.storage._ +import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.util.collection.{MergeUtil, TieredDiskMerger} + +/** + * SortShuffleReader merges and aggregates shuffle data that has already been sorted within each + * map output block. + * + * As blocks are fetched, we store them in memory until we fail to acquire space from the + * ShuffleMemoryManager. When this occurs, we merge some in-memory blocks to disk and go back to + * fetching. + * + * TieredDiskMerger is responsible for managing the merged on-disk blocks and for supplying an + * iterator with their merged contents. The final iterator that is passed to user code merges this + * on-disk iterator with the in-memory blocks that have not yet been spilled. + */ +private[spark] class SortShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext) + extends ShuffleReader[K, C] with Logging { + + /** Manage the fetched in-memory shuffle block and related buffer */ + case class MemoryShuffleBlock(blockId: BlockId, blockData: ManagedBuffer) + + require(endPartition == startPartition + 1, + "Sort shuffle currently only supports fetching one partition") + + private val dep = handle.dependency + private val conf = SparkEnv.get.conf + private val blockManager = SparkEnv.get.blockManager + private val ser = Serializer.getSerializer(dep.serializer) + private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager + + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + /** Queue to store in-memory shuffle blocks */ + private val inMemoryBlocks = new Queue[MemoryShuffleBlock]() + + /** + * Maintain block manager and reported size of each shuffle block. The block manager is used for + * error reporting. The reported size, which, because of size compression, may be slightly + * different than the size of the actual fetched block, is used for calculating how many blocks + * to spill. + */ + private val shuffleBlockMap = new HashMap[ShuffleBlockId, (BlockManagerId, Long)]() + + /** keyComparator for mergeSort, id keyOrdering is not available, + * using hashcode of key to compare */ + private val keyComparator: Comparator[K] = dep.keyOrdering.getOrElse(new Comparator[K] { + override def compare(a: K, b: K) = { + val h1 = if (a == null) 0 else a.hashCode() + val h2 = if (b == null) 0 else b.hashCode() + if (h1 < h2) -1 else if (h1 == h2) 0 else 1 + } + }) + + /** A merge thread to merge on-disk blocks */ + private val tieredMerger = new TieredDiskMerger(conf, dep, keyComparator, context) + + /** Shuffle block fetcher iterator */ + private var shuffleRawBlockFetcherItr: ShuffleRawBlockFetcherIterator = _ + + /** Number of bytes spilled in memory and on disk */ + private var _memoryBytesSpilled: Long = 0L + private var _diskBytesSpilled: Long = 0L + + /** Number of bytes left to fetch */ + private var unfetchedBytes: Long = 0L + + def memoryBytesSpilled: Long = _memoryBytesSpilled + + def diskBytesSpilled: Long = _diskBytesSpilled + tieredMerger.diskBytesSpilled + + override def read(): Iterator[Product2[K, C]] = { + tieredMerger.start() + + computeShuffleBlocks() + + for ((blockId, blockOption) <- fetchShuffleBlocks()) { + val blockData = blockOption match { + case Success(b) => b + case Failure(e) => + blockId match { + case b @ ShuffleBlockId(shuffleId, mapId, _) => + val address = shuffleBlockMap(b)._1 + throw new FetchFailedException (address, shuffleId.toInt, mapId.toInt, startPartition, + Utils.exceptionString (e)) + case _ => + throw new SparkException ( + s"Failed to get block $blockId, which is not a shuffle block", e) + } + } + + shuffleRawBlockFetcherItr.currentResult = null + + // Try to fit block in memory. If this fails, merge in-memory blocks to disk. + val blockSize = blockData.size + val granted = shuffleMemoryManager.tryToAcquire(blockSize) + if (granted >= blockSize) { + if (blockData.isDirect) { + // If the shuffle block is allocated on a direct buffer, copy it to an on-heap buffer, + // otherwise off heap memory will be increased to the shuffle memory size. + val onHeapBuffer = ByteBuffer.allocate(blockSize.toInt) + onHeapBuffer.put(blockData.nioByteBuffer) + + inMemoryBlocks += MemoryShuffleBlock(blockId, new NioManagedBuffer(onHeapBuffer)) + blockData.release() + } else { + inMemoryBlocks += MemoryShuffleBlock(blockId, blockData) + } + } else { + logDebug(s"Granted $granted memory is not enough to store shuffle block (id: $blockId, " + + s"size: $blockSize), spilling in-memory blocks to release the memory") + + shuffleMemoryManager.release(granted) + spillInMemoryBlocks(MemoryShuffleBlock(blockId, blockData)) + } + + unfetchedBytes -= shuffleBlockMap(blockId.asInstanceOf[ShuffleBlockId])._2 + } + + // Make sure all the blocks have been fetched. + assert(unfetchedBytes == 0L) + + tieredMerger.doneRegisteringOnDiskBlocks() + + // Merge on-disk blocks with in-memory blocks to directly feed to the reducer. + val finalItrGroup = inMemoryBlocksToIterators(inMemoryBlocks) ++ Seq(tieredMerger.readMerged()) + val mergedItr = + MergeUtil.mergeSort(finalItrGroup, keyComparator, dep.keyOrdering, dep.aggregator) + + // Update the spill metrics and do cleanup work when task is finished. + context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) + + def releaseFinalShuffleMemory(): Unit = { + inMemoryBlocks.foreach { block => + block.blockData.release() + shuffleMemoryManager.release(block.blockData.size) + } + inMemoryBlocks.clear() + } + context.addTaskCompletionListener(_ => releaseFinalShuffleMemory()) + + // Release the in-memory block when iteration is completed. + val completionItr = CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]]( + mergedItr, releaseFinalShuffleMemory()) + + new InterruptibleIterator(context, completionItr.map(p => (p._1, p._2))) + } + + /** + * Called when we've failed to acquire memory for a block we've just fetched. Figure out how many + * blocks to spill and then spill them. + */ + private def spillInMemoryBlocks(tippingBlock: MemoryShuffleBlock): Unit = { + val (tmpBlockId, file) = blockManager.diskBlockManager.createTempShuffleBlock() + + // If the remaining unfetched data would fit inside our current allocation, we don't want to + // waste time spilling blocks beyond the space needed for it. + // Note that the number of unfetchedBytes is not exact, because of the compression used on the + // sizes of map output blocks. + var bytesToSpill = unfetchedBytes + val blocksToSpill = new ArrayBuffer[MemoryShuffleBlock]() + blocksToSpill += tippingBlock + bytesToSpill -= tippingBlock.blockData.size + while (bytesToSpill > 0 && !inMemoryBlocks.isEmpty) { + val block = inMemoryBlocks.dequeue() + blocksToSpill += block + bytesToSpill -= block.blockData.size + } + + _memoryBytesSpilled += blocksToSpill.map(_.blockData.size()).sum + + if (blocksToSpill.size > 1) { + spillMultipleBlocks(file, tmpBlockId, blocksToSpill, tippingBlock) + } else { + spillSingleBlock(file, blocksToSpill.head) + } + + tieredMerger.registerOnDiskBlock(tmpBlockId, file) + + logInfo(s"Merged ${blocksToSpill.size} in-memory blocks into file ${file.getName}") + } + + private def spillSingleBlock(file: File, block: MemoryShuffleBlock): Unit = { + val fos = new FileOutputStream(file) + val buffer = block.blockData.nioByteBuffer() + var channel = fos.getChannel + var success = false + + try { + while (buffer.hasRemaining) { + channel.write(buffer) + } + success = true + } finally { + if (channel != null) { + channel.close() + channel = null + } + if (!success) { + if (file.exists()) { + file.delete() + } + } else { + _diskBytesSpilled += file.length() + } + // When we spill a single block, it's the single tipping block that we never acquired memory + // from the shuffle memory manager for, so we don't need to release any memory from there. + block.blockData.release() + } + } + + /** + * Merge multiple in-memory blocks to a single on-disk file. + */ + private def spillMultipleBlocks(file: File, tmpBlockId: BlockId, + blocksToSpill: Seq[MemoryShuffleBlock], tippingBlock: MemoryShuffleBlock): Unit = { + val itrGroup = inMemoryBlocksToIterators(blocksToSpill) + val partialMergedItr = + MergeUtil.mergeSort(itrGroup, keyComparator, dep.keyOrdering, dep.aggregator) + val curWriteMetrics = new ShuffleWriteMetrics() + var writer = blockManager.getDiskWriter(tmpBlockId, file, ser, fileBufferSize, curWriteMetrics) + var success = false + + try { + partialMergedItr.foreach(writer.write) + success = true + } finally { + if (!success) { + if (writer != null) { + writer.revertPartialWritesAndClose() + writer = null + } + if (file.exists()) { + file.delete() + } + } else { + writer.commitAndClose() + writer = null + } + for (block <- blocksToSpill) { + block.blockData.release() + if (block != tippingBlock) { + shuffleMemoryManager.release(block.blockData.size) + } + } + } + _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten + } + + private def inMemoryBlocksToIterators(blocks: Seq[MemoryShuffleBlock]) + : Seq[Iterator[Product2[K, C]]] = { + blocks.map{ case MemoryShuffleBlock(id, buf) => + blockManager.dataDeserialize(id, buf.nioByteBuffer(), ser) + .asInstanceOf[Iterator[Product2[K, C]]] + } + } + + /** + * Utility function to compute the shuffle blocks and related BlockManagerID, block size, + * also the total request shuffle size before starting to fetch the shuffle blocks. + */ + private def computeShuffleBlocks(): Unit = { + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(handle.shuffleId, startPartition) + + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]() + for (((address, size), index) <- statuses.zipWithIndex) { + splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) + } + + splitsByAddress.foreach { case (id, blocks) => + blocks.foreach { case (idx, len) => + shuffleBlockMap.put(ShuffleBlockId(handle.shuffleId, idx, startPartition), (id, len)) + unfetchedBytes += len + } + } + } + + private def fetchShuffleBlocks(): Iterator[(BlockId, Try[ManagedBuffer])] = { + val blocksByAddress = new HashMap[BlockManagerId, ArrayBuffer[(ShuffleBlockId, Long)]]() + + shuffleBlockMap.foreach { case (block, (id, len)) => + blocksByAddress.getOrElseUpdate(id, + ArrayBuffer[(ShuffleBlockId, Long)]()) += ((block, len)) + } + + shuffleRawBlockFetcherItr = new ShuffleRawBlockFetcherIterator( + context, + SparkEnv.get.blockManager.shuffleClient, + blockManager, + blocksByAddress.toSeq, + conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) + + val completionItr = CompletionIterator[ + (BlockId, Try[ManagedBuffer]), + Iterator[(BlockId, Try[ManagedBuffer])]](shuffleRawBlockFetcherItr, + context.taskMetrics.updateShuffleReadMetrics()) + + new InterruptibleIterator[(BlockId, Try[ManagedBuffer])](context, completionItr) + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index a066435df6fb0..f3d15d261e782 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -54,6 +54,10 @@ private[spark] class SortShuffleWriter[K, V, C]( sorter = new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) sorter.insertAll(records) + } else if (dep.keyOrdering.isDefined) { + sorter = new ExternalSorter[K, V, V]( + None, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + sorter.insertAll(records) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 8f28ef49a8a6f..12a01429e565d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -30,6 +30,41 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.serializer.Serializer import org.apache.spark.util.{CompletionIterator, Utils} +private[spark] +final class ShuffleBlockFetcherIterator( + context: TaskContext, + shuffleClient: ShuffleClient, + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + serializer: Serializer, + maxBytesInFlight: Long) + extends Iterator[(BlockId, Try[Iterator[Any]])] { + + val shuffleRawBlockFetcherItr = new ShuffleRawBlockFetcherIterator( + context, + shuffleClient, + blockManager, + blocksByAddress, + maxBytesInFlight) + + def hasNext: Boolean = shuffleRawBlockFetcherItr.hasNext + + def next(): (BlockId, Try[Iterator[Any]]) = { + val (blockId, block) = shuffleRawBlockFetcherItr.next() + val completedItr = block.map { buf => + val is = blockManager.wrapForCompression(blockId, buf.createInputStream()) + val itr = serializer.newInstance().deserializeStream(is).asIterator + CompletionIterator[Any, Iterator[Any]](itr, { + // Once the iterator is exhausted, release the buffer and set currentResult to null + // so we don't release it again in cleanup. + buf.release() + shuffleRawBlockFetcherItr.currentResult = null + }) + } + (blockId, completedItr) + } +} + /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. @@ -46,20 +81,18 @@ import org.apache.spark.util.{CompletionIterator, Utils} * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. - * @param serializer serializer used to deserialize the data. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. */ private[spark] -final class ShuffleBlockFetcherIterator( +final class ShuffleRawBlockFetcherIterator( context: TaskContext, shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[ManagedBuffer])] with Logging { - import ShuffleBlockFetcherIterator._ + import ShuffleRawBlockFetcherIterator._ /** * Total number of blocks to fetch. This can be smaller than the total number of blocks @@ -93,7 +126,7 @@ final class ShuffleBlockFetcherIterator( * Current [[FetchResult]] being processed. We track this so we can release the current buffer * in case of a runtime exception when processing the current buffer. */ - @volatile private[this] var currentResult: FetchResult = null + @volatile private[spark] var currentResult: FetchResult = null /** * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that @@ -272,7 +305,7 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Try[Iterator[Any]]) = { + override def next(): (BlockId, Try[ManagedBuffer]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -290,32 +323,18 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[Iterator[Any]] = result match { - case FailureFetchResult(_, e) => - Failure(e) - case SuccessFetchResult(blockId, _, buf) => - // There is a chance that createInputStream can fail (e.g. fetching a local file that does - // not exist, SPARK-4085). In that case, we should propagate the right exception so - // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { is0 => - val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializer.newInstance().deserializeStream(is).asIterator - CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - buf.release() - }) - } + val bufferTry: Try[ManagedBuffer] = result match { + case FailureFetchResult(_, e) => Failure(e) + case SuccessFetchResult(blockId, _, buf) => Success(buf) } - (result.blockId, iteratorTry) + (result.blockId, bufferTry) } } private[storage] -object ShuffleBlockFetcherIterator { +object ShuffleRawBlockFetcherIterator { /** * A request to fetch blocks from a remote BlockManager. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 035f3767ff554..74475d489b5e7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -21,7 +21,6 @@ import java.io._ import java.util.Comparator import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable import com.google.common.io.ByteStreams @@ -395,128 +394,18 @@ private[spark] class ExternalSorter[K, V, C]( val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator) if (aggregator.isDefined) { // Perform partial aggregation across partitions - (p, mergeWithAggregation( + (p, MergeUtil.mergeWithAggregation( iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined)) } else if (ordering.isDefined) { // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey); // sort the elements without trying to merge them - (p, mergeSort(iterators, ordering.get)) + (p, MergeUtil.mergeSort(iterators, ordering.get)) } else { (p, iterators.iterator.flatten) } } } - /** - * Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys. - */ - private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K]) - : Iterator[Product2[K, C]] = - { - val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) - type Iter = BufferedIterator[Product2[K, C]] - val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { - // Use the reverse of comparator.compare because PriorityQueue dequeues the max - override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) - }) - heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true - new Iterator[Product2[K, C]] { - override def hasNext: Boolean = !heap.isEmpty - - override def next(): Product2[K, C] = { - if (!hasNext) { - throw new NoSuchElementException - } - val firstBuf = heap.dequeue() - val firstPair = firstBuf.next() - if (firstBuf.hasNext) { - heap.enqueue(firstBuf) - } - firstPair - } - } - } - - /** - * Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each - * iterator is sorted by key with a given comparator. If the comparator is not a total ordering - * (e.g. when we sort objects by hash code and different keys may compare as equal although - * they're not), we still merge them by doing equality tests for all keys that compare as equal. - */ - private def mergeWithAggregation( - iterators: Seq[Iterator[Product2[K, C]]], - mergeCombiners: (C, C) => C, - comparator: Comparator[K], - totalOrder: Boolean) - : Iterator[Product2[K, C]] = - { - if (!totalOrder) { - // We only have a partial ordering, e.g. comparing the keys by hash code, which means that - // multiple distinct keys might be treated as equal by the ordering. To deal with this, we - // need to read all keys considered equal by the ordering at once and compare them. - new Iterator[Iterator[Product2[K, C]]] { - val sorted = mergeSort(iterators, comparator).buffered - - // Buffers reused across elements to decrease memory allocation - val keys = new ArrayBuffer[K] - val combiners = new ArrayBuffer[C] - - override def hasNext: Boolean = sorted.hasNext - - override def next(): Iterator[Product2[K, C]] = { - if (!hasNext) { - throw new NoSuchElementException - } - keys.clear() - combiners.clear() - val firstPair = sorted.next() - keys += firstPair._1 - combiners += firstPair._2 - val key = firstPair._1 - while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) { - val pair = sorted.next() - var i = 0 - var foundKey = false - while (i < keys.size && !foundKey) { - if (keys(i) == pair._1) { - combiners(i) = mergeCombiners(combiners(i), pair._2) - foundKey = true - } - i += 1 - } - if (!foundKey) { - keys += pair._1 - combiners += pair._2 - } - } - - // Note that we return an iterator of elements since we could've had many keys marked - // equal by the partial order; we flatten this below to get a flat iterator of (K, C). - keys.iterator.zip(combiners.iterator) - } - }.flatMap(i => i) - } else { - // We have a total ordering, so the objects with the same key are sequential. - new Iterator[Product2[K, C]] { - val sorted = mergeSort(iterators, comparator).buffered - - override def hasNext: Boolean = sorted.hasNext - - override def next(): Product2[K, C] = { - if (!hasNext) { - throw new NoSuchElementException - } - val elem = sorted.next() - val k = elem._1 - var c = elem._2 - while (sorted.hasNext && sorted.head._1 == k) { - c = mergeCombiners(c, sorted.head._2) - } - (k, c) - } - } - } - } /** * An internal class for reading a spilled file partition by partition. Expects all the diff --git a/core/src/main/scala/org/apache/spark/util/collection/MergeUtil.scala b/core/src/main/scala/org/apache/spark/util/collection/MergeUtil.scala new file mode 100644 index 0000000000000..246e36b15ab82 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/MergeUtil.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import org.apache.spark.Aggregator + +private[spark] object MergeUtil { + def mergeSort[K, C]( + iterators: Seq[Iterator[Product2[K, C]]], + comparator: Comparator[K], + keyOrdering: Option[Ordering[K]], + aggregator: Option[Aggregator[K, _, C]]) + : Iterator[Product2[K, C]] = { + if (aggregator.isDefined) { + mergeWithAggregation(iterators, aggregator.get.mergeCombiners, + comparator, keyOrdering.isDefined) + } else { + mergeSort(iterators, comparator) + } + } + + /** + * Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys. + */ + def mergeSort[K, C](iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K]) + : Iterator[Product2[K, C]] = { + val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) + type Iter = BufferedIterator[Product2[K, C]] + val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { + // Use the reverse of comparator.compare because PriorityQueue dequeues the max + override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) + }) + heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true + new Iterator[Product2[K, C]] { + override def hasNext: Boolean = !heap.isEmpty + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val firstBuf = heap.dequeue() + val firstPair = firstBuf.next() + if (firstBuf.hasNext) { + heap.enqueue(firstBuf) + } + firstPair + } + } + } + + /** + * Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each + * iterator is sorted by key with a given comparator. If the comparator is not a total ordering + * (e.g. when we sort objects by hash code and different keys may compare as equal although + * they're not), we still merge them by doing equality tests for all keys that compare as equal. + */ + def mergeWithAggregation[K, C]( + iterators: Seq[Iterator[Product2[K, C]]], + mergeCombiners: (C, C) => C, + comparator: Comparator[K], + totalOrder: Boolean) + : Iterator[Product2[K, C]] = { + if (!totalOrder) { + // We only have a partial ordering, e.g. comparing the keys by hash code, which means that + // multiple distinct keys might be treated as equal by the ordering. To deal with this, we + // need to read all keys considered equal by the ordering at once and compare them. + new Iterator[Iterator[Product2[K, C]]] { + val sorted = mergeSort(iterators, comparator).buffered + + // Buffers reused across elements to decrease memory allocation + val keys = new ArrayBuffer[K] + val combiners = new ArrayBuffer[C] + + override def hasNext: Boolean = sorted.hasNext + + override def next(): Iterator[Product2[K, C]] = { + if (!hasNext) { + throw new NoSuchElementException + } + keys.clear() + combiners.clear() + val firstPair = sorted.next() + keys += firstPair._1 + combiners += firstPair._2 + val key = firstPair._1 + while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) { + val pair = sorted.next() + var i = 0 + var foundKey = false + while (i < keys.size && !foundKey) { + if (keys(i) == pair._1) { + combiners(i) = mergeCombiners(combiners(i), pair._2) + foundKey = true + } + i += 1 + } + if (!foundKey) { + keys += pair._1 + combiners += pair._2 + } + } + + // Note that we return an iterator of elements since we could've had many keys marked + // equal by the partial order; we flatten this below to get a flat iterator of (K, C). + keys.iterator.zip(combiners.iterator) + } + }.flatMap(i => i) + } else { + // We have a total ordering, so the objects with the same key are sequential. + new Iterator[Product2[K, C]] { + val sorted = mergeSort(iterators, comparator).buffered + + override def hasNext: Boolean = sorted.hasNext + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val elem = sorted.next() + val k = elem._1 + var c = elem._2 + while (sorted.hasNext && sorted.head._1 == k) { + c = mergeCombiners(c, sorted.head._2) + } + (k, c) + } + } + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/TieredDiskMerger.scala b/core/src/main/scala/org/apache/spark/util/collection/TieredDiskMerger.scala new file mode 100644 index 0000000000000..d55413a8d31be --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/TieredDiskMerger.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io.File +import java.util.Comparator +import java.util.concurrent.{PriorityBlockingQueue, CountDownLatch} + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark._ +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.storage.BlockId +import org.apache.spark.serializer.Serializer +import org.apache.spark.util.CompletionIterator + +/** + * Manages blocks of sorted data on disk that need to be merged together. Carries out a tiered + * merge that will never merge more than spark.shuffle.maxMergeFactor segments at a time. + * Except for the final merge, which merges disk blocks to a returned iterator, TieredDiskMerger + * merges blocks from disk to disk. + * + * TieredDiskMerger carries out disk-to-disk merges in a background thread that can run concurrently + * with blocks being deposited on disk. + * + * When deciding which blocks to merge, it first tries to minimize the number of blocks, and then + * the size of the blocks chosen. + */ +private[spark] class TieredDiskMerger[K, C]( + conf: SparkConf, + dep: ShuffleDependency[K, _, C], + keyComparator: Comparator[K], + context: TaskContext) extends Logging { + + /** Manage the on-disk shuffle block and related file, file size */ + case class DiskShuffleBlock(blockId: BlockId, file: File, len: Long) + extends Comparable[DiskShuffleBlock] { + def compareTo(o: DiskShuffleBlock): Int = len.compare(o.len) + } + + private val maxMergeFactor = conf.getInt("spark.shuffle.maxMergeFactor", 100) + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + private val blockManager = SparkEnv.get.blockManager + private val ser = Serializer.getSerializer(dep.serializer) + + /** PriorityQueue to store the on-disk merging blocks, blocks are merged by size ordering */ + private val onDiskBlocks = new PriorityBlockingQueue[DiskShuffleBlock]() + + /** A merging thread to merge on-disk blocks */ + private val diskToDiskMerger = new DiskToDiskMerger + + /** Signal to block/signal the merge action */ + private val mergeReadyMonitor = new AnyRef() + + private val mergeFinished = new CountDownLatch(1) + + /** Whether more on-disk blocks may come in */ + @volatile private var doneRegistering = false + + /** Number of bytes spilled on disk */ + private var _diskBytesSpilled: Long = 0L + + def diskBytesSpilled: Long = _diskBytesSpilled + + def registerOnDiskBlock(blockId: BlockId, file: File): Unit = { + assert(!doneRegistering) + onDiskBlocks.put(new DiskShuffleBlock(blockId, file, file.length())) + + mergeReadyMonitor.synchronized { + if (shouldMergeNow()) { + mergeReadyMonitor.notify() + } + } + } + + /** + * Notify the merger that no more on disk blocks will be registered. + */ + def doneRegisteringOnDiskBlocks(): Unit = { + mergeReadyMonitor.synchronized { + doneRegistering = true + mergeReadyMonitor.notify() + } + } + + def readMerged(): Iterator[Product2[K, C]] = { + mergeFinished.await() + + // Merge the final group for combiner to directly feed to the reducer + val finalMergedBlocks = onDiskBlocks.toArray(new Array[DiskShuffleBlock](onDiskBlocks.size())) + val finalItrGroup = onDiskBlocksToIterators(finalMergedBlocks) + val mergedItr = + MergeUtil.mergeSort(finalItrGroup, keyComparator, dep.keyOrdering, dep.aggregator) + + onDiskBlocks.clear() + + // Release the on-disk file when iteration is completed. + val completionItr = CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]]( + mergedItr, releaseShuffleBlocks(finalMergedBlocks)) + + new InterruptibleIterator(context, completionItr) + } + + def start() { + diskToDiskMerger.start() + } + + /** + * Release the left in-memory buffer or on-disk file after merged. + */ + private def releaseShuffleBlocks(onDiskShuffleGroup: Array[DiskShuffleBlock]): Unit = { + onDiskShuffleGroup.map { case DiskShuffleBlock(_, file, _) => + try { + logDebug(s"Deleting the unused temp shuffle file: ${file.getName}") + file.delete() + } catch { + // Swallow the exception + case e: Exception => logWarning(s"Unexpected errors when deleting file: ${ + file.getAbsolutePath}", e) + } + } + } + + private def onDiskBlocksToIterators(shufflePartGroup: Seq[DiskShuffleBlock]) + : Seq[Iterator[Product2[K, C]]] = { + shufflePartGroup.map { case DiskShuffleBlock(id, _, _) => + blockManager.diskStore.getValues(id, ser).get.asInstanceOf[Iterator[Product2[K, C]]] + }.toSeq + } + + /** + * Whether we should carry out a disk-to-disk merge now or wait for more blocks or a done + * registering notification to come in. + * + * We want to avoid merging more blocks than we need to. Our last disk-to-disk merge may + * merge fewer than maxMergeFactor blocks, as its only requirement is that, after it has been + * carried out, <= maxMergeFactor blocks remain. E.g., if maxMergeFactor is 10, no more blocks + * will come in, and we have 13 on-disk blocks, the optimal number of blocks to include in the + * last disk-to-disk merge is 4. + * + * While blocks are still coming in, we don't know the optimal number, so we hold off until we + * either receive the notification that no more blocks are coming in, or until maxMergeFactor + * merge is required no matter what. + * + * E.g. if maxMergeFactor is 10 and we have 19 or more on-disk blocks, a 10-block merge will put + * us at 10 or more blocks, so we might as well carry it out now. + */ + private def shouldMergeNow(): Boolean = doneRegistering || + onDiskBlocks.size() >= maxMergeFactor * 2 - 1 + + private final class DiskToDiskMerger extends Thread { + setName(s"tiered-merge-thread-${Thread.currentThread().getId}") + setDaemon(true) + + override def run() { + // Each iteration of this loop carries out a disk-to-disk merge. We remain in this loop until + // no more disk-to-disk merges need to be carried out, i.e. when no more blocks are coming in + // and the final merge won't need to merge more than maxMergeFactor blocks. + while (!doneRegistering || onDiskBlocks.size() > maxMergeFactor) { + while (!shouldMergeNow()) { + mergeReadyMonitor.synchronized { + if (!shouldMergeNow()) { + mergeReadyMonitor.wait() + } + } + } + + if (onDiskBlocks.size() > maxMergeFactor) { + val blocksToMerge = new ArrayBuffer[DiskShuffleBlock]() + // Try to pick the smallest merge width that will result in the next merge being the final + // merge. + var mergeFactor = math.min(onDiskBlocks.size - maxMergeFactor + 1, maxMergeFactor) + while (mergeFactor > 0) { + blocksToMerge += onDiskBlocks.take() + mergeFactor -= 1 + } + + // Merge the blocks + val itrGroup = onDiskBlocksToIterators(blocksToMerge) + val partialMergedItr = + MergeUtil.mergeSort(itrGroup, keyComparator, dep.keyOrdering, dep.aggregator) + // Write merged blocks to disk + val (tmpBlockId, file) = blockManager.diskBlockManager.createTempShuffleBlock() + val curWriteMetrics = new ShuffleWriteMetrics() + var writer = + blockManager.getDiskWriter(tmpBlockId, file, ser, fileBufferSize, curWriteMetrics) + var success = false + + try { + partialMergedItr.foreach(writer.write) + success = true + } finally { + if (!success) { + if (writer != null) { + writer.revertPartialWritesAndClose() + writer = null + } + if (file.exists()) { + file.delete() + } + } else { + writer.commitAndClose() + writer = null + } + releaseShuffleBlocks(blocksToMerge.toArray) + } + _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten + + logInfo(s"Merged ${blocksToMerge.size} on-disk blocks into file ${file.getName}") + + onDiskBlocks.add(DiskShuffleBlock(tmpBlockId, file, file.length())) + } + } + + mergeFinished.countDown() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index df42faab64505..a675c5b9008f8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.rdd import java.io.{ObjectInputStream, ObjectOutputStream, IOException} -import com.esotericsoftware.kryo.KryoException - import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import com.esotericsoftware.kryo.KryoException import org.scalatest.FunSuite import org.apache.spark._ @@ -726,8 +725,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { val repartitioned = data.repartitionAndSortWithinPartitions(partitioner) val partitions = repartitioned.glom().collect() - assert(partitions(0) === Seq((0, 5), (0, 8), (2, 6))) - assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8))) + assert(partitions(0).toSet === Set((0, 5), (0, 8), (2, 6))) + assert(partitions(1).toSet === Set((1, 3), (3, 8), (3, 8))) } test("intersection") { diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 844eff4f4c701..aab8d0fa73671 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -151,4 +151,9 @@ public String toString() { .add("length", length) .toString(); } + + @Override + public boolean isDirect() { + return length >= conf.memoryMapBytes(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java index a415db593a788..4bbc852e5c446 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -68,4 +68,10 @@ public abstract class ManagedBuffer { * Convert the buffer into an Netty object, used to write the data out. */ public abstract Object convertToNetty() throws IOException; + + /** + * Tell whether to not this byte buffer is direct + * @return + */ + public abstract boolean isDirect(); } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java index c806bfa45bef3..909942c64f93b 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -73,4 +73,9 @@ public String toString() { .add("buf", buf) .toString(); } + + @Override + public boolean isDirect() { + return buf.isDirect(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java index f55b884bc45ce..4727abe0e4157 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -71,5 +71,10 @@ public String toString() { .add("buf", buf) .toString(); } + + @Override + public boolean isDirect() { + return buf.isDirect(); + } } diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java index 38113a918f795..02a994a04a2a9 100644 --- a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -101,4 +101,9 @@ public boolean equals(Object other) { } return false; } + + @Override + public boolean isDirect() { + return underlying.isDirect(); + } }