From 539720dda50c8fb277c9574867f9df7b4e7937d9 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Mon, 19 Apr 2021 16:24:04 -0700 Subject: [PATCH 1/5] Support two level hash map for final hash aggregation --- .../aggregate/HashAggregateExec.scala | 51 +++++++++++++++++-- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 6e23a2844d14..648142e2bf0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -128,6 +128,16 @@ case class HashAggregateExec( // all the mode of aggregate expressions private val modes = aggregateExpressions.map(_.mode).distinct + // This is for testing final aggregate with number-of-rows-based fall back as specified in + // `testFallbackStartsAt`. In this scenario, there might be same keys exist in both fast and + // regular hash map. So the aggregation buffers from both maps need to be merged together + // to avoid correctness issue. + // + // This scenario only happens in unit test with number-of-rows-based fall back. + // There should not be same keys in both maps with size-based fall back in production. + private val isTestFinalAggregateWithFallback: Boolean = testFallbackStartsAt.isDefined && + (modes.contains(Final) || modes.contains(Complete)) + override def usedInputs: AttributeSet = inputSet override def supportCodegen: Boolean = { @@ -537,6 +547,34 @@ case class HashAggregateExec( } } + /** + * Called by generated Java class to finish merge the fast hash map into regular map. + * This is used for testing final aggregate only. + */ + def mergeFastHashMapForTest( + fastHashMapRowIter: KVIterator[UnsafeRow, UnsafeRow], + regularHashMap: UnsafeFixedWidthAggregationMap): Unit = { + + // Create a MutableProjection to merge the buffers of same key together + val mergeExpr = declFunctions.flatMap(_.mergeExpressions) + val mergeProjection = MutableProjection.create( + mergeExpr, + aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes)) + val joinedRow = new JoinedRow() + + while (fastHashMapRowIter.next()) { + val key = fastHashMapRowIter.getKey + val fastMapBuffer = fastHashMapRowIter.getValue + val regularMapBuffer = regularHashMap.getAggregationBufferFromUnsafeRow(key) + + // Merge the aggregation buffer of fast hash map, into the buffer with same key of + // regular map + mergeProjection.target(regularMapBuffer) + mergeProjection(joinedRow(regularMapBuffer, fastMapBuffer)) + } + fastHashMapRowIter.close() + } + /** * Generate the code for output. * @return function name for the result code. @@ -647,7 +685,7 @@ case class HashAggregateExec( (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) || f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType] || f.dataType.isInstanceOf[CalendarIntervalType]) && - bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) + bufferSchema.nonEmpty // For vectorized hash map, We do not support byte array based decimal type for aggregate values // as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place @@ -663,7 +701,7 @@ case class HashAggregateExec( private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = { if (!checkIfFastHashMapSupported(ctx)) { - if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { + if (!Utils.isTesting) { logInfo(s"${SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key} is set to true, but" + " current version of codegened fast hashmap does not support this aggregate.") } @@ -740,8 +778,13 @@ case class HashAggregateExec( val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" + s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe);" val finishHashMap = if (isFastHashMapEnabled) { + val finishFastHashMap = if (isTestFinalAggregateWithFallback) { + s"$thisPlan.mergeFastHashMapForTest($fastHashMapTerm.rowIterator(), $hashMapTerm);" + } else { + s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();" + } s""" - |$iterTermForFastHashMap = $fastHashMapTerm.rowIterator(); + |$finishFastHashMap |$finishRegularHashMap """.stripMargin } else { @@ -762,7 +805,7 @@ case class HashAggregateExec( val outputFunc = generateResultFunction(ctx) def outputFromFastHashMap: String = { - if (isFastHashMapEnabled) { + if (isFastHashMapEnabled && !isTestFinalAggregateWithFallback) { if (isVectorizedHashMapEnabled) { outputFromVectorizedMap } else { From c9d09be9ad22ddfe764195ce5e52a92dbda7df31 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Tue, 20 Apr 2021 00:10:16 -0700 Subject: [PATCH 2/5] Add limit early stop, listener to clean up resource, and fix doc --- .../sql/catalyst/expressions/grouping.scala | 2 +- .../execution/aggregate/HashAggregateExec.scala | 16 +++++++++------- .../aggregate/RowBasedHashMapGenerator.scala | 16 ++++++++++++++-- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 808c8222d293..4d14b5e7927f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -214,9 +214,9 @@ case class Grouping(child: Expression) extends Expression with Unevaluable Examples: > SELECT name, _FUNC_(), sum(age), avg(height) FROM VALUES (2, 'Alice', 165), (5, 'Bob', 180) people(age, name, height) GROUP BY cube(name, height); Alice 0 2 165.0 + Bob 0 5 180.0 Alice 1 2 165.0 NULL 3 7 172.5 - Bob 0 5 180.0 Bob 1 5 180.0 NULL 2 2 165.0 NULL 2 5 180.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 648142e2bf0b..b1905dffdfda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit._ import scala.collection.mutable import org.apache.spark.TaskContext -import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} +import org.apache.spark.memory.SparkOutOfMemoryError import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -445,8 +445,8 @@ case class HashAggregateExec( ) } - def getTaskMemoryManager(): TaskMemoryManager = { - TaskContext.get().taskMemoryManager() + def getTaskContext(): TaskContext = { + TaskContext.get() } def getEmptyAggregationBuffer(): InternalRow = { @@ -755,7 +755,7 @@ case class HashAggregateExec( "org.apache.spark.unsafe.KVIterator", "fastHashMapIter", forceInline = true) val create = s"$fastHashMapTerm = new $fastHashMapClassName(" + - s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());" + s"$thisPlan.getTaskContext(), $thisPlan.getEmptyAggregationBuffer());" (iter, create) } } else ("", "") @@ -804,6 +804,8 @@ case class HashAggregateExec( val bufferTerm = ctx.freshName("aggBuffer") val outputFunc = generateResultFunction(ctx) + val limitNotReachedCondition = limitNotReachedCond + def outputFromFastHashMap: String = { if (isFastHashMapEnabled && !isTestFinalAggregateWithFallback) { if (isVectorizedHashMapEnabled) { @@ -816,7 +818,7 @@ case class HashAggregateExec( def outputFromRowBasedMap: String = { s""" - |while ($iterTermForFastHashMap.next()) { + |while ($limitNotReachedCondition $iterTermForFastHashMap.next()) { | UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); | $outputFunc($keyTerm, $bufferTerm); @@ -841,7 +843,7 @@ case class HashAggregateExec( BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) s""" - |while ($iterTermForFastHashMap.hasNext()) { + |while ($limitNotReachedCondition $iterTermForFastHashMap.hasNext()) { | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next(); | ${generateKeyRow.code} | ${generateBufferRow.code} @@ -856,7 +858,7 @@ case class HashAggregateExec( def outputFromRegularHashMap: String = { s""" - |while ($limitNotReachedCond $iterTerm.next()) { + |while ($limitNotReachedCondition $iterTerm.next()) { | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | $outputFunc($keyTerm, $bufferTerm); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 44d19ad60d49..58fcbc7c3d0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -70,10 +70,10 @@ class RowBasedHashMapGenerator( | | | public $generatedClassName( - | org.apache.spark.memory.TaskMemoryManager taskMemoryManager, + | org.apache.spark.TaskContext taskContext, | InternalRow emptyAggregationBuffer) { | batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch - | .allocate($keySchema, $valueSchema, taskMemoryManager, capacity); + | .allocate($keySchema, $valueSchema, taskContext.taskMemoryManager(), capacity); | | final UnsafeProjection valueProjection = UnsafeProjection.create($valueSchema); | final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); @@ -87,6 +87,18 @@ class RowBasedHashMapGenerator( | | buckets = new int[numBuckets]; | java.util.Arrays.fill(buckets, -1); + | + | // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be + | // freed at the end of the task. This is necessary to avoid memory leaks in when the + | // downstream operator does not fully consume the aggregation map's output + | // (e.g. aggregate followed by limit). + | taskContext.addTaskCompletionListener( + | new org.apache.spark.util.TaskCompletionListener() { + | @Override + | public void onTaskCompletion(org.apache.spark.TaskContext context) { + | close(); + | } + | }); | } """.stripMargin } From 917e7bb8c958fbcf0cbc908429a365ae8eb6aa94 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Tue, 20 Apr 2021 17:43:13 -0700 Subject: [PATCH 3/5] Move fast hash map cleanup logic to HashAggregateExec --- .../aggregate/HashAggregateExec.scala | 88 ++++++++++++------- .../aggregate/RowBasedHashMapGenerator.scala | 16 +--- 2 files changed, 56 insertions(+), 48 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b1905dffdfda..fa57bc134835 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -725,40 +725,59 @@ case class HashAggregateExec( val thisPlan = ctx.addReferenceObj("plan", this) - // Create a name for the iterator from the fast hash map, and the code to create fast hash map. - val (iterTermForFastHashMap, createFastHashMap) = if (isFastHashMapEnabled) { - // Generates the fast hash map class and creates the fast hash map term. - val fastHashMapClassName = ctx.freshName("FastHashMap") - if (isVectorizedHashMapEnabled) { - val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() - ctx.addInnerClass(generatedMap) - - // Inline mutable state since not many aggregation operations in a task - fastHashMapTerm = ctx.addMutableState( - fastHashMapClassName, "vectorizedFastHashMap", forceInline = true) - val iter = ctx.addMutableState( - "java.util.Iterator", - "vectorizedFastHashMapIter", - forceInline = true) - val create = s"$fastHashMapTerm = new $fastHashMapClassName();" - (iter, create) - } else { - val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() - ctx.addInnerClass(generatedMap) - - // Inline mutable state since not many aggregation operations in a task - fastHashMapTerm = ctx.addMutableState( - fastHashMapClassName, "fastHashMap", forceInline = true) - val iter = ctx.addMutableState( - "org.apache.spark.unsafe.KVIterator", - "fastHashMapIter", forceInline = true) - val create = s"$fastHashMapTerm = new $fastHashMapClassName(" + - s"$thisPlan.getTaskContext(), $thisPlan.getEmptyAggregationBuffer());" - (iter, create) - } - } else ("", "") + // Create a name for the iterator from the fast hash map, the code to create + // and add hook to close fast hash map. + val (iterTermForFastHashMap, createFastHashMap, addHookToCloseFastHashMap) = + if (isFastHashMapEnabled) { + // Generates the fast hash map class and creates the fast hash map term. + val fastHashMapClassName = ctx.freshName("FastHashMap") + val (iter, create) = if (isVectorizedHashMapEnabled) { + val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, + fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() + ctx.addInnerClass(generatedMap) + + // Inline mutable state since not many aggregation operations in a task + fastHashMapTerm = ctx.addMutableState( + fastHashMapClassName, "vectorizedFastHashMap", forceInline = true) + val iter = ctx.addMutableState( + "java.util.Iterator", + "vectorizedFastHashMapIter", + forceInline = true) + val create = s"$fastHashMapTerm = new $fastHashMapClassName();" + (iter, create) + } else { + val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, + fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() + ctx.addInnerClass(generatedMap) + + // Inline mutable state since not many aggregation operations in a task + fastHashMapTerm = ctx.addMutableState( + fastHashMapClassName, "fastHashMap", forceInline = true) + val iter = ctx.addMutableState( + "org.apache.spark.unsafe.KVIterator", + "fastHashMapIter", forceInline = true) + val create = s"$fastHashMapTerm = new $fastHashMapClassName(" + + s"$thisPlan.getTaskContext().taskMemoryManager(), " + + s"$thisPlan.getEmptyAggregationBuffer());" + (iter, create) + } + + // Generates the code to register a cleanup task with TaskContext to ensure that memory + // is guaranteed to be freed at the end of the task. This is necessary to avoid memory + // leaks in when the downstream operator does not fully consume the aggregation map's + // output (e.g. aggregate followed by limit). + val hookToCloseFastHashMap = + s""" + |$thisPlan.getTaskContext().addTaskCompletionListener( + | new org.apache.spark.util.TaskCompletionListener() { + | @Override + | public void onTaskCompletion(org.apache.spark.TaskContext context) { + | $fastHashMapTerm.close(); + | } + |}); + """.stripMargin + (iter, create, hookToCloseFastHashMap) + } else ("", "", "") // Create a name for the iterator from the regular hash map. // Inline mutable state since not many aggregation operations in a task @@ -877,6 +896,7 @@ case class HashAggregateExec( |if (!$initAgg) { | $initAgg = true; | $createFastHashMap + | $addHookToCloseFastHashMap | $hashMapTerm = $thisPlan.createHashMap(); | long $beforeAgg = System.nanoTime(); | $doAggFuncName(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 58fcbc7c3d0b..44d19ad60d49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -70,10 +70,10 @@ class RowBasedHashMapGenerator( | | | public $generatedClassName( - | org.apache.spark.TaskContext taskContext, + | org.apache.spark.memory.TaskMemoryManager taskMemoryManager, | InternalRow emptyAggregationBuffer) { | batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch - | .allocate($keySchema, $valueSchema, taskContext.taskMemoryManager(), capacity); + | .allocate($keySchema, $valueSchema, taskMemoryManager, capacity); | | final UnsafeProjection valueProjection = UnsafeProjection.create($valueSchema); | final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); @@ -87,18 +87,6 @@ class RowBasedHashMapGenerator( | | buckets = new int[numBuckets]; | java.util.Arrays.fill(buckets, -1); - | - | // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be - | // freed at the end of the task. This is necessary to avoid memory leaks in when the - | // downstream operator does not fully consume the aggregation map's output - | // (e.g. aggregate followed by limit). - | taskContext.addTaskCompletionListener( - | new org.apache.spark.util.TaskCompletionListener() { - | @Override - | public void onTaskCompletion(org.apache.spark.TaskContext context) { - | close(); - | } - | }); | } """.stripMargin } From 67d4cd76c4de27b69acb8edc1ea37972a2de67aa Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Wed, 21 Apr 2021 22:19:40 -0700 Subject: [PATCH 4/5] Change number-of-rows fallback behavior to restrict fast map capacity --- .../aggregate/HashAggregateExec.scala | 82 ++++++------------- 1 file changed, 23 insertions(+), 59 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index fa57bc134835..de014aada620 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -128,16 +128,6 @@ case class HashAggregateExec( // all the mode of aggregate expressions private val modes = aggregateExpressions.map(_.mode).distinct - // This is for testing final aggregate with number-of-rows-based fall back as specified in - // `testFallbackStartsAt`. In this scenario, there might be same keys exist in both fast and - // regular hash map. So the aggregation buffers from both maps need to be merged together - // to avoid correctness issue. - // - // This scenario only happens in unit test with number-of-rows-based fall back. - // There should not be same keys in both maps with size-based fall back in production. - private val isTestFinalAggregateWithFallback: Boolean = testFallbackStartsAt.isDefined && - (modes.contains(Final) || modes.contains(Complete)) - override def usedInputs: AttributeSet = inputSet override def supportCodegen: Boolean = { @@ -547,34 +537,6 @@ case class HashAggregateExec( } } - /** - * Called by generated Java class to finish merge the fast hash map into regular map. - * This is used for testing final aggregate only. - */ - def mergeFastHashMapForTest( - fastHashMapRowIter: KVIterator[UnsafeRow, UnsafeRow], - regularHashMap: UnsafeFixedWidthAggregationMap): Unit = { - - // Create a MutableProjection to merge the buffers of same key together - val mergeExpr = declFunctions.flatMap(_.mergeExpressions) - val mergeProjection = MutableProjection.create( - mergeExpr, - aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes)) - val joinedRow = new JoinedRow() - - while (fastHashMapRowIter.next()) { - val key = fastHashMapRowIter.getKey - val fastMapBuffer = fastHashMapRowIter.getValue - val regularMapBuffer = regularHashMap.getAggregationBufferFromUnsafeRow(key) - - // Merge the aggregation buffer of fast hash map, into the buffer with same key of - // regular map - mergeProjection.target(regularMapBuffer) - mergeProjection(joinedRow(regularMapBuffer, fastMapBuffer)) - } - fastHashMapRowIter.close() - } - /** * Generate the code for output. * @return function name for the result code. @@ -721,7 +683,18 @@ case class HashAggregateExec( } else if (sqlContext.conf.enableVectorizedHashMap) { logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.") } - val bitMaxCapacity = sqlContext.conf.fastHashAggregateRowMaxCapacityBit + val bitMaxCapacity = testFallbackStartsAt match { + case Some((fastMapCounter, _)) => + // In testing, with fall back counter of fast hash map (`fastMapCounter`), set the max bit + // of map to be no more than log2(`fastMapCounter`). This helps control the number of keys + // in map to mimic fall back. + if (fastMapCounter <= 1) { + 0 + } else { + (math.log10(fastMapCounter) / math.log10(2)).floor.toInt + } + case _ => sqlContext.conf.fastHashAggregateRowMaxCapacityBit + } val thisPlan = ctx.addReferenceObj("plan", this) @@ -797,13 +770,8 @@ case class HashAggregateExec( val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" + s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe);" val finishHashMap = if (isFastHashMapEnabled) { - val finishFastHashMap = if (isTestFinalAggregateWithFallback) { - s"$thisPlan.mergeFastHashMapForTest($fastHashMapTerm.rowIterator(), $hashMapTerm);" - } else { - s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();" - } s""" - |$finishFastHashMap + |$iterTermForFastHashMap = $fastHashMapTerm.rowIterator(); |$finishRegularHashMap """.stripMargin } else { @@ -826,7 +794,7 @@ case class HashAggregateExec( val limitNotReachedCondition = limitNotReachedCond def outputFromFastHashMap: String = { - if (isFastHashMapEnabled && !isTestFinalAggregateWithFallback) { + if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { outputFromVectorizedMap } else { @@ -931,13 +899,11 @@ case class HashAggregateExec( } } - val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, - incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") - (s"$countTerm < ${testFallbackStartsAt.get._1}", - s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") - } else { - ("true", "true", "", "") + val (checkFallbackForBytesToBytesMap, resetCounter, incCounter) = testFallbackStartsAt match { + case Some((_, regularMapCounter)) => + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") + (s"$countTerm < $regularMapCounter", s"$countTerm = 0;", s"$countTerm += 1;") + case _ => ("true", "", "") } val oomeClassName = classOf[SparkOutOfMemoryError].getName @@ -977,12 +943,10 @@ case class HashAggregateExec( // If fast hash map is on, we first generate code to probe and update the fast hash map. // If the probe is successful the corresponding fast row buffer will hold the mutable row. s""" - |if ($checkFallbackForGeneratedHashMap) { - | ${fastRowKeys.map(_.code).mkString("\n")} - | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { - | $fastRowBuffer = $fastHashMapTerm.findOrInsert( - | ${fastRowKeys.map(_.value).mkString(", ")}); - | } + |${fastRowKeys.map(_.code).mkString("\n")} + |if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { + | $fastRowBuffer = $fastHashMapTerm.findOrInsert( + | ${fastRowKeys.map(_.value).mkString(", ")}); |} |// Cannot find the key in fast hash map, try regular hash map. |if ($fastRowBuffer == null) { From 965a35c7329bacad6f46010811aee4693bd04d95 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Thu, 22 Apr 2021 00:31:27 -0700 Subject: [PATCH 5/5] Address comments to separate out addHookToCloseFastHashMap --- .../aggregate/HashAggregateExec.scala | 102 +++++++++--------- 1 file changed, 50 insertions(+), 52 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index de014aada620..3c1304e9cdad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -698,59 +698,57 @@ case class HashAggregateExec( val thisPlan = ctx.addReferenceObj("plan", this) - // Create a name for the iterator from the fast hash map, the code to create - // and add hook to close fast hash map. - val (iterTermForFastHashMap, createFastHashMap, addHookToCloseFastHashMap) = - if (isFastHashMapEnabled) { - // Generates the fast hash map class and creates the fast hash map term. - val fastHashMapClassName = ctx.freshName("FastHashMap") - val (iter, create) = if (isVectorizedHashMapEnabled) { - val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() - ctx.addInnerClass(generatedMap) - - // Inline mutable state since not many aggregation operations in a task - fastHashMapTerm = ctx.addMutableState( - fastHashMapClassName, "vectorizedFastHashMap", forceInline = true) - val iter = ctx.addMutableState( - "java.util.Iterator", - "vectorizedFastHashMapIter", - forceInline = true) - val create = s"$fastHashMapTerm = new $fastHashMapClassName();" - (iter, create) - } else { - val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() - ctx.addInnerClass(generatedMap) - - // Inline mutable state since not many aggregation operations in a task - fastHashMapTerm = ctx.addMutableState( - fastHashMapClassName, "fastHashMap", forceInline = true) - val iter = ctx.addMutableState( - "org.apache.spark.unsafe.KVIterator", - "fastHashMapIter", forceInline = true) - val create = s"$fastHashMapTerm = new $fastHashMapClassName(" + - s"$thisPlan.getTaskContext().taskMemoryManager(), " + - s"$thisPlan.getEmptyAggregationBuffer());" - (iter, create) - } + // Create a name for the iterator from the fast hash map, and the code to create fast hash map. + val (iterTermForFastHashMap, createFastHashMap) = if (isFastHashMapEnabled) { + // Generates the fast hash map class and creates the fast hash map term. + val fastHashMapClassName = ctx.freshName("FastHashMap") + if (isVectorizedHashMapEnabled) { + val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, + fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() + ctx.addInnerClass(generatedMap) + + // Inline mutable state since not many aggregation operations in a task + fastHashMapTerm = ctx.addMutableState( + fastHashMapClassName, "vectorizedFastHashMap", forceInline = true) + val iter = ctx.addMutableState( + "java.util.Iterator", + "vectorizedFastHashMapIter", + forceInline = true) + val create = s"$fastHashMapTerm = new $fastHashMapClassName();" + (iter, create) + } else { + val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, + fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() + ctx.addInnerClass(generatedMap) + + // Inline mutable state since not many aggregation operations in a task + fastHashMapTerm = ctx.addMutableState( + fastHashMapClassName, "fastHashMap", forceInline = true) + val iter = ctx.addMutableState( + "org.apache.spark.unsafe.KVIterator", + "fastHashMapIter", forceInline = true) + val create = s"$fastHashMapTerm = new $fastHashMapClassName(" + + s"$thisPlan.getTaskContext().taskMemoryManager(), " + + s"$thisPlan.getEmptyAggregationBuffer());" + (iter, create) + } + } else ("", "") - // Generates the code to register a cleanup task with TaskContext to ensure that memory - // is guaranteed to be freed at the end of the task. This is necessary to avoid memory - // leaks in when the downstream operator does not fully consume the aggregation map's - // output (e.g. aggregate followed by limit). - val hookToCloseFastHashMap = - s""" - |$thisPlan.getTaskContext().addTaskCompletionListener( - | new org.apache.spark.util.TaskCompletionListener() { - | @Override - | public void onTaskCompletion(org.apache.spark.TaskContext context) { - | $fastHashMapTerm.close(); - | } - |}); - """.stripMargin - (iter, create, hookToCloseFastHashMap) - } else ("", "", "") + // Generates the code to register a cleanup task with TaskContext to ensure that memory + // is guaranteed to be freed at the end of the task. This is necessary to avoid memory + // leaks in when the downstream operator does not fully consume the aggregation map's + // output (e.g. aggregate followed by limit). + val addHookToCloseFastHashMap = if (isFastHashMapEnabled) { + s""" + |$thisPlan.getTaskContext().addTaskCompletionListener( + | new org.apache.spark.util.TaskCompletionListener() { + | @Override + | public void onTaskCompletion(org.apache.spark.TaskContext context) { + | $fastHashMapTerm.close(); + | } + |}); + """.stripMargin + } else "" // Create a name for the iterator from the regular hash map. // Inline mutable state since not many aggregation operations in a task