Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ protected MemoryConsumer(TaskMemoryManager taskMemoryManager) {
/**
* Returns the size of used memory in bytes.
*/
long getUsed() {
public long getUsed() {
return used;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ public class TaskMemoryManager {

private final long taskAttemptId;

public long getTaskAttemptId() {
return taskAttemptId;
}

/**
* Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods
* without doing any masking or lookups. Since this branching should be well-predicted by the JIT,
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/org/apache/spark/Aggregator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ case class Aggregator[K, V, C] (
context: TaskContext): Iterator[(K, C)] = {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
combiners.insertAll(iter)
if (SparkEnv.get.conf.getBoolean("spark.shuffle.spillAfterRead", true)) {
combiners.spill()
}
updateMetrics(context, combiners)
combiners.iterator
}
Expand All @@ -48,6 +51,9 @@ case class Aggregator[K, V, C] (
context: TaskContext): Iterator[(K, C)] = {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
combiners.insertAll(iter)
if (SparkEnv.get.conf.getBoolean("spark.shuffle.spillAfterRead", true)) {
combiners.spill()
}
updateMetrics(context, combiners)
combiners.iterator
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ private[spark] class BlockStoreShuffleReader[K, C](
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
if (SparkEnv.get.conf.getBoolean("spark.shuffle.spillAfterRead", true)) {
sorter.spill()
}
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,11 @@ class ExternalAppendOnlyMap[K, V, C](

while (entries.hasNext) {
curEntry = entries.next()
val estimatedSize = currentMap.estimateSize()
val estimatedSize = estimateUsedMemory
if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
}
if (maybeSpill(currentMap, estimatedSize)) {
currentMap = new SizeTrackingAppendOnlyMap[K, C]
}
maybeSpill(estimatedSize)
currentMap.changeValue(curEntry._1, update)
addElementsRead()
}
Expand All @@ -178,10 +176,18 @@ class ExternalAppendOnlyMap[K, V, C](
insertAll(entries.iterator)
}

def estimateUsedMemory(): Long = {
currentMap.estimateSize()
}

protected def resetAfterSpill(): Unit = {
currentMap = new SizeTrackingAppendOnlyMap[K, C]
}

/**
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
*/
override protected[this] def spill(collection: SizeTracker): Unit = {
override protected[this] def spillCollection(): Unit = {
val (blockId, file) = diskBlockManager.createTempLocalBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
Expand Down Expand Up @@ -534,6 +540,13 @@ class ExternalAppendOnlyMap[K, V, C](

/** Convenience function to hash the given (K, C) pair by the key. */
private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1)

/**
* To prevent debug code from printing out the contents of the iterator, and destroying the data
*/
override def toString(): String = {
getClass().getSimpleName + "@" + System.identityHashCode(this)
}
}

private[spark] object ExternalAppendOnlyMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ private[spark] class ExternalSorter[K, V, C](
private var map = new PartitionedAppendOnlyMap[K, C]
private var buffer = new PartitionedPairBuffer[K, C]

private[this] val usingMap = aggregator.isDefined

// Total spilling statistics
private var _diskBytesSpilled = 0L
def diskBytesSpilled: Long = _diskBytesSpilled
Expand Down Expand Up @@ -177,9 +179,8 @@ private[spark] class ExternalSorter[K, V, C](

def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined

if (shouldCombine) {
if (usingMap) {
// Combine values in-memory first using our AppendOnlyMap
val mergeValue = aggregator.get.mergeValue
val createCombiner = aggregator.get.createCombiner
Expand All @@ -191,37 +192,44 @@ private[spark] class ExternalSorter[K, V, C](
addElementsRead()
kv = records.next()
map.changeValue((getPartition(kv._1), kv._1), update)
maybeSpillCollection(usingMap = true)
maybeSpillCollection()
}
} else {
// Stick values into our buffer
while (records.hasNext) {
addElementsRead()
val kv = records.next()
buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)
maybeSpillCollection()
}
}
}

override protected[this] def resetAfterSpill(): Unit = {
if (usingMap) {
map = new PartitionedAppendOnlyMap[K, C]
} else {
buffer = new PartitionedPairBuffer[K, C]
}
}

override def estimateUsedMemory: Long = {
if (usingMap) {
map.estimateSize()
} else {
buffer.estimateSize()
}
}


/**
* Spill the current in-memory collection to disk if needed.
*
* @param usingMap whether we're using a map or buffer as our current in-memory collection
*/
private def maybeSpillCollection(usingMap: Boolean): Unit = {
var estimatedSize = 0L
if (usingMap) {
estimatedSize = map.estimateSize()
if (maybeSpill(map, estimatedSize)) {
map = new PartitionedAppendOnlyMap[K, C]
}
} else {
estimatedSize = buffer.estimateSize()
if (maybeSpill(buffer, estimatedSize)) {
buffer = new PartitionedPairBuffer[K, C]
}
}
private def maybeSpillCollection(): Unit = {
val estimatedSize = estimateUsedMemory
maybeSpill(estimatedSize)

if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
Expand All @@ -234,7 +242,15 @@ private[spark] class ExternalSorter[K, V, C](
*
* @param collection whichever collection we're using (map or buffer)
*/
override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
override protected[this] def spillCollection(): Unit = {
if (usingMap) {
spillCollection(map)
} else {
spillCollection(buffer)
}
}

protected[this] def spillCollection(collection: WritablePartitionedPairCollection[K, C]): Unit = {
// Because these files may be read during shuffle, their compression must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
// createTempShuffleBlock here; see SPARK-3426 for more context.
Expand Down
110 changes: 96 additions & 14 deletions core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.util.collection

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager}

/**
* Spills contents of an in-memory collection to disk when the memory threshold
Expand All @@ -28,10 +28,53 @@ import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
private[spark] trait Spillable[C] extends Logging {
/**
* Spills the current in-memory collection to disk, and releases the memory.
*/
protected def spillCollection(): Unit

/**
* After a spill, reset any internal data structures so they are ready to accept more input data
*/
protected def resetAfterSpill(): Unit

/**
* Return an estimate of the current memory used by the collection.
*
* @param collection collection to spill to disk
* Note this is *not* the same as the memory requested from the memory manager, for two reasons:
* (1) If we allow the collection to use some initial amount of memory that is untracked, that
* should still be reported here. (which would lead to this amount being larger than what is
* tracked by the memory manager.)
* (2) If we've just requested a large increase in memory from the memory manager, but aren't
* actually *using* that memory yet, we will not report it here (which would lead to this amount
* being smaller than what is tracked by the memory manager.)
*/
protected def spill(collection: C): Unit
def estimateUsedMemory(): Long

/**
* Spills the in-memory collection, releases memory, and updates metrics. This can be
* used to force a spill, even if this collection beleives it still has extra memory, to
* free up memory for other operators. For example, during a stage which does a shuffle-read
* and a shuffle-write, after the shuffle-read is finished, we can spill to free up memory
* for the shuffle-write.
* [[maybeSpill]] can be used when the collection
* should only spill if it doesn't have enough memory
*/
final def spill(): Unit = {
spill(estimateUsedMemory())
}

final def spill(currentMemory: Long): Unit = {
if (_elementsRead == 0) {
logDebug(s"Skipping spill since ${this} is empty")
} else {
_spillCount += 1
logSpillage(currentMemory)
spillCollection()
_elementsRead = 0
_memoryBytesSpilled += currentMemory
releaseMemory()
resetAfterSpill()
}
}

// Number of elements read from input since last spill
protected def elementsRead: Long = _elementsRead
Expand All @@ -57,6 +100,17 @@ private[spark] trait Spillable[C] extends Logging {
// To avoid a large number of small spills, initialize this to a value orders of magnitude > 0
private[this] var myMemoryThreshold = initialMemoryThreshold

/**
* The amount of memory that has been allocated to this Spillable by the memory manager.
*
* Note that this is *not* the same as [[estimateUsedMemory]] -- see the doc on that method
* for why these differ
*/
def allocatedMemory: Long = {
// we don't ever request initialMemoryThreshold from the memory manager
myMemoryThreshold - initialMemoryThreshold
}

// Number of elements read from input since last spill
private[this] var _elementsRead = 0L

Expand All @@ -66,21 +120,24 @@ private[spark] trait Spillable[C] extends Logging {
// Number of spills
private[this] var _spillCount = 0

private[this] val memoryConsumer = new SpillableMemoryConsumer(this, taskMemoryManager)

/**
* Spills the current in-memory collection to disk if needed. Attempts to acquire more
* memory before spilling.
* memory before spilling. If this does spill, it will call [[resetAfterSpill()]] to
* prepare the in-memory data structures to accept more data
*
* @param collection collection to spill to disk
* @param currentMemory estimated size of the collection in bytes
* @return true if `collection` was spilled to disk; false otherwise
*/
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
protected def maybeSpill(currentMemory: Long): Boolean = {
var shouldSpill = false
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
val granted =
taskMemoryManager.acquireExecutionMemory(amountToRequest, MemoryMode.ON_HEAP, null)
taskMemoryManager.acquireExecutionMemory(amountToRequest, MemoryMode.ON_HEAP,
memoryConsumer)
myMemoryThreshold += granted
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold), spill the current collection
Expand All @@ -89,12 +146,7 @@ private[spark] trait Spillable[C] extends Logging {
shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
// Actually spill
if (shouldSpill) {
_spillCount += 1
logSpillage(currentMemory)
spill(collection)
_elementsRead = 0
_memoryBytesSpilled += currentMemory
releaseMemory()
spill(currentMemory)
}
shouldSpill
}
Expand All @@ -110,7 +162,7 @@ private[spark] trait Spillable[C] extends Logging {
def releaseMemory(): Unit = {
// The amount we requested does not include the initial memory tracking threshold
taskMemoryManager.releaseExecutionMemory(
myMemoryThreshold - initialMemoryThreshold, MemoryMode.ON_HEAP, null)
myMemoryThreshold - initialMemoryThreshold, MemoryMode.ON_HEAP, memoryConsumer)
myMemoryThreshold = initialMemoryThreshold
}

Expand All @@ -126,3 +178,33 @@ private[spark] trait Spillable[C] extends Logging {
_spillCount, if (_spillCount > 1) "s" else ""))
}
}

/**
* A light-wrapper around Spillables to implement MemoryConsumer, just so that
* they can be tracked and logged in TaskMemoryManager.
*
* Note that this does *not* give cooperative memory management for Spillables, its just to
* make debug logs clearly on memory usage.
*/
class SpillableMemoryConsumer(val sp: Spillable[_], val taskMM: TaskMemoryManager)
extends MemoryConsumer(taskMM) with Logging {
def spill(size: Long, trigger: MemoryConsumer): Long = {
// If another memory consumer requests more memory, we can't easily spill here. The
// problem is that even if we do spill, there may be an iterator that is already
// reading from the in-memory data structures, which would hold a reference to that
// object even if we spilled. So even if we spilled, we aren't *actually* freeing memory
// unless we update any in-flight iterators to switch to the spilled data
logDebug(s"Spill requested for ${sp} (TID ${taskMemoryManager.getTaskAttemptId}) by " +
s"${trigger}, but ${this} can't spill")
0L
}

override def toString(): String = {
s"SpillableConsumer($sp)"
}

override def getUsed(): Long = {
sp.allocatedMemory
}

}
Loading