Skip to content

Commit 5b00a27

Browse files
Andrew OrCodingCat
authored andcommitted
[SPARK-9674] Re-enable ignored test in SQLQuerySuite
The original code that this test tests is removed in apache@9270bd0. It was ignored shortly before that so we never caught it. This patch re-enables the test and adds the code necessary to make it pass. JoshRosen yhuai Author: Andrew Or <[email protected]> Closes apache#8015 from andrewor14/SPARK-9674 and squashes the following commits: 225eac2 [Andrew Or] Merge branch 'master' of github.com:apache/spark into SPARK-9674 8c24209 [Andrew Or] Fix NPE e541d64 [Andrew Or] Track aggregation memory for both sort and hash 0be3a42 [Andrew Or] Fix test
1 parent 38958d0 commit 5b00a27

File tree

6 files changed

+85
-26
lines changed

6 files changed

+85
-26
lines changed

core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ public final class BytesToBytesMap {
109109
* Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i},
110110
* while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode.
111111
*/
112-
private LongArray longArray;
112+
@Nullable private LongArray longArray;
113113
// TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode
114114
// and exploit word-alignment to use fewer bits to hold the address. This might let us store
115115
// only one long per map entry, increasing the chance that this array will fit in cache at the
@@ -124,7 +124,7 @@ public final class BytesToBytesMap {
124124
* A {@link BitSet} used to track location of the map where the key is set.
125125
* Size of the bitset should be half of the size of the long array.
126126
*/
127-
private BitSet bitset;
127+
@Nullable private BitSet bitset;
128128

129129
private final double loadFactor;
130130

@@ -166,6 +166,8 @@ public final class BytesToBytesMap {
166166

167167
private long numHashCollisions = 0;
168168

169+
private long peakMemoryUsedBytes = 0L;
170+
169171
public BytesToBytesMap(
170172
TaskMemoryManager taskMemoryManager,
171173
ShuffleMemoryManager shuffleMemoryManager,
@@ -321,6 +323,9 @@ public Location lookup(
321323
Object keyBaseObject,
322324
long keyBaseOffset,
323325
int keyRowLengthBytes) {
326+
assert(bitset != null);
327+
assert(longArray != null);
328+
324329
if (enablePerfMetrics) {
325330
numKeyLookups++;
326331
}
@@ -410,6 +415,7 @@ private void updateAddressesAndSizes(final Object page, final long offsetInPage)
410415
}
411416

412417
private Location with(int pos, int keyHashcode, boolean isDefined) {
418+
assert(longArray != null);
413419
this.pos = pos;
414420
this.isDefined = isDefined;
415421
this.keyHashcode = keyHashcode;
@@ -525,6 +531,9 @@ public boolean putNewKey(
525531
assert (!isDefined) : "Can only set value once for a key";
526532
assert (keyLengthBytes % 8 == 0);
527533
assert (valueLengthBytes % 8 == 0);
534+
assert(bitset != null);
535+
assert(longArray != null);
536+
528537
if (numElements == MAX_CAPACITY) {
529538
throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
530539
}
@@ -658,6 +667,7 @@ private void allocate(int capacity) {
658667
* This method is idempotent and can be called multiple times.
659668
*/
660669
public void free() {
670+
updatePeakMemoryUsed();
661671
longArray = null;
662672
bitset = null;
663673
Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
@@ -684,14 +694,30 @@ public long getPageSizeBytes() {
684694

685695
/**
686696
* Returns the total amount of memory, in bytes, consumed by this map's managed structures.
687-
* Note that this is also the peak memory used by this map, since the map is append-only.
688697
*/
689698
public long getTotalMemoryConsumption() {
690699
long totalDataPagesSize = 0L;
691700
for (MemoryBlock dataPage : dataPages) {
692701
totalDataPagesSize += dataPage.size();
693702
}
694-
return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size();
703+
return totalDataPagesSize +
704+
((bitset != null) ? bitset.memoryBlock().size() : 0L) +
705+
((longArray != null) ? longArray.memoryBlock().size() : 0L);
706+
}
707+
708+
private void updatePeakMemoryUsed() {
709+
long mem = getTotalMemoryConsumption();
710+
if (mem > peakMemoryUsedBytes) {
711+
peakMemoryUsedBytes = mem;
712+
}
713+
}
714+
715+
/**
716+
* Return the peak memory used so far, in bytes.
717+
*/
718+
public long getPeakMemoryUsedBytes() {
719+
updatePeakMemoryUsed();
720+
return peakMemoryUsedBytes;
695721
}
696722

697723
/**
@@ -731,6 +757,9 @@ int getNumDataPages() {
731757
*/
732758
@VisibleForTesting
733759
void growAndRehash() {
760+
assert(bitset != null);
761+
assert(longArray != null);
762+
734763
long resizeStartTime = -1;
735764
if (enablePerfMetrics) {
736765
resizeStartTime = System.nanoTime();

core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ public void resizingLargeMap() {
525525
}
526526

527527
@Test
528-
public void testTotalMemoryConsumption() {
528+
public void testPeakMemoryUsed() {
529529
final long recordLengthBytes = 24;
530530
final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
531531
final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
@@ -536,8 +536,8 @@ public void testTotalMemoryConsumption() {
536536
// monotonically increasing. More specifically, every time we allocate a new page it
537537
// should increase by exactly the size of the page. In this regard, the memory usage
538538
// at any given time is also the peak memory used.
539-
long previousMemory = map.getTotalMemoryConsumption();
540-
long newMemory;
539+
long previousPeakMemory = map.getPeakMemoryUsedBytes();
540+
long newPeakMemory;
541541
try {
542542
for (long i = 0; i < numRecordsPerPage * 10; i++) {
543543
final long[] value = new long[]{i};
@@ -548,15 +548,21 @@ public void testTotalMemoryConsumption() {
548548
value,
549549
PlatformDependent.LONG_ARRAY_OFFSET,
550550
8);
551-
newMemory = map.getTotalMemoryConsumption();
551+
newPeakMemory = map.getPeakMemoryUsedBytes();
552552
if (i % numRecordsPerPage == 0) {
553553
// We allocated a new page for this record, so peak memory should change
554-
assertEquals(previousMemory + pageSizeBytes, newMemory);
554+
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
555555
} else {
556-
assertEquals(previousMemory, newMemory);
556+
assertEquals(previousPeakMemory, newPeakMemory);
557557
}
558-
previousMemory = newMemory;
558+
previousPeakMemory = newPeakMemory;
559559
}
560+
561+
// Freeing the map should not change the peak memory
562+
map.free();
563+
newPeakMemory = map.getPeakMemoryUsedBytes();
564+
assertEquals(previousPeakMemory, newPeakMemory);
565+
560566
} finally {
561567
map.free();
562568
}

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,10 @@ public void close() {
210210
}
211211

212212
/**
213-
* The memory used by this map's managed structures, in bytes.
214-
* Note that this is also the peak memory used by this map, since the map is append-only.
213+
* Return the peak memory used so far, in bytes.
215214
*/
216-
public long getMemoryUsage() {
217-
return map.getTotalMemoryConsumption();
215+
public long getPeakMemoryUsedBytes() {
216+
return map.getPeakMemoryUsedBytes();
218217
}
219218

220219
/**

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,13 @@ public KVSorterIterator sortedIterator() throws IOException {
159159
}
160160
}
161161

162+
/**
163+
* Return the peak memory used so far, in bytes.
164+
*/
165+
public long getPeakMemoryUsedBytes() {
166+
return sorter.getPeakMemoryUsedBytes();
167+
}
168+
162169
/**
163170
* Marks the current page as no-more-space-available, and as a result, either allocate a
164171
* new page or spill when we see the next record.

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution.aggregate
1919

2020
import org.apache.spark.unsafe.KVIterator
21-
import org.apache.spark.{Logging, SparkEnv, TaskContext}
21+
import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext}
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate._
2424
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
@@ -397,14 +397,20 @@ class TungstenAggregationIterator(
397397
private[this] var mapIteratorHasNext: Boolean = false
398398

399399
///////////////////////////////////////////////////////////////////////////
400-
// Part 4: The function used to switch this iterator from hash-based
401-
// aggregation to sort-based aggregation.
400+
// Part 3: Methods and fields used by sort-based aggregation.
402401
///////////////////////////////////////////////////////////////////////////
403402

403+
// This sorter is used for sort-based aggregation. It is initialized as soon as
404+
// we switch from hash-based to sort-based aggregation. Otherwise, it is not used.
405+
private[this] var externalSorter: UnsafeKVExternalSorter = null
406+
407+
/**
408+
* Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
409+
*/
404410
private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = {
405411
logInfo("falling back to sort based aggregation.")
406412
// Step 1: Get the ExternalSorter containing sorted entries of the map.
407-
val externalSorter: UnsafeKVExternalSorter = hashMap.destructAndCreateExternalSorter()
413+
externalSorter = hashMap.destructAndCreateExternalSorter()
408414

409415
// Step 2: Free the memory used by the map.
410416
hashMap.free()
@@ -601,7 +607,7 @@ class TungstenAggregationIterator(
601607
}
602608

603609
///////////////////////////////////////////////////////////////////////////
604-
// Par 7: Iterator's public methods.
610+
// Part 7: Iterator's public methods.
605611
///////////////////////////////////////////////////////////////////////////
606612

607613
override final def hasNext: Boolean = {
@@ -610,7 +616,7 @@ class TungstenAggregationIterator(
610616

611617
override final def next(): UnsafeRow = {
612618
if (hasNext) {
613-
if (sortBased) {
619+
val res = if (sortBased) {
614620
// Process the current group.
615621
processCurrentSortedGroup()
616622
// Generate output row for the current group.
@@ -641,6 +647,19 @@ class TungstenAggregationIterator(
641647
result
642648
}
643649
}
650+
651+
// If this is the last record, update the task's peak memory usage. Since we destroy
652+
// the map to create the sorter, their memory usages should not overlap, so it is safe
653+
// to just use the max of the two.
654+
if (!hasNext) {
655+
val mapMemory = hashMap.getPeakMemoryUsedBytes
656+
val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
657+
val peakMemory = Math.max(mapMemory, sorterMemory)
658+
TaskContext.get().internalMetricsToAccumulators(
659+
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory)
660+
}
661+
662+
res
644663
} else {
645664
// no more result
646665
throw new NoSuchElementException
@@ -651,6 +670,7 @@ class TungstenAggregationIterator(
651670
// Part 8: A utility function used to generate a output row when there is no
652671
// input and there is no grouping expression.
653672
///////////////////////////////////////////////////////////////////////////
673+
654674
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
655675
if (groupingExpressions.isEmpty) {
656676
sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
267267
if (!hasGeneratedAgg) {
268268
fail(
269269
s"""
270-
|Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan.
270+
|Codegen is enabled, but query $sqlText does not have TungstenAggregate in the plan.
271271
|${df.queryExecution.simpleString}
272272
""".stripMargin)
273273
}
@@ -1602,10 +1602,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
16021602
Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
16031603
}
16041604

1605-
ignore("aggregation with codegen updates peak execution memory") {
1606-
withSQLConf(
1607-
(SQLConf.CODEGEN_ENABLED.key, "true"),
1608-
(SQLConf.USE_SQL_AGGREGATE2.key, "false")) {
1605+
test("aggregation with codegen updates peak execution memory") {
1606+
withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) {
16091607
val sc = sqlContext.sparkContext
16101608
AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") {
16111609
testCodeGen(

0 commit comments

Comments
 (0)