Skip to content

Commit 4189fdc

Browse files
Andrew OrCodingCat
authored andcommitted
[SPARK-9747] [SQL] Avoid starving an unsafe operator in aggregation
This is the sister patch to apache#8011, but for aggregation. In a nutshell: create the `TungstenAggregationIterator` before computing the parent partition. Internally this creates a `BytesToBytesMap` which acquires a page in the constructor as of this patch. This ensures that the aggregation operator is not starved since we reserve at least 1 page in advance. rxin yhuai Author: Andrew Or <[email protected]> Closes apache#8038 from andrewor14/unsafe-starve-memory-agg.
1 parent 48a2f5d commit 4189fdc

File tree

7 files changed

+201
-76
lines changed

7 files changed

+201
-76
lines changed

core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ public BytesToBytesMap(
193193
TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
194194
}
195195
allocate(initialCapacity);
196+
197+
// Acquire a new page as soon as we construct the map to ensure that we have at least
198+
// one page to work with. Otherwise, other operators in the same task may starve this
199+
// map (SPARK-9747).
200+
acquireNewPage();
196201
}
197202

198203
public BytesToBytesMap(
@@ -574,16 +579,9 @@ public boolean putNewKey(
574579
final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
575580
Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
576581
}
577-
final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
578-
if (memoryGranted != pageSizeBytes) {
579-
shuffleMemoryManager.release(memoryGranted);
580-
logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
582+
if (!acquireNewPage()) {
581583
return false;
582584
}
583-
MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
584-
dataPages.add(newPage);
585-
pageCursor = 0;
586-
currentDataPage = newPage;
587585
dataPage = currentDataPage;
588586
dataPageBaseObject = currentDataPage.getBaseObject();
589587
dataPageInsertOffset = currentDataPage.getBaseOffset();
@@ -642,6 +640,24 @@ public boolean putNewKey(
642640
}
643641
}
644642

643+
/**
644+
* Acquire a new page from the {@link ShuffleMemoryManager}.
645+
* @return whether there is enough space to allocate the new page.
646+
*/
647+
private boolean acquireNewPage() {
648+
final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
649+
if (memoryGranted != pageSizeBytes) {
650+
shuffleMemoryManager.release(memoryGranted);
651+
logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
652+
return false;
653+
}
654+
MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
655+
dataPages.add(newPage);
656+
pageCursor = 0;
657+
currentDataPage = newPage;
658+
return true;
659+
}
660+
645661
/**
646662
* Allocate new data structures for this map. When calling this outside of the constructor,
647663
* make sure to keep references to the old data structures so that you can free them.
@@ -748,7 +764,7 @@ public long getNumHashCollisions() {
748764
}
749765

750766
@VisibleForTesting
751-
int getNumDataPages() {
767+
public int getNumDataPages() {
752768
return dataPages.size();
753769
}
754770

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,15 @@ private UnsafeExternalSorter(
132132

133133
if (existingInMemorySorter == null) {
134134
initializeForWriting();
135+
// Acquire a new page as soon as we construct the sorter to ensure that we have at
136+
// least one page to work with. Otherwise, other operators in the same task may starve
137+
// this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter.
138+
acquireNewPage();
135139
} else {
136140
this.isInMemSorterExternal = true;
137141
this.inMemSorter = existingInMemorySorter;
138142
}
139143

140-
// Acquire a new page as soon as we construct the sorter to ensure that we have at
141-
// least one page to work with. Otherwise, other operators in the same task may starve
142-
// this sorter (SPARK-9709).
143-
acquireNewPage();
144-
145144
// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
146145
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
147146
// does not fully consume the sorter's output (e.g. sort followed by limit).

core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ public void testPeakMemoryUsed() {
543543
Platform.LONG_ARRAY_OFFSET,
544544
8);
545545
newPeakMemory = map.getPeakMemoryUsedBytes();
546-
if (i % numRecordsPerPage == 0) {
546+
if (i % numRecordsPerPage == 0 && i > 0) {
547547
// We allocated a new page for this record, so peak memory should change
548548
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
549549
} else {
@@ -561,4 +561,13 @@ public void testPeakMemoryUsed() {
561561
map.free();
562562
}
563563
}
564+
565+
@Test
566+
public void testAcquirePageInConstructor() {
567+
final BytesToBytesMap map = new BytesToBytesMap(
568+
taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
569+
assertEquals(1, map.getNumDataPages());
570+
map.free();
571+
}
572+
564573
}

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import java.io.IOException;
2121

22+
import com.google.common.annotations.VisibleForTesting;
23+
2224
import org.apache.spark.SparkEnv;
2325
import org.apache.spark.shuffle.ShuffleMemoryManager;
2426
import org.apache.spark.sql.catalyst.InternalRow;
@@ -220,6 +222,11 @@ public long getPeakMemoryUsedBytes() {
220222
return map.getPeakMemoryUsedBytes();
221223
}
222224

225+
@VisibleForTesting
226+
public int getNumDataPages() {
227+
return map.getNumDataPages();
228+
}
229+
223230
/**
224231
* Free the memory associated with this map. This is idempotent and can be called multiple times.
225232
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
package org.apache.spark.sql.execution.aggregate
1919

20-
import org.apache.spark.rdd.RDD
20+
import org.apache.spark.TaskContext
21+
import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
2122
import org.apache.spark.sql.catalyst.InternalRow
2223
import org.apache.spark.sql.catalyst.errors._
2324
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
2425
import org.apache.spark.sql.catalyst.expressions._
25-
import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
26+
import org.apache.spark.sql.catalyst.plans.physical._
2627
import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
2728
import org.apache.spark.sql.execution.metric.SQLMetrics
2829

@@ -68,35 +69,56 @@ case class TungstenAggregate(
6869
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
6970
val numInputRows = longMetric("numInputRows")
7071
val numOutputRows = longMetric("numOutputRows")
71-
child.execute().mapPartitions { iter =>
72-
val hasInput = iter.hasNext
73-
if (!hasInput && groupingExpressions.nonEmpty) {
74-
// This is a grouped aggregate and the input iterator is empty,
75-
// so return an empty iterator.
76-
Iterator.empty.asInstanceOf[Iterator[UnsafeRow]]
77-
} else {
78-
val aggregationIterator =
79-
new TungstenAggregationIterator(
80-
groupingExpressions,
81-
nonCompleteAggregateExpressions,
82-
completeAggregateExpressions,
83-
initialInputBufferOffset,
84-
resultExpressions,
85-
newMutableProjection,
86-
child.output,
87-
iter,
88-
testFallbackStartsAt,
89-
numInputRows,
90-
numOutputRows)
91-
92-
if (!hasInput && groupingExpressions.isEmpty) {
72+
73+
/**
74+
* Set up the underlying unsafe data structures used before computing the parent partition.
75+
* This makes sure our iterator is not starved by other operators in the same task.
76+
*/
77+
def preparePartition(): TungstenAggregationIterator = {
78+
new TungstenAggregationIterator(
79+
groupingExpressions,
80+
nonCompleteAggregateExpressions,
81+
completeAggregateExpressions,
82+
initialInputBufferOffset,
83+
resultExpressions,
84+
newMutableProjection,
85+
child.output,
86+
testFallbackStartsAt,
87+
numInputRows,
88+
numOutputRows)
89+
}
90+
91+
/** Compute a partition using the iterator already set up previously. */
92+
def executePartition(
93+
context: TaskContext,
94+
partitionIndex: Int,
95+
aggregationIterator: TungstenAggregationIterator,
96+
parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = {
97+
val hasInput = parentIterator.hasNext
98+
if (!hasInput) {
99+
// We're not using the underlying map, so we just can free it here
100+
aggregationIterator.free()
101+
if (groupingExpressions.isEmpty) {
93102
numOutputRows += 1
94103
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
95104
} else {
96-
aggregationIterator
105+
// This is a grouped aggregate and the input iterator is empty,
106+
// so return an empty iterator.
107+
Iterator[UnsafeRow]()
97108
}
109+
} else {
110+
aggregationIterator.start(parentIterator)
111+
aggregationIterator
98112
}
99113
}
114+
115+
// Note: we need to set up the iterator in each partition before computing the
116+
// parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747).
117+
val resultRdd = {
118+
new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator](
119+
child.execute(), preparePartition, executePartition, preservesPartitioning = true)
120+
}
121+
resultRdd.asInstanceOf[RDD[InternalRow]]
100122
}
101123

102124
override def simpleString: String = {

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ import org.apache.spark.sql.types.StructType
7272
* the function used to create mutable projections.
7373
* @param originalInputAttributes
7474
* attributes of representing input rows from `inputIter`.
75-
* @param inputIter
76-
* the iterator containing input [[UnsafeRow]]s.
7775
*/
7876
class TungstenAggregationIterator(
7977
groupingExpressions: Seq[NamedExpression],
@@ -83,12 +81,14 @@ class TungstenAggregationIterator(
8381
resultExpressions: Seq[NamedExpression],
8482
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
8583
originalInputAttributes: Seq[Attribute],
86-
inputIter: Iterator[InternalRow],
8784
testFallbackStartsAt: Option[Int],
8885
numInputRows: LongSQLMetric,
8986
numOutputRows: LongSQLMetric)
9087
extends Iterator[UnsafeRow] with Logging {
9188

89+
// The parent partition iterator, to be initialized later in `start`
90+
private[this] var inputIter: Iterator[InternalRow] = null
91+
9292
///////////////////////////////////////////////////////////////////////////
9393
// Part 1: Initializing aggregate functions.
9494
///////////////////////////////////////////////////////////////////////////
@@ -348,11 +348,15 @@ class TungstenAggregationIterator(
348348
false // disable tracking of performance metrics
349349
)
350350

351+
// Exposed for testing
352+
private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap
353+
351354
// The function used to read and process input rows. When processing input rows,
352355
// it first uses hash-based aggregation by putting groups and their buffers in
353356
// hashMap. If we could not allocate more memory for the map, we switch to
354357
// sort-based aggregation (by calling switchToSortBasedAggregation).
355358
private def processInputs(): Unit = {
359+
assert(inputIter != null, "attempted to process input when iterator was null")
356360
while (!sortBased && inputIter.hasNext) {
357361
val newInput = inputIter.next()
358362
numInputRows += 1
@@ -372,6 +376,7 @@ class TungstenAggregationIterator(
372376
// that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
373377
// been processed.
374378
private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
379+
assert(inputIter != null, "attempted to process input when iterator was null")
375380
var i = 0
376381
while (!sortBased && inputIter.hasNext) {
377382
val newInput = inputIter.next()
@@ -412,6 +417,7 @@ class TungstenAggregationIterator(
412417
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
413418
*/
414419
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
420+
assert(inputIter != null, "attempted to process input when iterator was null")
415421
logInfo("falling back to sort based aggregation.")
416422
// Step 1: Get the ExternalSorter containing sorted entries of the map.
417423
externalSorter = hashMap.destructAndCreateExternalSorter()
@@ -431,6 +437,11 @@ class TungstenAggregationIterator(
431437
case _ => false
432438
}
433439

440+
// Note: Since we spill the sorter's contents immediately after creating it, we must insert
441+
// something into the sorter here to ensure that we acquire at least a page of memory.
442+
// This is done through `externalSorter.insertKV`, which will trigger the page allocation.
443+
// Otherwise, children operators may steal the window of opportunity and starve our sorter.
444+
434445
if (needsProcess) {
435446
// First, we create a buffer.
436447
val buffer = createNewAggregationBuffer()
@@ -588,27 +599,33 @@ class TungstenAggregationIterator(
588599
// have not switched to sort-based aggregation.
589600
///////////////////////////////////////////////////////////////////////////
590601

591-
// Starts to process input rows.
592-
testFallbackStartsAt match {
593-
case None =>
594-
processInputs()
595-
case Some(fallbackStartsAt) =>
596-
// This is the testing path. processInputsWithControlledFallback is same as processInputs
597-
// except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
598-
// have been processed.
599-
processInputsWithControlledFallback(fallbackStartsAt)
600-
}
602+
/**
603+
* Start processing input rows.
604+
* Only after this method is called will this iterator be non-empty.
605+
*/
606+
def start(parentIter: Iterator[InternalRow]): Unit = {
607+
inputIter = parentIter
608+
testFallbackStartsAt match {
609+
case None =>
610+
processInputs()
611+
case Some(fallbackStartsAt) =>
612+
// This is the testing path. processInputsWithControlledFallback is same as processInputs
613+
// except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
614+
// have been processed.
615+
processInputsWithControlledFallback(fallbackStartsAt)
616+
}
601617

602-
// If we did not switch to sort-based aggregation in processInputs,
603-
// we pre-load the first key-value pair from the map (to make hasNext idempotent).
604-
if (!sortBased) {
605-
// First, set aggregationBufferMapIterator.
606-
aggregationBufferMapIterator = hashMap.iterator()
607-
// Pre-load the first key-value pair from the aggregationBufferMapIterator.
608-
mapIteratorHasNext = aggregationBufferMapIterator.next()
609-
// If the map is empty, we just free it.
610-
if (!mapIteratorHasNext) {
611-
hashMap.free()
618+
// If we did not switch to sort-based aggregation in processInputs,
619+
// we pre-load the first key-value pair from the map (to make hasNext idempotent).
620+
if (!sortBased) {
621+
// First, set aggregationBufferMapIterator.
622+
aggregationBufferMapIterator = hashMap.iterator()
623+
// Pre-load the first key-value pair from the aggregationBufferMapIterator.
624+
mapIteratorHasNext = aggregationBufferMapIterator.next()
625+
// If the map is empty, we just free it.
626+
if (!mapIteratorHasNext) {
627+
hashMap.free()
628+
}
612629
}
613630
}
614631

@@ -673,21 +690,20 @@ class TungstenAggregationIterator(
673690
}
674691

675692
///////////////////////////////////////////////////////////////////////////
676-
// Part 8: A utility function used to generate a output row when there is no
677-
// input and there is no grouping expression.
693+
// Part 8: Utility functions
678694
///////////////////////////////////////////////////////////////////////////
679695

696+
/**
697+
* Generate a output row when there is no input and there is no grouping expression.
698+
*/
680699
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
681-
if (groupingExpressions.isEmpty) {
682-
sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
683-
// We create a output row and copy it. So, we can free the map.
684-
val resultCopy =
685-
generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy()
686-
hashMap.free()
687-
resultCopy
688-
} else {
689-
throw new IllegalStateException(
690-
"This method should not be called when groupingExpressions is not empty.")
691-
}
700+
assert(groupingExpressions.isEmpty)
701+
assert(inputIter == null)
702+
generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer)
703+
}
704+
705+
/** Free memory used in the underlying map. */
706+
def free(): Unit = {
707+
hashMap.free()
692708
}
693709
}

0 commit comments

Comments
 (0)