Skip to content

Commit e541d64

Browse files
author
Andrew Or
committed
Track aggregation memory for both sort and hash
Previous commit only tracked it for hash-based aggregation. We should track it in both cases.
1 parent 0be3a42 commit e541d64

File tree

5 files changed

+66
-20
lines changed

5 files changed

+66
-20
lines changed

core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ public final class BytesToBytesMap {
166166

167167
private long numHashCollisions = 0;
168168

169+
private long peakMemoryUsedBytes = 0L;
170+
169171
public BytesToBytesMap(
170172
TaskMemoryManager taskMemoryManager,
171173
ShuffleMemoryManager shuffleMemoryManager,
@@ -658,6 +660,7 @@ private void allocate(int capacity) {
658660
* This method is idempotent and can be called multiple times.
659661
*/
660662
public void free() {
663+
updatePeakMemoryUsed();
661664
longArray = null;
662665
bitset = null;
663666
Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
@@ -684,7 +687,6 @@ public long getPageSizeBytes() {
684687

685688
/**
686689
* Returns the total amount of memory, in bytes, consumed by this map's managed structures.
687-
* Note that this is also the peak memory used by this map, since the map is append-only.
688690
*/
689691
public long getTotalMemoryConsumption() {
690692
long totalDataPagesSize = 0L;
@@ -694,6 +696,21 @@ public long getTotalMemoryConsumption() {
694696
return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size();
695697
}
696698

699+
private void updatePeakMemoryUsed() {
700+
long mem = getTotalMemoryConsumption();
701+
if (mem > peakMemoryUsedBytes) {
702+
peakMemoryUsedBytes = mem;
703+
}
704+
}
705+
706+
/**
707+
* Return the peak memory used so far, in bytes.
708+
*/
709+
public long getPeakMemoryUsedBytes() {
710+
updatePeakMemoryUsed();
711+
return peakMemoryUsedBytes;
712+
}
713+
697714
/**
698715
* Returns the total amount of time spent resizing this map (in nanoseconds).
699716
*/

core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ public void resizingLargeMap() {
525525
}
526526

527527
@Test
528-
public void testTotalMemoryConsumption() {
528+
public void testPeakMemoryUsed() {
529529
final long recordLengthBytes = 24;
530530
final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
531531
final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
@@ -536,8 +536,8 @@ public void testTotalMemoryConsumption() {
536536
// monotonically increasing. More specifically, every time we allocate a new page it
537537
// should increase by exactly the size of the page. In this regard, the memory usage
538538
// at any given time is also the peak memory used.
539-
long previousMemory = map.getTotalMemoryConsumption();
540-
long newMemory;
539+
long previousPeakMemory = map.getPeakMemoryUsedBytes();
540+
long newPeakMemory;
541541
try {
542542
for (long i = 0; i < numRecordsPerPage * 10; i++) {
543543
final long[] value = new long[]{i};
@@ -548,15 +548,21 @@ public void testTotalMemoryConsumption() {
548548
value,
549549
PlatformDependent.LONG_ARRAY_OFFSET,
550550
8);
551-
newMemory = map.getTotalMemoryConsumption();
551+
newPeakMemory = map.getPeakMemoryUsedBytes();
552552
if (i % numRecordsPerPage == 0) {
553553
// We allocated a new page for this record, so peak memory should change
554-
assertEquals(previousMemory + pageSizeBytes, newMemory);
554+
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
555555
} else {
556-
assertEquals(previousMemory, newMemory);
556+
assertEquals(previousPeakMemory, newPeakMemory);
557557
}
558-
previousMemory = newMemory;
558+
previousPeakMemory = newPeakMemory;
559559
}
560+
561+
// Freeing the map should not change the peak memory
562+
map.free();
563+
newPeakMemory = map.getPeakMemoryUsedBytes();
564+
assertEquals(previousPeakMemory, newPeakMemory);
565+
560566
} finally {
561567
map.free();
562568
}

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,10 @@ public void close() {
210210
}
211211

212212
/**
213-
* The memory used by this map's managed structures, in bytes.
214-
* Note that this is also the peak memory used by this map, since the map is append-only.
213+
* Return the peak memory used so far, in bytes.
215214
*/
216-
public long getMemoryUsage() {
217-
return map.getTotalMemoryConsumption();
215+
public long getPeakMemoryUsedBytes() {
216+
return map.getPeakMemoryUsedBytes();
218217
}
219218

220219
/**

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,13 @@ public KVSorterIterator sortedIterator() throws IOException {
159159
}
160160
}
161161

162+
/**
163+
* Return the peak memory used so far, in bytes.
164+
*/
165+
public long getPeakMemoryUsedBytes() {
166+
return sorter.getPeakMemoryUsedBytes();
167+
}
168+
162169
/**
163170
* Marks the current page as no-more-space-available, and as a result, either allocate a
164171
* new page or spill when we see the next record.

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -397,14 +397,20 @@ class TungstenAggregationIterator(
397397
private[this] var mapIteratorHasNext: Boolean = false
398398

399399
///////////////////////////////////////////////////////////////////////////
400-
// Part 4: The function used to switch this iterator from hash-based
401-
// aggregation to sort-based aggregation.
400+
// Part 3: Methods and fields used by sort-based aggregation.
402401
///////////////////////////////////////////////////////////////////////////
403402

403+
// This sorter is used for sort-based aggregation. It is initialized as soon as
404+
// we switch from hash-based to sort-based aggregation. Otherwise, it is not used.
405+
private[this] var externalSorter: UnsafeKVExternalSorter = null
406+
407+
/**
408+
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
409+
*/
404410
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = {
405411
logInfo("falling back to sort based aggregation.")
406412
// Step 1: Get the ExternalSorter containing sorted entries of the map.
407-
val externalSorter: UnsafeKVExternalSorter = hashMap.destructAndCreateExternalSorter()
413+
externalSorter = hashMap.destructAndCreateExternalSorter()
408414

409415
// Step 2: Free the memory used by the map.
410416
hashMap.free()
@@ -601,7 +607,7 @@ class TungstenAggregationIterator(
601607
}
602608

603609
///////////////////////////////////////////////////////////////////////////
604-
// Par 7: Iterator's public methods.
610+
// Part 7: Iterator's public methods.
605611
///////////////////////////////////////////////////////////////////////////
606612

607613
override final def hasNext: Boolean = {
@@ -610,7 +616,7 @@ class TungstenAggregationIterator(
610616

611617
override final def next(): UnsafeRow = {
612618
if (hasNext) {
613-
if (sortBased) {
619+
val res = if (sortBased) {
614620
// Process the current group.
615621
processCurrentSortedGroup()
616622
// Generate output row for the current group.
@@ -633,9 +639,6 @@ class TungstenAggregationIterator(
633639
if (!mapIteratorHasNext) {
634640
// If there is no input from aggregationBufferMapIterator, we copy current result.
635641
val resultCopy = result.copy()
636-
// Report memory usage metrics.
637-
TaskContext.get().internalMetricsToAccumulators(
638-
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(hashMap.getMemoryUsage)
639642
// Then, we free the map.
640643
hashMap.free()
641644

@@ -644,6 +647,19 @@ class TungstenAggregationIterator(
644647
result
645648
}
646649
}
650+
651+
// If this is the last record, update the task's peak memory usage. Since we destroy
652+
// the map to create the sorter, their memory usages should not overlap, so it is safe
653+
// to just use the max of the two.
654+
if (!hasNext) {
655+
val mapMemory = hashMap.getPeakMemoryUsedBytes
656+
val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
657+
val peakMemory = Math.max(mapMemory, sorterMemory)
658+
TaskContext.get().internalMetricsToAccumulators(
659+
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory)
660+
}
661+
662+
res
647663
} else {
648664
// no more result
649665
throw new NoSuchElementException
@@ -654,6 +670,7 @@ class TungstenAggregationIterator(
654670
// Part 8: A utility function used to generate a output row when there is no
655671
// input and there is no grouping expression.
656672
///////////////////////////////////////////////////////////////////////////
673+
657674
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
658675
if (groupingExpressions.isEmpty) {
659676
sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)

0 commit comments

Comments
 (0)