-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-35141][SQL] Support two level of hash maps for final hash aggregation #32242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
539720d
c9d09be
917e7bb
67d4cd7
965a35c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit._ | |
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.TaskContext | ||
| import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} | ||
| import org.apache.spark.memory.SparkOutOfMemoryError | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
|
|
@@ -435,8 +435,8 @@ case class HashAggregateExec( | |
| ) | ||
| } | ||
|
|
||
| def getTaskMemoryManager(): TaskMemoryManager = { | ||
| TaskContext.get().taskMemoryManager() | ||
| def getTaskContext(): TaskContext = { | ||
| TaskContext.get() | ||
| } | ||
|
|
||
| def getEmptyAggregationBuffer(): InternalRow = { | ||
|
|
@@ -647,7 +647,7 @@ case class HashAggregateExec( | |
| (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) || | ||
| f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType] || | ||
| f.dataType.isInstanceOf[CalendarIntervalType]) && | ||
| bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) | ||
| bufferSchema.nonEmpty | ||
|
|
||
| // For vectorized hash map, We do not support byte array based decimal type for aggregate values | ||
| // as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place | ||
|
|
@@ -663,7 +663,7 @@ case class HashAggregateExec( | |
|
|
||
| private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = { | ||
| if (!checkIfFastHashMapSupported(ctx)) { | ||
| if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. last question: can we search the commit history and figure out why we didn't enable the fast hash map in the final aggregate? It seems we did it on purpose.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan - I was wondering at first place before making this PR as well. The decision to only support partial aggregate is made when the first level hash map was introduced (#12345 and #14176), and never changed afterwards. I checked with @sameeragarwal before making this PR. He told me there is no fundamental reason to not support final aggregate. Just for documentation, I asked him why we don't support nested type (array/map/struct) as key type for fast hash map. He told me the reason was the size of keys might be too large for long array/map/struct, so the size of fast hash map may not fit in cache and lose the benefit. |
||
| if (!Utils.isTesting) { | ||
| logInfo(s"${SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key} is set to true, but" | ||
| + " current version of codegened fast hashmap does not support this aggregate.") | ||
| } | ||
|
|
@@ -683,7 +683,18 @@ case class HashAggregateExec( | |
| } else if (sqlContext.conf.enableVectorizedHashMap) { | ||
| logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.") | ||
| } | ||
| val bitMaxCapacity = sqlContext.conf.fastHashAggregateRowMaxCapacityBit | ||
| val bitMaxCapacity = testFallbackStartsAt match { | ||
| case Some((fastMapCounter, _)) => | ||
| // In testing, with fall back counter of fast hash map (`fastMapCounter`), set the max bit | ||
| // of map to be no more than log2(`fastMapCounter`). This helps control the number of keys | ||
| // in map to mimic fall back. | ||
| if (fastMapCounter <= 1) { | ||
| 0 | ||
| } else { | ||
| (math.log10(fastMapCounter) / math.log10(2)).floor.toInt | ||
| } | ||
| case _ => sqlContext.conf.fastHashAggregateRowMaxCapacityBit | ||
| } | ||
|
|
||
| val thisPlan = ctx.addReferenceObj("plan", this) | ||
|
|
||
|
|
@@ -717,11 +728,28 @@ case class HashAggregateExec( | |
| "org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>", | ||
| "fastHashMapIter", forceInline = true) | ||
| val create = s"$fastHashMapTerm = new $fastHashMapClassName(" + | ||
| s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());" | ||
| s"$thisPlan.getTaskContext().taskMemoryManager(), " + | ||
| s"$thisPlan.getEmptyAggregationBuffer());" | ||
| (iter, create) | ||
| } | ||
| } else ("", "") | ||
|
|
||
| // Generates the code to register a cleanup task with TaskContext to ensure that memory | ||
| // is guaranteed to be freed at the end of the task. This is necessary to avoid memory | ||
| // leaks in when the downstream operator does not fully consume the aggregation map's | ||
| // output (e.g. aggregate followed by limit). | ||
| val addHookToCloseFastHashMap = if (isFastHashMapEnabled) { | ||
| s""" | ||
| |$thisPlan.getTaskContext().addTaskCompletionListener( | ||
| | new org.apache.spark.util.TaskCompletionListener() { | ||
| | @Override | ||
| | public void onTaskCompletion(org.apache.spark.TaskContext context) { | ||
| | $fastHashMapTerm.close(); | ||
| | } | ||
| |}); | ||
| """.stripMargin | ||
| } else "" | ||
|
|
||
| // Create a name for the iterator from the regular hash map. | ||
| // Inline mutable state since not many aggregation operations in a task | ||
| val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, | ||
|
|
@@ -761,6 +789,8 @@ case class HashAggregateExec( | |
| val bufferTerm = ctx.freshName("aggBuffer") | ||
| val outputFunc = generateResultFunction(ctx) | ||
|
|
||
| val limitNotReachedCondition = limitNotReachedCond | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding the limit early termination for first level map as well. This is needed to fix test failure |
||
|
|
||
| def outputFromFastHashMap: String = { | ||
| if (isFastHashMapEnabled) { | ||
| if (isVectorizedHashMapEnabled) { | ||
|
|
@@ -773,7 +803,7 @@ case class HashAggregateExec( | |
|
|
||
| def outputFromRowBasedMap: String = { | ||
| s""" | ||
| |while ($iterTermForFastHashMap.next()) { | ||
| |while ($limitNotReachedCondition $iterTermForFastHashMap.next()) { | ||
| | UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); | ||
| | UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); | ||
| | $outputFunc($keyTerm, $bufferTerm); | ||
|
|
@@ -798,7 +828,7 @@ case class HashAggregateExec( | |
| BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) | ||
| }) | ||
| s""" | ||
| |while ($iterTermForFastHashMap.hasNext()) { | ||
| |while ($limitNotReachedCondition $iterTermForFastHashMap.hasNext()) { | ||
| | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next(); | ||
| | ${generateKeyRow.code} | ||
| | ${generateBufferRow.code} | ||
|
|
@@ -813,7 +843,7 @@ case class HashAggregateExec( | |
|
|
||
| def outputFromRegularHashMap: String = { | ||
| s""" | ||
| |while ($limitNotReachedCond $iterTerm.next()) { | ||
| |while ($limitNotReachedCondition $iterTerm.next()) { | ||
| | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | ||
| | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | ||
| | $outputFunc($keyTerm, $bufferTerm); | ||
|
|
@@ -832,6 +862,7 @@ case class HashAggregateExec( | |
| |if (!$initAgg) { | ||
| | $initAgg = true; | ||
| | $createFastHashMap | ||
| | $addHookToCloseFastHashMap | ||
| | $hashMapTerm = $thisPlan.createHashMap(); | ||
| | long $beforeAgg = System.nanoTime(); | ||
| | $doAggFuncName(); | ||
|
|
@@ -866,13 +897,11 @@ case class HashAggregateExec( | |
| } | ||
| } | ||
|
|
||
| val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, | ||
| incCounter) = if (testFallbackStartsAt.isDefined) { | ||
| val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") | ||
| (s"$countTerm < ${testFallbackStartsAt.get._1}", | ||
| s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") | ||
| } else { | ||
| ("true", "true", "", "") | ||
| val (checkFallbackForBytesToBytesMap, resetCounter, incCounter) = testFallbackStartsAt match { | ||
| case Some((_, regularMapCounter)) => | ||
| val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") | ||
| (s"$countTerm < $regularMapCounter", s"$countTerm = 0;", s"$countTerm += 1;") | ||
| case _ => ("true", "", "") | ||
| } | ||
|
|
||
| val oomeClassName = classOf[SparkOutOfMemoryError].getName | ||
|
|
@@ -912,12 +941,10 @@ case class HashAggregateExec( | |
| // If fast hash map is on, we first generate code to probe and update the fast hash map. | ||
| // If the probe is successful the corresponding fast row buffer will hold the mutable row. | ||
| s""" | ||
| |if ($checkFallbackForGeneratedHashMap) { | ||
| | ${fastRowKeys.map(_.code).mkString("\n")} | ||
| | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { | ||
| | $fastRowBuffer = $fastHashMapTerm.findOrInsert( | ||
| | ${fastRowKeys.map(_.value).mkString(", ")}); | ||
| | } | ||
| |${fastRowKeys.map(_.code).mkString("\n")} | ||
| |if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { | ||
| | $fastRowBuffer = $fastHashMapTerm.findOrInsert( | ||
| | ${fastRowKeys.map(_.value).mkString(", ")}); | ||
| |} | ||
| |// Cannot find the key in fast hash map, try regular hash map. | ||
| |if ($fastRowBuffer == null) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is needed as hash aggregation output order is changed, and it causes
ExpressionInfoSuite.check outputs of expression examplestest failure in https://github.com/c21/spark/runs/2386397792?check_suite_focus=true .