diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 808c8222d2938..4d14b5e7927fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -214,9 +214,9 @@ case class Grouping(child: Expression) extends Expression with Unevaluable Examples: > SELECT name, _FUNC_(), sum(age), avg(height) FROM VALUES (2, 'Alice', 165), (5, 'Bob', 180) people(age, name, height) GROUP BY cube(name, height); Alice 0 2 165.0 + Bob 0 5 180.0 Alice 1 2 165.0 NULL 3 7 172.5 - Bob 0 5 180.0 Bob 1 5 180.0 NULL 2 2 165.0 NULL 2 5 180.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 6e23a2844d148..3c1304e9cdad8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit._ import scala.collection.mutable import org.apache.spark.TaskContext -import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} +import org.apache.spark.memory.SparkOutOfMemoryError import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -435,8 +435,8 @@ case class HashAggregateExec( ) } - def getTaskMemoryManager(): TaskMemoryManager = { - TaskContext.get().taskMemoryManager() + def getTaskContext(): TaskContext = { + TaskContext.get() } def getEmptyAggregationBuffer(): InternalRow = { @@ -647,7 +647,7 @@ case class HashAggregateExec( (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) || f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType] || f.dataType.isInstanceOf[CalendarIntervalType]) && - bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) + bufferSchema.nonEmpty // For vectorized hash map, We do not support byte array based decimal type for aggregate values // as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place @@ -663,7 +663,7 @@ case class HashAggregateExec( private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = { if (!checkIfFastHashMapSupported(ctx)) { - if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { + if (!Utils.isTesting) { logInfo(s"${SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key} is set to true, but" + " current version of codegened fast hashmap does not support this aggregate.") } @@ -683,7 +683,18 @@ case class HashAggregateExec( } else if (sqlContext.conf.enableVectorizedHashMap) { logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.") } - val bitMaxCapacity = sqlContext.conf.fastHashAggregateRowMaxCapacityBit + val bitMaxCapacity = testFallbackStartsAt match { + case Some((fastMapCounter, _)) => + // In testing, with fall back counter of fast hash map (`fastMapCounter`), set the max bit + // of map to be no more than log2(`fastMapCounter`). This helps control the number of keys + // in map to mimic fall back. + if (fastMapCounter <= 1) { + 0 + } else { + (math.log10(fastMapCounter) / math.log10(2)).floor.toInt + } + case _ => sqlContext.conf.fastHashAggregateRowMaxCapacityBit + } val thisPlan = ctx.addReferenceObj("plan", this) @@ -717,11 +728,28 @@ case class HashAggregateExec( "org.apache.spark.unsafe.KVIterator", "fastHashMapIter", forceInline = true) val create = s"$fastHashMapTerm = new $fastHashMapClassName(" + - s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());" + s"$thisPlan.getTaskContext().taskMemoryManager(), " + + s"$thisPlan.getEmptyAggregationBuffer());" (iter, create) } } else ("", "") + // Generates the code to register a cleanup task with TaskContext to ensure that memory + // is guaranteed to be freed at the end of the task. This is necessary to avoid memory + // leaks in when the downstream operator does not fully consume the aggregation map's + // output (e.g. aggregate followed by limit). + val addHookToCloseFastHashMap = if (isFastHashMapEnabled) { + s""" + |$thisPlan.getTaskContext().addTaskCompletionListener( + | new org.apache.spark.util.TaskCompletionListener() { + | @Override + | public void onTaskCompletion(org.apache.spark.TaskContext context) { + | $fastHashMapTerm.close(); + | } + |}); + """.stripMargin + } else "" + // Create a name for the iterator from the regular hash map. // Inline mutable state since not many aggregation operations in a task val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, @@ -761,6 +789,8 @@ case class HashAggregateExec( val bufferTerm = ctx.freshName("aggBuffer") val outputFunc = generateResultFunction(ctx) + val limitNotReachedCondition = limitNotReachedCond + def outputFromFastHashMap: String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { @@ -773,7 +803,7 @@ case class HashAggregateExec( def outputFromRowBasedMap: String = { s""" - |while ($iterTermForFastHashMap.next()) { + |while ($limitNotReachedCondition $iterTermForFastHashMap.next()) { | UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); | $outputFunc($keyTerm, $bufferTerm); @@ -798,7 +828,7 @@ case class HashAggregateExec( BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) s""" - |while ($iterTermForFastHashMap.hasNext()) { + |while ($limitNotReachedCondition $iterTermForFastHashMap.hasNext()) { | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next(); | ${generateKeyRow.code} | ${generateBufferRow.code} @@ -813,7 +843,7 @@ case class HashAggregateExec( def outputFromRegularHashMap: String = { s""" - |while ($limitNotReachedCond $iterTerm.next()) { + |while ($limitNotReachedCondition $iterTerm.next()) { | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | $outputFunc($keyTerm, $bufferTerm); @@ -832,6 +862,7 @@ case class HashAggregateExec( |if (!$initAgg) { | $initAgg = true; | $createFastHashMap + | $addHookToCloseFastHashMap | $hashMapTerm = $thisPlan.createHashMap(); | long $beforeAgg = System.nanoTime(); | $doAggFuncName(); @@ -866,13 +897,11 @@ case class HashAggregateExec( } } - val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, - incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") - (s"$countTerm < ${testFallbackStartsAt.get._1}", - s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") - } else { - ("true", "true", "", "") + val (checkFallbackForBytesToBytesMap, resetCounter, incCounter) = testFallbackStartsAt match { + case Some((_, regularMapCounter)) => + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") + (s"$countTerm < $regularMapCounter", s"$countTerm = 0;", s"$countTerm += 1;") + case _ => ("true", "", "") } val oomeClassName = classOf[SparkOutOfMemoryError].getName @@ -912,12 +941,10 @@ case class HashAggregateExec( // If fast hash map is on, we first generate code to probe and update the fast hash map. // If the probe is successful the corresponding fast row buffer will hold the mutable row. s""" - |if ($checkFallbackForGeneratedHashMap) { - | ${fastRowKeys.map(_.code).mkString("\n")} - | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { - | $fastRowBuffer = $fastHashMapTerm.findOrInsert( - | ${fastRowKeys.map(_.value).mkString(", ")}); - | } + |${fastRowKeys.map(_.code).mkString("\n")} + |if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { + | $fastRowBuffer = $fastHashMapTerm.findOrInsert( + | ${fastRowKeys.map(_.value).mkString(", ")}); |} |// Cannot find the key in fast hash map, try regular hash map. |if ($fastRowBuffer == null) {