From 27f2e7f74385fa62592aa11fade4d9d237683825 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 7 Aug 2015 12:36:08 -0700 Subject: [PATCH 1/8] Reserve memory in advance in TungstenAggregate --- .../aggregate/TungstenAggregate.scala | 65 +++++++++++----- .../TungstenAggregationIterator.scala | 78 ++++++++++--------- 2 files changed, 86 insertions(+), 57 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 5a0b4d47d62f..5a9854ffd5f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.rdd.RDD +import org.apache.spark.TaskContext +import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 @@ -61,32 +62,54 @@ case class TungstenAggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator.empty.asInstanceOf[Iterator[UnsafeRow]] - } else { - val aggregationIterator = - new TungstenAggregationIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - completeAggregateExpressions, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - child.output, - iter.asInstanceOf[Iterator[UnsafeRow]], - testFallbackStartsAt) - if (!hasInput && groupingExpressions.isEmpty) { + /** + * Set up the underlying unsafe data structures used before computing the parent partition. + * This makes sure our iterator is not starved by other operators in the same task. + */ + def preparePartition(): TungstenAggregationIterator = { + new TungstenAggregationIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + completeAggregateExpressions, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + testFallbackStartsAt) + } + + /** Compute a partition using the iterator already set up previously. */ + def executePartition( + context: TaskContext, + partitionIndex: Int, + aggregationIterator: TungstenAggregationIterator, + parentIterator: Iterator[UnsafeRow]): Iterator[UnsafeRow] = { + val hasInput = parentIterator.hasNext + if (!hasInput) { + // We're not using the underlying map, so we just can free it here + aggregationIterator.free() + if (groupingExpressions.isEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { - aggregationIterator + Iterator[UnsafeRow]() } + } else { + aggregationIterator.start(parentIterator) + aggregationIterator } } + + // Note: we need to set up the external sorter in each partition before computing + // the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747). + val parentPartition = child.execute().asInstanceOf[RDD[UnsafeRow]] + val resultRdd = { + new MapPartitionsWithPreparationRDD[UnsafeRow, UnsafeRow, TungstenAggregationIterator]( + parentPartition, preparePartition, executePartition, preservesPartitioning = true) + } + resultRdd.asInstanceOf[RDD[InternalRow]] } override def simpleString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 4d5e98a3e90c..4bf9942edb6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -71,8 +71,6 @@ import org.apache.spark.sql.types.StructType * the function used to create mutable projections. * @param originalInputAttributes * attributes of representing input rows from `inputIter`. - * @param inputIter - * the iterator containing input [[UnsafeRow]]s. */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], @@ -82,10 +80,12 @@ class TungstenAggregationIterator( resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], - inputIter: Iterator[UnsafeRow], testFallbackStartsAt: Option[Int]) extends Iterator[UnsafeRow] with Logging { + // The parent partition iterator, to be initialized later in `start` + private[this] var inputIter: Iterator[UnsafeRow] = Iterator[UnsafeRow]() + /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// @@ -576,27 +576,33 @@ class TungstenAggregationIterator( // have not switched to sort-based aggregation. /////////////////////////////////////////////////////////////////////////// - // Starts to process input rows. - testFallbackStartsAt match { - case None => - processInputs() - case Some(fallbackStartsAt) => - // This is the testing path. processInputsWithControlledFallback is same as processInputs - // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows - // have been processed. - processInputsWithControlledFallback(fallbackStartsAt) - } + /** + * Start processing input rows. + * Only after this method is called will this iterator be non-empty. + */ + def start(parentIter: Iterator[UnsafeRow]): Unit = { + inputIter = parentIter + testFallbackStartsAt match { + case None => + processInputs() + case Some(fallbackStartsAt) => + // This is the testing path. processInputsWithControlledFallback is same as processInputs + // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows + // have been processed. + processInputsWithControlledFallback(fallbackStartsAt) + } - // If we did not switch to sort-based aggregation in processInputs, - // we pre-load the first key-value pair from the map (to make hasNext idempotent). - if (!sortBased) { - // First, set aggregationBufferMapIterator. - aggregationBufferMapIterator = hashMap.iterator() - // Pre-load the first key-value pair from the aggregationBufferMapIterator. - mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!mapIteratorHasNext) { - hashMap.free() + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() + } } } @@ -648,20 +654,20 @@ class TungstenAggregationIterator( } /////////////////////////////////////////////////////////////////////////// - // Part 8: A utility function used to generate a output row when there is no - // input and there is no grouping expression. + // Part 8: Utility functions /////////////////////////////////////////////////////////////////////////// + + /** + * Generate a output row when there is no input and there is no grouping expression. + */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { - if (groupingExpressions.isEmpty) { - sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) - // We create a output row and copy it. So, we can free the map. - val resultCopy = - generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() - hashMap.free() - resultCopy - } else { - throw new IllegalStateException( - "This method should not be called when groupingExpressions is not empty.") - } + assert(groupingExpressions.isEmpty) + assert(!inputIter.hasNext) + generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer) + } + + /** Free memory used in the underlying map. */ + def free(): Unit = { + hashMap.free() } } From 654965445a2fee8cde8c81c0feb3f63ec59abe64 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 7 Aug 2015 13:21:48 -0700 Subject: [PATCH 2/8] Actually request the memory in constructor + add tests --- .../spark/unsafe/map/BytesToBytesMap.java | 34 ++++++++---- .../map/AbstractBytesToBytesMapSuite.java | 11 +++- .../UnsafeFixedWidthAggregationMap.java | 7 +++ .../TungstenAggregationIterator.scala | 2 +- .../TungstenAggregationIteratorSuite.scala | 52 +++++++++++++++++++ 5 files changed, 95 insertions(+), 11 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 0636ae7c8df1..bd5524eca478 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -191,6 +191,11 @@ public BytesToBytesMap( TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } allocate(initialCapacity); + + // Acquire a new page as soon as we construct the map to ensure that we have at least + // one page to work with. Otherwise, other operators in the same task may starve this + // map (SPARK-9747). + acquireNewPage(); } public BytesToBytesMap( @@ -565,16 +570,9 @@ public boolean putNewKey( final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; PlatformDependent.UNSAFE.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } - final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryGranted != pageSizeBytes) { - shuffleMemoryManager.release(memoryGranted); - logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + if (!acquireNewPage()) { return false; } - MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); - dataPages.add(newPage); - pageCursor = 0; - currentDataPage = newPage; dataPage = currentDataPage; dataPageBaseObject = currentDataPage.getBaseObject(); dataPageInsertOffset = currentDataPage.getBaseOffset(); @@ -633,6 +631,24 @@ public boolean putNewKey( } } + /** + * Acquire a new page from the {@link ShuffleMemoryManager}. + * @return whether there is enough space to allocate the new page. + */ + private boolean acquireNewPage() { + final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryGranted != pageSizeBytes) { + shuffleMemoryManager.release(memoryGranted); + logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + return false; + } + MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); + dataPages.add(newPage); + pageCursor = 0; + currentDataPage = newPage; + return true; + } + /** * Allocate new data structures for this map. When calling this outside of the constructor, * make sure to keep references to the old data structures so that you can free them. @@ -722,7 +738,7 @@ public long getNumHashCollisions() { } @VisibleForTesting - int getNumDataPages() { + public int getNumDataPages() { return dataPages.size(); } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 0b11562980b8..e9443832a346 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -549,7 +549,7 @@ public void testTotalMemoryConsumption() { PlatformDependent.LONG_ARRAY_OFFSET, 8); newMemory = map.getTotalMemoryConsumption(); - if (i % numRecordsPerPage == 0) { + if (i % numRecordsPerPage == 0 && i > 0) { // We allocated a new page for this record, so peak memory should change assertEquals(previousMemory + pageSizeBytes, newMemory); } else { @@ -561,4 +561,13 @@ public void testTotalMemoryConsumption() { map.free(); } } + + @Test + public void testAcquirePageInConstructor() { + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + assertEquals(1, map.getNumDataPages()); + map.free(); + } + } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index efb33530dac8..4aaf7837396f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -19,6 +19,8 @@ import java.io.IOException; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.SparkEnv; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; @@ -217,6 +219,11 @@ public long getMemoryUsage() { return map.getTotalMemoryConsumption(); } + @VisibleForTesting + public int getNumDataPages() { + return map.getNumDataPages(); + } + /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 4bf9942edb6f..559a6b292329 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -335,7 +335,7 @@ class TungstenAggregationIterator( // This is the hash map used for hash-based aggregation. It is backed by an // UnsafeFixedWidthAggregationMap and it is used to store // all groups and their corresponding aggregation buffers for hash-based aggregation. - private[this] val hashMap = new UnsafeFixedWidthAggregationMap( + private[aggregate] val hashMap = new UnsafeFixedWidthAggregationMap( initialAggregationBuffer, StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala new file mode 100644 index 000000000000..039e5e780492 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.unsafe.memory.TaskMemoryManager +import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection + +class TungstenAggregationIteratorSuite extends SparkFunSuite with LocalSparkContext { + + test("memory acquired on construction") { + // Needed for various things in SparkEnv + sc = new SparkContext("local", "testing") + val taskMemoryManager = new TaskMemoryManager(sc.env.executorMemoryManager) + val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) + TaskContext.setTaskContext(taskContext) + + // Assert that a page is allocated before processing starts + var iter: TungstenAggregationIterator = null + try { + val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { + () => new InterpretedMutableProjection(expr, schema) + } + iter = new TungstenAggregationIterator( + Seq.empty, Seq.empty, Seq.empty, 0, Seq.empty, newMutableProjection, Seq.empty, None) + val numPages = iter.hashMap.getNumDataPages + assert(numPages === 1) + } finally { + // Clean up + if (iter != null) { + iter.free() + } + TaskContext.unset() + } + } +} From ca1b44c01fd02d02a56a8d25d6a1fd40cbc8343d Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 7 Aug 2015 13:43:35 -0700 Subject: [PATCH 3/8] Minor: Update comment --- .../spark/sql/execution/aggregate/TungstenAggregate.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 5a9854ffd5f6..e141950135ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -102,8 +102,8 @@ case class TungstenAggregate( } } - // Note: we need to set up the external sorter in each partition before computing - // the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747). + // Note: we need to set up the iterator in each partition before computing the + // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747). val parentPartition = child.execute().asInstanceOf[RDD[UnsafeRow]] val resultRdd = { new MapPartitionsWithPreparationRDD[UnsafeRow, UnsafeRow, TungstenAggregationIterator]( From 4d416d07a6139f1c84e81db8b37a8c1fd9354856 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 10 Aug 2015 13:15:44 -0700 Subject: [PATCH 4/8] Do not acquire a page if creating sorter destructively In TungstenAggregate, we fall back to sort-based aggregation if the hash-based approach cannot request more memory. To do this, we create a new sorter from an existing unsafe map destructively. Because this is largely in place, we don't need to reserve a page in the sorter's constructor. --- .../collection/unsafe/sort/UnsafeExternalSorter.java | 9 ++++----- .../aggregate/TungstenAggregationIterator.scala | 4 ++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 5ebbf9b068fd..a16e699be481 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -132,16 +132,15 @@ private UnsafeExternalSorter( if (existingInMemorySorter == null) { initializeForWriting(); + // Acquire a new page as soon as we construct the sorter to ensure that we have at + // least one page to work with. Otherwise, other operators in the same task may starve + // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter. + acquireNewPage(); } else { this.isInMemSorterExternal = true; this.inMemSorter = existingInMemorySorter; } - // Acquire a new page as soon as we construct the sorter to ensure that we have at - // least one page to work with. Otherwise, other operators in the same task may starve - // this sorter (SPARK-9709). - acquireNewPage(); - // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 25acf573b5c2..cb36c85a60ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -426,6 +426,10 @@ class TungstenAggregationIterator( case _ => false } + // Note: we spill the sorter's contents immediately after creating it. Therefore, we must + // insert something into the sorter here to ensure that we acquire at least a page of memory. + // Otherwise, children operators may steal the window of opportunity and starve our sorter. + if (needsProcess) { // First, we create a buffer. val buffer = createNewAggregationBuffer() From b4d3633b256de6d981ef7fd2e62afa5490323682 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 10 Aug 2015 14:31:51 -0700 Subject: [PATCH 5/8] Address comments --- .../execution/aggregate/TungstenAggregate.scala | 7 +++---- .../aggregate/TungstenAggregationIterator.scala | 17 ++++++++++++----- .../TungstenAggregationIteratorSuite.scala | 2 +- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 305d0b08488f..c79289b0012e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -90,10 +90,10 @@ case class TungstenAggregate( // We're not using the underlying map, so we just can free it here aggregationIterator.free() if (groupingExpressions.isEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. Iterator[UnsafeRow]() } } else { @@ -104,10 +104,9 @@ case class TungstenAggregate( // Note: we need to set up the iterator in each partition before computing the // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747). - val parentPartition = child.execute() val resultRdd = { new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator]( - parentPartition, preparePartition, executePartition, preservesPartitioning = true) + child.execute(), preparePartition, executePartition, preservesPartitioning = true) } resultRdd.asInstanceOf[RDD[InternalRow]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index cb36c85a60ef..dd0191c6540b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -84,7 +84,7 @@ class TungstenAggregationIterator( extends Iterator[UnsafeRow] with Logging { // The parent partition iterator, to be initialized later in `start` - private[this] var inputIter: Iterator[InternalRow] = Iterator[InternalRow]() + private[this] var inputIter: Iterator[InternalRow] = null /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. @@ -334,7 +334,7 @@ class TungstenAggregationIterator( // This is the hash map used for hash-based aggregation. It is backed by an // UnsafeFixedWidthAggregationMap and it is used to store // all groups and their corresponding aggregation buffers for hash-based aggregation. - private[aggregate] val hashMap = new UnsafeFixedWidthAggregationMap( + private[this] val hashMap = new UnsafeFixedWidthAggregationMap( initialAggregationBuffer, StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), @@ -345,11 +345,15 @@ class TungstenAggregationIterator( false // disable tracking of performance metrics ) + // Exposed for testing + private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap + // The function used to read and process input rows. When processing input rows, // it first uses hash-based aggregation by putting groups and their buffers in // hashMap. If we could not allocate more memory for the map, we switch to // sort-based aggregation (by calling switchToSortBasedAggregation). private def processInputs(): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() val groupingKey = groupProjection.apply(newInput) @@ -368,6 +372,7 @@ class TungstenAggregationIterator( // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have // been processed. private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") var i = 0 while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() @@ -407,6 +412,7 @@ class TungstenAggregationIterator( * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. externalSorter = hashMap.destructAndCreateExternalSorter() @@ -426,8 +432,9 @@ class TungstenAggregationIterator( case _ => false } - // Note: we spill the sorter's contents immediately after creating it. Therefore, we must - // insert something into the sorter here to ensure that we acquire at least a page of memory. + // Note: Since we spill the sorter's contents immediately after creating it, we must insert + // something into the sorter here to ensure that we acquire at least a page of memory. + // This is done through `externalSorter.insertKV`, which will trigger the page allocation. // Otherwise, children operators may steal the window of opportunity and starve our sorter. if (needsProcess) { @@ -684,7 +691,7 @@ class TungstenAggregationIterator( */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { assert(groupingExpressions.isEmpty) - assert(!inputIter.hasNext) + assert(inputIter == null) generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index 039e5e780492..9b38c2fa58a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -39,7 +39,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with LocalSparkCont } iter = new TungstenAggregationIterator( Seq.empty, Seq.empty, Seq.empty, 0, Seq.empty, newMutableProjection, Seq.empty, None) - val numPages = iter.hashMap.getNumDataPages + val numPages = iter.getHashMap.getNumDataPages assert(numPages === 1) } finally { // Clean up From b10a4f3904bd1a67c0c5ae07e11e21464c6a0268 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 10:10:40 -0700 Subject: [PATCH 6/8] Fix test --- .../aggregate/TungstenAggregationIteratorSuite.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index 9b38c2fa58a9..8c297870d5f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -19,15 +19,17 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.unsafe.memory.TaskMemoryManager -class TungstenAggregationIteratorSuite extends SparkFunSuite with LocalSparkContext { +class TungstenAggregationIteratorSuite extends SparkFunSuite { test("memory acquired on construction") { - // Needed for various things in SparkEnv - sc = new SparkContext("local", "testing") - val taskMemoryManager = new TaskMemoryManager(sc.env.executorMemoryManager) + // set up environment + TestSQLContext + + val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) TaskContext.setTaskContext(taskContext) From d4dc9cae8608622f25948aac51a29953cd25b09f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 17:47:57 -0700 Subject: [PATCH 7/8] Fix tests --- .../aggregate/TungstenAggregationIteratorSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index 8c297870d5f8..1c190e92cdd2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -39,8 +40,9 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite { val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { () => new InterpretedMutableProjection(expr, schema) } - iter = new TungstenAggregationIterator( - Seq.empty, Seq.empty, Seq.empty, 0, Seq.empty, newMutableProjection, Seq.empty, None) + val dummyAccum = new LongSQLMetric("dummy") + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, + Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) val numPages = iter.getHashMap.getNumDataPages assert(numPages === 1) } finally { From 19f2e1b9e03a385478d8e91e41f93b1bc8e8c909 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 23:12:46 -0700 Subject: [PATCH 8/8] Fix tests again --- .../aggregate/TungstenAggregationIteratorSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index 1c190e92cdd2..ac22c2f3c0a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection -import org.apache.spark.sql.execution.metric.LongSQLMetric +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -28,7 +28,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite { test("memory acquired on construction") { // set up environment - TestSQLContext + val ctx = TestSQLContext val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) @@ -40,7 +40,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite { val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { () => new InterpretedMutableProjection(expr, schema) } - val dummyAccum = new LongSQLMetric("dummy") + val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy") iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) val numPages = iter.getHashMap.getNumDataPages