diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 36138cc9a297c..15d0c38e4ee10 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -45,7 +45,7 @@ protected MemoryConsumer(TaskMemoryManager taskMemoryManager) { /** * Returns the size of used memory in bytes. */ - long getUsed() { + public long getUsed() { return used; } diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 9044bb4f4a44b..92c753b2498b8 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -98,6 +98,10 @@ public class TaskMemoryManager { private final long taskAttemptId; + public long getTaskAttemptId() { + return taskAttemptId; + } + /** * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index e493d9a3cf9cc..3144320ecdf30 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -39,6 +39,9 @@ case class Aggregator[K, V, C] ( context: TaskContext): Iterator[(K, C)] = { val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) combiners.insertAll(iter) + if (SparkEnv.get.conf.getBoolean("spark.shuffle.spillAfterRead", false)) { + combiners.spill() + } updateMetrics(context, combiners) combiners.iterator } @@ -48,6 +51,9 @@ case class Aggregator[K, V, C] ( context: TaskContext): Iterator[(K, C)] = { val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) combiners.insertAll(iter) + if (SparkEnv.get.conf.getBoolean("spark.shuffle.spillAfterRead", false)) { + combiners.spill() + } updateMetrics(context, combiners) combiners.iterator } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 637b2dfc193b8..19713ef9ee4d3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -103,6 +103,9 @@ private[spark] class BlockStoreShuffleReader[K, C]( val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) + if (SparkEnv.get.conf.getBoolean("spark.shuffle.spillAfterRead", false)) { + sorter.spill() + } context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 95351e98261d7..eded684d5e862 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -153,13 +153,11 @@ class ExternalAppendOnlyMap[K, V, C]( while (entries.hasNext) { curEntry = entries.next() - val estimatedSize = currentMap.estimateSize() + val estimatedSize = estimateUsedMemory if (estimatedSize > _peakMemoryUsedBytes) { _peakMemoryUsedBytes = estimatedSize } - if (maybeSpill(currentMap, estimatedSize)) { - currentMap = new SizeTrackingAppendOnlyMap[K, C] - } + maybeSpill(estimatedSize) currentMap.changeValue(curEntry._1, update) addElementsRead() } @@ -178,10 +176,18 @@ class ExternalAppendOnlyMap[K, V, C]( insertAll(entries.iterator) } + def estimateUsedMemory(): Long = { + currentMap.estimateSize() + } + + protected def resetAfterSpill(): Unit = { + currentMap = new SizeTrackingAppendOnlyMap[K, C] + } + /** * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. */ - override protected[this] def spill(collection: SizeTracker): Unit = { + override protected[this] def spillCollection(): Unit = { val (blockId, file) = diskBlockManager.createTempLocalBlock() curWriteMetrics = new ShuffleWriteMetrics() var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) @@ -534,6 +540,13 @@ class ExternalAppendOnlyMap[K, V, C]( /** Convenience function to hash the given (K, C) pair by the key. */ private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1) + + /** + * To prevent debug code from printing out the contents of the iterator, and destroying the data + */ + override def toString(): String = { + getClass().getSimpleName + "@" + System.identityHashCode(this) + } } private[spark] object ExternalAppendOnlyMap { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 561ba22df557f..aa380338a66e2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -129,6 +129,8 @@ private[spark] class ExternalSorter[K, V, C]( private var map = new PartitionedAppendOnlyMap[K, C] private var buffer = new PartitionedPairBuffer[K, C] + private[this] val usingMap = aggregator.isDefined + // Total spilling statistics private var _diskBytesSpilled = 0L def diskBytesSpilled: Long = _diskBytesSpilled @@ -177,9 +179,8 @@ private[spark] class ExternalSorter[K, V, C]( def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high - val shouldCombine = aggregator.isDefined - if (shouldCombine) { + if (usingMap) { // Combine values in-memory first using our AppendOnlyMap val mergeValue = aggregator.get.mergeValue val createCombiner = aggregator.get.createCombiner @@ -191,7 +192,7 @@ private[spark] class ExternalSorter[K, V, C]( addElementsRead() kv = records.next() map.changeValue((getPartition(kv._1), kv._1), update) - maybeSpillCollection(usingMap = true) + maybeSpillCollection() } } else { // Stick values into our buffer @@ -199,29 +200,36 @@ private[spark] class ExternalSorter[K, V, C]( addElementsRead() val kv = records.next() buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C]) - maybeSpillCollection(usingMap = false) + maybeSpillCollection() } } } + override protected[this] def resetAfterSpill(): Unit = { + if (usingMap) { + map = new PartitionedAppendOnlyMap[K, C] + } else { + buffer = new PartitionedPairBuffer[K, C] + } + } + + override def estimateUsedMemory: Long = { + if (usingMap) { + map.estimateSize() + } else { + buffer.estimateSize() + } + } + + /** * Spill the current in-memory collection to disk if needed. * * @param usingMap whether we're using a map or buffer as our current in-memory collection */ - private def maybeSpillCollection(usingMap: Boolean): Unit = { - var estimatedSize = 0L - if (usingMap) { - estimatedSize = map.estimateSize() - if (maybeSpill(map, estimatedSize)) { - map = new PartitionedAppendOnlyMap[K, C] - } - } else { - estimatedSize = buffer.estimateSize() - if (maybeSpill(buffer, estimatedSize)) { - buffer = new PartitionedPairBuffer[K, C] - } - } + private def maybeSpillCollection(): Unit = { + val estimatedSize = estimateUsedMemory + maybeSpill(estimatedSize) if (estimatedSize > _peakMemoryUsedBytes) { _peakMemoryUsedBytes = estimatedSize @@ -234,7 +242,15 @@ private[spark] class ExternalSorter[K, V, C]( * * @param collection whichever collection we're using (map or buffer) */ - override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { + override protected[this] def spillCollection(): Unit = { + if (usingMap) { + spillCollection(map) + } else { + spillCollection(buffer) + } + } + + protected[this] def spillCollection(collection: WritablePartitionedPairCollection[K, C]): Unit = { // Because these files may be read during shuffle, their compression must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use // createTempShuffleBlock here; see SPARK-3426 for more context. diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 25ca2037bbac6..5ba54f413af22 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -19,7 +19,7 @@ package org.apache.spark.util.collection import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} /** * Spills contents of an in-memory collection to disk when the memory threshold @@ -28,10 +28,53 @@ import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} private[spark] trait Spillable[C] extends Logging { /** * Spills the current in-memory collection to disk, and releases the memory. + */ + protected def spillCollection(): Unit + + /** + * After a spill, reset any internal data structures so they are ready to accept more input data + */ + protected def resetAfterSpill(): Unit + + /** + * Return an estimate of the current memory used by the collection. * - * @param collection collection to spill to disk + * Note this is *not* the same as the memory requested from the memory manager, for two reasons: + * (1) If we allow the collection to use some initial amount of memory that is untracked, that + * should still be reported here. (which would lead to this amount being larger than what is + * tracked by the memory manager.) + * (2) If we've just requested a large increase in memory from the memory manager, but aren't + * actually *using* that memory yet, we will not report it here (which would lead to this amount + * being smaller than what is tracked by the memory manager.) */ - protected def spill(collection: C): Unit + def estimateUsedMemory(): Long + + /** + * Spills the in-memory collection, releases memory, and updates metrics. This can be + * used to force a spill, even if this collection beleives it still has extra memory, to + * free up memory for other operators. For example, during a stage which does a shuffle-read + * and a shuffle-write, after the shuffle-read is finished, we can spill to free up memory + * for the shuffle-write. + * [[maybeSpill]] can be used when the collection + * should only spill if it doesn't have enough memory + */ + final def spill(): Unit = { + spill(estimateUsedMemory()) + } + + final def spill(currentMemory: Long): Unit = { + if (_elementsRead == 0) { + logDebug(s"Skipping spill since ${this} is empty") + } else { + _spillCount += 1 + logSpillage(currentMemory) + spillCollection() + _elementsRead = 0 + _memoryBytesSpilled += currentMemory + releaseMemory() + resetAfterSpill() + } + } // Number of elements read from input since last spill protected def elementsRead: Long = _elementsRead @@ -57,6 +100,17 @@ private[spark] trait Spillable[C] extends Logging { // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 private[this] var myMemoryThreshold = initialMemoryThreshold + /** + * The amount of memory that has been allocated to this Spillable by the memory manager. + * + * Note that this is *not* the same as [[estimateUsedMemory]] -- see the doc on that method + * for why these differ + */ + def allocatedMemory: Long = { + // we don't ever request initialMemoryThreshold from the memory manager + myMemoryThreshold - initialMemoryThreshold + } + // Number of elements read from input since last spill private[this] var _elementsRead = 0L @@ -66,21 +120,24 @@ private[spark] trait Spillable[C] extends Logging { // Number of spills private[this] var _spillCount = 0 + private[this] val memoryConsumer = new SpillableMemoryConsumer(this, taskMemoryManager) + /** * Spills the current in-memory collection to disk if needed. Attempts to acquire more - * memory before spilling. + * memory before spilling. If this does spill, it will call [[resetAfterSpill()]] to + * prepare the in-memory data structures to accept more data * - * @param collection collection to spill to disk * @param currentMemory estimated size of the collection in bytes * @return true if `collection` was spilled to disk; false otherwise */ - protected def maybeSpill(collection: C, currentMemory: Long): Boolean = { + protected def maybeSpill(currentMemory: Long): Boolean = { var shouldSpill = false if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold val granted = - taskMemoryManager.acquireExecutionMemory(amountToRequest, MemoryMode.ON_HEAP, null) + taskMemoryManager.acquireExecutionMemory(amountToRequest, MemoryMode.ON_HEAP, + memoryConsumer) myMemoryThreshold += granted // If we were granted too little memory to grow further (either tryToAcquire returned 0, // or we already had more memory than myMemoryThreshold), spill the current collection @@ -89,12 +146,7 @@ private[spark] trait Spillable[C] extends Logging { shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold // Actually spill if (shouldSpill) { - _spillCount += 1 - logSpillage(currentMemory) - spill(collection) - _elementsRead = 0 - _memoryBytesSpilled += currentMemory - releaseMemory() + spill(currentMemory) } shouldSpill } @@ -110,7 +162,7 @@ private[spark] trait Spillable[C] extends Logging { def releaseMemory(): Unit = { // The amount we requested does not include the initial memory tracking threshold taskMemoryManager.releaseExecutionMemory( - myMemoryThreshold - initialMemoryThreshold, MemoryMode.ON_HEAP, null) + myMemoryThreshold - initialMemoryThreshold, MemoryMode.ON_HEAP, memoryConsumer) myMemoryThreshold = initialMemoryThreshold } @@ -126,3 +178,33 @@ private[spark] trait Spillable[C] extends Logging { _spillCount, if (_spillCount > 1) "s" else "")) } } + +/** + * A light-wrapper around Spillables to implement MemoryConsumer, just so that + * they can be tracked and logged in TaskMemoryManager. + * + * Note that this does *not* give cooperative memory management for Spillables, its just to + * make debug logs clearly on memory usage. + */ +class SpillableMemoryConsumer(val sp: Spillable[_], val taskMM: TaskMemoryManager) + extends MemoryConsumer(taskMM) with Logging { + def spill(size: Long, trigger: MemoryConsumer): Long = { + // If another memory consumer requests more memory, we can't easily spill here. The + // problem is that even if we do spill, there may be an iterator that is already + // reading from the in-memory data structures, which would hold a reference to that + // object even if we spilled. So even if we spilled, we aren't *actually* freeing memory + // unless we update any in-flight iterators to switch to the spilled data + logDebug(s"Spill requested for ${sp} (TID ${taskMemoryManager.getTaskAttemptId}) by " + + s"${trigger}, but ${this} can't spill") + 0L + } + + override def toString(): String = { + s"SpillableConsumer($sp)" + } + + override def getUsed(): Long = { + sp.allocatedMemory + } + +} diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 00f3f15c4596c..a5afb41bad79b 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -380,6 +380,50 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC manager.unregisterShuffle(0) } + + test("SPARK-14560 -- UnsafeShuffleWriter") { + val myConf = conf.clone() + .set("spark.shuffle.memoryFraction", "0.01") + .set("spark.memory.useLegacyMode", "true") + .set("spark.testing.memory", "500000000") // ~500MB + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + // for relocation, so we can use ShuffleExternalSorter + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.shuffle.spillAfterRead", "true") + sc = new SparkContext("local", "test", myConf) + val N = 2e6.toInt + val p = new org.apache.spark.HashPartitioner(10) + val d = sc.parallelize(1 to N, 10).map { x => (x % 10000) -> x.toLong } + // now we use an aggregator, but that still produces enough data that we need to spill on the + // read-side (this one is ridiculous, we shouldn't aggregate at all, but its just an easy + // way to trigger lots of memory use on the shuffle-read side) + val d2: RDD[(Int, Seq[Long])] = d.aggregateByKey(Seq[Long](), 5) ( + { case (list, next) => list :+ next }, + { case (listA, listB) => listA ++ listB } + ) + val d3 = d2.repartitionAndSortWithinPartitions(p) + d3.count() + } + + test("SPARK-14560 -- SortShuffleWriters") { + val myConf = conf.clone() + .set("spark.shuffle.memoryFraction", "0.01") + .set("spark.memory.useLegacyMode", "true") + .set("spark.testing.memory", "500000000") // ~500MB + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + // pretty small, but otherwise its too easy for the structures to claim they are using 0 + // memory in these small tests + .set("spark.shuffle.spill.initialMemoryThreshold", "5000") + .set("spark.shuffle.spillAfterRead", "true") + sc = new SparkContext("local", "test", myConf) + val N = 2e6.toInt + val p = new org.apache.spark.HashPartitioner(10) + val d = sc.parallelize(1 to N, 10).map { x => (x % 10000) -> x.toLong } + val p2 = new org.apache.spark.HashPartitioner(5) + val d2 = d.repartitionAndSortWithinPartitions(p2) + val d3 = d2.repartitionAndSortWithinPartitions(p) + d3.count() + } } /** diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index dc3185a6d505a..c620144c07e8b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -418,4 +418,24 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { } } + test("SPARK-14560 -- force spill an empty collection") { + // You should be able to force-spill a collection any time -- even if there is nothing + // to spill (eg., nothing has been added, or it just spilled) + val conf = createSparkConf(loadDefaults = true) + sc = new SparkContext("local", "test", conf) + val map = createExternalMap[Int] + map.spill() + assert(map.iterator.toIndexedSeq == IndexedSeq()) + + val map2 = createExternalMap[Int] + val elements = IndexedSeq((1, 1), (2, 2), (5, 5)) + val expected = elements.map { case (k, v) => k -> ArrayBuffer(v)} + map2.insertAll(elements.iterator) + // say the first spill is natural due to the collection being full + map2.spill() + // and then we spill again, even though we haven't added anything else, because something + // external requests us to free memory + map2.spill() + assert(map2.iterator.toIndexedSeq == expected) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index a1a7ac97d924b..ecad77122f481 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -608,4 +608,28 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } } } + + test("SPARK-14560 -- force spill an empty collection") { + // You should be able to force-spill a collection any time -- even if there is nothing + // to spill (eg., nothing has been added, or it just spilled) + val conf = createSparkConf(loadDefaults = true, kryo = true) + sc = new SparkContext("local", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val sorter = + new ExternalSorter[Int, Int, Int](context, None, None, None) + sorter.spill() + assert(sorter.iterator.toIndexedSeq == IndexedSeq()) + + + val sorter2 = + new ExternalSorter[Int, Int, Int](context, None, None, None) + val elements = IndexedSeq((1, 1), (2, 2), (5, 5)) + sorter2.insertAll(elements.iterator) + // say the first spill is natural due to the collection being full + sorter2.spill() + // and then we spill again, even though we haven't added anything else, because something + // external requests us to free memory + sorter2.spill() + assert(sorter2.iterator.toIndexedSeq == elements) + } }