Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ public BytesToBytesMap(
TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
}
allocate(initialCapacity);

// Acquire a new page as soon as we construct the map to ensure that we have at least
// one page to work with. Otherwise, other operators in the same task may starve this
// map (SPARK-9747).
acquireNewPage();
}

public BytesToBytesMap(
Expand Down Expand Up @@ -574,16 +579,9 @@ public boolean putNewKey(
final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
}
final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryGranted != pageSizeBytes) {
shuffleMemoryManager.release(memoryGranted);
logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
if (!acquireNewPage()) {
return false;
}
MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
dataPages.add(newPage);
pageCursor = 0;
currentDataPage = newPage;
dataPage = currentDataPage;
dataPageBaseObject = currentDataPage.getBaseObject();
dataPageInsertOffset = currentDataPage.getBaseOffset();
Expand Down Expand Up @@ -642,6 +640,24 @@ public boolean putNewKey(
}
}

/**
* Acquire a new page from the {@link ShuffleMemoryManager}.
* @return whether there is enough space to allocate the new page.
*/
private boolean acquireNewPage() {
final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryGranted != pageSizeBytes) {
shuffleMemoryManager.release(memoryGranted);
logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
return false;
}
MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
dataPages.add(newPage);
pageCursor = 0;
currentDataPage = newPage;
return true;
}

/**
* Allocate new data structures for this map. When calling this outside of the constructor,
* make sure to keep references to the old data structures so that you can free them.
Expand Down Expand Up @@ -748,7 +764,7 @@ public long getNumHashCollisions() {
}

@VisibleForTesting
int getNumDataPages() {
public int getNumDataPages() {
return dataPages.size();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,15 @@ private UnsafeExternalSorter(

if (existingInMemorySorter == null) {
initializeForWriting();
// Acquire a new page as soon as we construct the sorter to ensure that we have at
// least one page to work with. Otherwise, other operators in the same task may starve
// this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter.
acquireNewPage();
} else {
this.isInMemSorterExternal = true;
this.inMemSorter = existingInMemorySorter;
}

// Acquire a new page as soon as we construct the sorter to ensure that we have at
// least one page to work with. Otherwise, other operators in the same task may starve
// this sorter (SPARK-9709).
acquireNewPage();

// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
// does not fully consume the sorter's output (e.g. sort followed by limit).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ public void testPeakMemoryUsed() {
Platform.LONG_ARRAY_OFFSET,
8);
newPeakMemory = map.getPeakMemoryUsedBytes();
if (i % numRecordsPerPage == 0) {
if (i % numRecordsPerPage == 0 && i > 0) {
// We allocated a new page for this record, so peak memory should change
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
} else {
Expand All @@ -561,4 +561,13 @@ public void testPeakMemoryUsed() {
map.free();
}
}

@Test
public void testAcquirePageInConstructor() {
final BytesToBytesMap map = new BytesToBytesMap(
taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
assertEquals(1, map.getNumDataPages());
map.free();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import java.io.IOException;

import com.google.common.annotations.VisibleForTesting;

import org.apache.spark.SparkEnv;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
Expand Down Expand Up @@ -220,6 +222,11 @@ public long getPeakMemoryUsedBytes() {
return map.getPeakMemoryUsedBytes();
}

@VisibleForTesting
public int getNumDataPages() {
return map.getNumDataPages();
}

/**
* Free the memory associated with this map. This is idempotent and can be called multiple times.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.sql.execution.aggregate

import org.apache.spark.rdd.RDD
import org.apache.spark.TaskContext
import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics

Expand Down Expand Up @@ -68,35 +69,56 @@ case class TungstenAggregate(
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numInputRows = longMetric("numInputRows")
val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitions { iter =>
val hasInput = iter.hasNext
if (!hasInput && groupingExpressions.nonEmpty) {
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
Iterator.empty.asInstanceOf[Iterator[UnsafeRow]]
} else {
val aggregationIterator =
new TungstenAggregationIterator(
groupingExpressions,
nonCompleteAggregateExpressions,
completeAggregateExpressions,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
child.output,
iter,
testFallbackStartsAt,
numInputRows,
numOutputRows)

if (!hasInput && groupingExpressions.isEmpty) {

/**
* Set up the underlying unsafe data structures used before computing the parent partition.
* This makes sure our iterator is not starved by other operators in the same task.
*/
def preparePartition(): TungstenAggregationIterator = {
new TungstenAggregationIterator(
groupingExpressions,
nonCompleteAggregateExpressions,
completeAggregateExpressions,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
child.output,
testFallbackStartsAt,
numInputRows,
numOutputRows)
}

/** Compute a partition using the iterator already set up previously. */
def executePartition(
context: TaskContext,
partitionIndex: Int,
aggregationIterator: TungstenAggregationIterator,
parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = {
val hasInput = parentIterator.hasNext
if (!hasInput) {
// We're not using the underlying map, so we just can free it here
aggregationIterator.free()
if (groupingExpressions.isEmpty) {
numOutputRows += 1
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
} else {
aggregationIterator
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
Iterator[UnsafeRow]()
}
} else {
aggregationIterator.start(parentIterator)
aggregationIterator
}
}

// Note: we need to set up the iterator in each partition before computing the
// parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747).
val resultRdd = {
new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator](
child.execute(), preparePartition, executePartition, preservesPartitioning = true)
}
resultRdd.asInstanceOf[RDD[InternalRow]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just return resultRdd? Seems we do not need to cast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually result RDD is of type RDD[UnsafeRow]. Since RDDs are not covariant I think we do need the cast.

}

override def simpleString: String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ import org.apache.spark.sql.types.StructType
* the function used to create mutable projections.
* @param originalInputAttributes
* attributes of representing input rows from `inputIter`.
* @param inputIter
* the iterator containing input [[UnsafeRow]]s.
*/
class TungstenAggregationIterator(
groupingExpressions: Seq[NamedExpression],
Expand All @@ -83,12 +81,14 @@ class TungstenAggregationIterator(
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
originalInputAttributes: Seq[Attribute],
inputIter: Iterator[InternalRow],
testFallbackStartsAt: Option[Int],
numInputRows: LongSQLMetric,
numOutputRows: LongSQLMetric)
extends Iterator[UnsafeRow] with Logging {

// The parent partition iterator, to be initialized later in `start`
private[this] var inputIter: Iterator[InternalRow] = null

///////////////////////////////////////////////////////////////////////////
// Part 1: Initializing aggregate functions.
///////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -348,11 +348,15 @@ class TungstenAggregationIterator(
false // disable tracking of performance metrics
)

// Exposed for testing
private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap

// The function used to read and process input rows. When processing input rows,
// it first uses hash-based aggregation by putting groups and their buffers in
// hashMap. If we could not allocate more memory for the map, we switch to
// sort-based aggregation (by calling switchToSortBasedAggregation).
private def processInputs(): Unit = {
assert(inputIter != null, "attempted to process input when iterator was null")
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
numInputRows += 1
Expand All @@ -372,6 +376,7 @@ class TungstenAggregationIterator(
// that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
// been processed.
private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
assert(inputIter != null, "attempted to process input when iterator was null")
var i = 0
while (!sortBased && inputIter.hasNext) {
val newInput = inputIter.next()
Expand Down Expand Up @@ -412,6 +417,7 @@ class TungstenAggregationIterator(
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
*/
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
assert(inputIter != null, "attempted to process input when iterator was null")
logInfo("falling back to sort based aggregation.")
// Step 1: Get the ExternalSorter containing sorted entries of the map.
externalSorter = hashMap.destructAndCreateExternalSorter()
Expand All @@ -431,6 +437,11 @@ class TungstenAggregationIterator(
case _ => false
}

// Note: Since we spill the sorter's contents immediately after creating it, we must insert
// something into the sorter here to ensure that we acquire at least a page of memory.
// This is done through `externalSorter.insertKV`, which will trigger the page allocation.
// Otherwise, children operators may steal the window of opportunity and starve our sorter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we explicitly say that externalSorter.insertKV(firstKey, buffer) will trigger the page allocation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (needsProcess) {
// First, we create a buffer.
val buffer = createNewAggregationBuffer()
Expand Down Expand Up @@ -588,27 +599,33 @@ class TungstenAggregationIterator(
// have not switched to sort-based aggregation.
///////////////////////////////////////////////////////////////////////////

// Starts to process input rows.
testFallbackStartsAt match {
case None =>
processInputs()
case Some(fallbackStartsAt) =>
// This is the testing path. processInputsWithControlledFallback is same as processInputs
// except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
// have been processed.
processInputsWithControlledFallback(fallbackStartsAt)
}
/**
* Start processing input rows.
* Only after this method is called will this iterator be non-empty.
*/
def start(parentIter: Iterator[InternalRow]): Unit = {
inputIter = parentIter
testFallbackStartsAt match {
case None =>
processInputs()
case Some(fallbackStartsAt) =>
// This is the testing path. processInputsWithControlledFallback is same as processInputs
// except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
// have been processed.
processInputsWithControlledFallback(fallbackStartsAt)
}

// If we did not switch to sort-based aggregation in processInputs,
// we pre-load the first key-value pair from the map (to make hasNext idempotent).
if (!sortBased) {
// First, set aggregationBufferMapIterator.
aggregationBufferMapIterator = hashMap.iterator()
// Pre-load the first key-value pair from the aggregationBufferMapIterator.
mapIteratorHasNext = aggregationBufferMapIterator.next()
// If the map is empty, we just free it.
if (!mapIteratorHasNext) {
hashMap.free()
// If we did not switch to sort-based aggregation in processInputs,
// we pre-load the first key-value pair from the map (to make hasNext idempotent).
if (!sortBased) {
// First, set aggregationBufferMapIterator.
aggregationBufferMapIterator = hashMap.iterator()
// Pre-load the first key-value pair from the aggregationBufferMapIterator.
mapIteratorHasNext = aggregationBufferMapIterator.next()
// If the map is empty, we just free it.
if (!mapIteratorHasNext) {
hashMap.free()
}
}
}

Expand Down Expand Up @@ -673,21 +690,20 @@ class TungstenAggregationIterator(
}

///////////////////////////////////////////////////////////////////////////
// Part 8: A utility function used to generate a output row when there is no
// input and there is no grouping expression.
// Part 8: Utility functions
///////////////////////////////////////////////////////////////////////////

/**
* Generate a output row when there is no input and there is no grouping expression.
*/
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
if (groupingExpressions.isEmpty) {
sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
// We create a output row and copy it. So, we can free the map.
val resultCopy =
generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy()
hashMap.free()
resultCopy
} else {
throw new IllegalStateException(
"This method should not be called when groupingExpressions is not empty.")
}
assert(groupingExpressions.isEmpty)
assert(inputIter == null)
generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer)
}

/** Free memory used in the underlying map. */
def free(): Unit = {
hashMap.free()
}
}
Loading