Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ case class Grouping(child: Expression) extends Expression with Unevaluable
Examples:
> SELECT name, _FUNC_(), sum(age), avg(height) FROM VALUES (2, 'Alice', 165), (5, 'Bob', 180) people(age, name, height) GROUP BY cube(name, height);
Alice 0 2 165.0
Bob 0 5 180.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is needed as hash aggregation output order is changed, and it causes ExpressionInfoSuite.check outputs of expression examples test failure in https://github.com/c21/spark/runs/2386397792?check_suite_focus=true .

Alice 1 2 165.0
NULL 3 7 172.5
Bob 0 5 180.0
Bob 1 5 180.0
NULL 2 2 165.0
NULL 2 5 180.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit._
import scala.collection.mutable

import org.apache.spark.TaskContext
import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
import org.apache.spark.memory.SparkOutOfMemoryError
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -435,8 +435,8 @@ case class HashAggregateExec(
)
}

def getTaskMemoryManager(): TaskMemoryManager = {
TaskContext.get().taskMemoryManager()
def getTaskContext(): TaskContext = {
TaskContext.get()
}

def getEmptyAggregationBuffer(): InternalRow = {
Expand Down Expand Up @@ -647,7 +647,7 @@ case class HashAggregateExec(
(groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) ||
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType] ||
f.dataType.isInstanceOf[CalendarIntervalType]) &&
bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge)
bufferSchema.nonEmpty

// For vectorized hash map, We do not support byte array based decimal type for aggregate values
// as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place
Expand All @@ -663,7 +663,7 @@ case class HashAggregateExec(

private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = {
if (!checkIfFastHashMapSupported(ctx)) {
if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) {
Copy link
Contributor

@cloud-fan cloud-fan Apr 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last question: can we search the commit history and figure out why we didn't enable the fast hash map in the final aggregate? It seems we did it on purpose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - I was wondering at first place before making this PR as well. The decision to only support partial aggregate is made when the first level hash map was introduced (#12345 and #14176), and never changed afterwards. I checked with @sameeragarwal before making this PR. He told me there is no fundamental reason to not support final aggregate.

Just for documentation, I asked him why we don't support nested type (array/map/struct) as key type for fast hash map. He told me the reason was the size of keys might be too large for long array/map/struct, so the size of fast hash map may not fit in cache and lose the benefit.

if (!Utils.isTesting) {
logInfo(s"${SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key} is set to true, but"
+ " current version of codegened fast hashmap does not support this aggregate.")
}
Expand All @@ -683,7 +683,18 @@ case class HashAggregateExec(
} else if (sqlContext.conf.enableVectorizedHashMap) {
logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.")
}
val bitMaxCapacity = sqlContext.conf.fastHashAggregateRowMaxCapacityBit
val bitMaxCapacity = testFallbackStartsAt match {
case Some((fastMapCounter, _)) =>
// In testing, with fall back counter of fast hash map (`fastMapCounter`), set the max bit
// of map to be no more than log2(`fastMapCounter`). This helps control the number of keys
// in map to mimic fall back.
if (fastMapCounter <= 1) {
0
} else {
(math.log10(fastMapCounter) / math.log10(2)).floor.toInt
}
case _ => sqlContext.conf.fastHashAggregateRowMaxCapacityBit
}

val thisPlan = ctx.addReferenceObj("plan", this)

Expand Down Expand Up @@ -717,11 +728,28 @@ case class HashAggregateExec(
"org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
"fastHashMapIter", forceInline = true)
val create = s"$fastHashMapTerm = new $fastHashMapClassName(" +
s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());"
s"$thisPlan.getTaskContext().taskMemoryManager(), " +
s"$thisPlan.getEmptyAggregationBuffer());"
(iter, create)
}
} else ("", "")

// Generates the code to register a cleanup task with TaskContext to ensure that memory
// is guaranteed to be freed at the end of the task. This is necessary to avoid memory
// leaks in when the downstream operator does not fully consume the aggregation map's
// output (e.g. aggregate followed by limit).
val addHookToCloseFastHashMap = if (isFastHashMapEnabled) {
s"""
|$thisPlan.getTaskContext().addTaskCompletionListener(
| new org.apache.spark.util.TaskCompletionListener() {
| @Override
| public void onTaskCompletion(org.apache.spark.TaskContext context) {
| $fastHashMapTerm.close();
| }
|});
""".stripMargin
} else ""

// Create a name for the iterator from the regular hash map.
// Inline mutable state since not many aggregation operations in a task
val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName,
Expand Down Expand Up @@ -761,6 +789,8 @@ case class HashAggregateExec(
val bufferTerm = ctx.freshName("aggBuffer")
val outputFunc = generateResultFunction(ctx)

val limitNotReachedCondition = limitNotReachedCond
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the limit early termination for first level map as well. This is needed to fix test failure SQLMetricsSuite.SPARK-25497: LIMIT within whole stage codegen should not consume all the inputs in https://github.com/c21/spark/runs/2386397792?check_suite_focus=true. And this is good to have anyway.


def outputFromFastHashMap: String = {
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
Expand All @@ -773,7 +803,7 @@ case class HashAggregateExec(

def outputFromRowBasedMap: String = {
s"""
|while ($iterTermForFastHashMap.next()) {
|while ($limitNotReachedCondition $iterTermForFastHashMap.next()) {
| UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
| $outputFunc($keyTerm, $bufferTerm);
Expand All @@ -798,7 +828,7 @@ case class HashAggregateExec(
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable)
})
s"""
|while ($iterTermForFastHashMap.hasNext()) {
|while ($limitNotReachedCondition $iterTermForFastHashMap.hasNext()) {
| InternalRow $row = (InternalRow) $iterTermForFastHashMap.next();
| ${generateKeyRow.code}
| ${generateBufferRow.code}
Expand All @@ -813,7 +843,7 @@ case class HashAggregateExec(

def outputFromRegularHashMap: String = {
s"""
|while ($limitNotReachedCond $iterTerm.next()) {
|while ($limitNotReachedCondition $iterTerm.next()) {
| UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
| $outputFunc($keyTerm, $bufferTerm);
Expand All @@ -832,6 +862,7 @@ case class HashAggregateExec(
|if (!$initAgg) {
| $initAgg = true;
| $createFastHashMap
| $addHookToCloseFastHashMap
| $hashMapTerm = $thisPlan.createHashMap();
| long $beforeAgg = System.nanoTime();
| $doAggFuncName();
Expand Down Expand Up @@ -866,13 +897,11 @@ case class HashAggregateExec(
}
}

val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter,
incCounter) = if (testFallbackStartsAt.isDefined) {
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter")
(s"$countTerm < ${testFallbackStartsAt.get._1}",
s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;")
} else {
("true", "true", "", "")
val (checkFallbackForBytesToBytesMap, resetCounter, incCounter) = testFallbackStartsAt match {
case Some((_, regularMapCounter)) =>
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter")
(s"$countTerm < $regularMapCounter", s"$countTerm = 0;", s"$countTerm += 1;")
case _ => ("true", "", "")
}

val oomeClassName = classOf[SparkOutOfMemoryError].getName
Expand Down Expand Up @@ -912,12 +941,10 @@ case class HashAggregateExec(
// If fast hash map is on, we first generate code to probe and update the fast hash map.
// If the probe is successful the corresponding fast row buffer will hold the mutable row.
s"""
|if ($checkFallbackForGeneratedHashMap) {
| ${fastRowKeys.map(_.code).mkString("\n")}
| if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
| $fastRowBuffer = $fastHashMapTerm.findOrInsert(
| ${fastRowKeys.map(_.value).mkString(", ")});
| }
|${fastRowKeys.map(_.code).mkString("\n")}
|if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
| $fastRowBuffer = $fastHashMapTerm.findOrInsert(
| ${fastRowKeys.map(_.value).mkString(", ")});
|}
|// Cannot find the key in fast hash map, try regular hash map.
|if ($fastRowBuffer == null) {
Expand Down