Skip to content

Commit a158125

Browse files
committed
simply fash hash map condition check
1 parent bb46788 commit a158125

File tree

1 file changed

+19
-28
lines changed

1 file changed

+19
-28
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -477,52 +477,43 @@ case class HashAggregateExec(
477477
}
478478

479479
/**
480-
* Using the vectorized hash map in HashAggregate is currently supported for all primitive
481-
* data types during partial aggregation. However, we currently only enable the hash map for a
482-
* subset of cases that've been verified to show performance improvements on our benchmarks
483-
* subject to an internal conf that sets an upper limit on the maximum length of the aggregate
484-
* key/value schema.
485-
*
480+
* A required check for any fast hash map implementation. Currently fast hash map is supported
481+
* for primitive data types during partial aggregation.
486482
* This list of supported use-cases should be expanded over time.
487483
*/
488-
private def enableVectorizedHashMap(ctx: CodegenContext): Boolean = {
489-
val schemaLength = (groupingKeySchema ++ bufferSchema).length
484+
private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = {
490485
val isSupported =
491486
(groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) ||
492487
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) &&
493488
bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge)
494489

495-
// We do not support byte array based decimal type for aggregate values as
496-
// ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place
497-
// updates. Due to this, appending the byte array in the vectorized hash map can turn out to be
498-
// quite inefficient and can potentially OOM the executor.
490+
// Acting conservative and do not support byte array based decimal type for aggregate values
491+
// for now.
499492
val isNotByteArrayDecimalType = bufferSchema.map(_.dataType).filter(_.isInstanceOf[DecimalType])
500493
.forall(!DecimalType.isByteArrayDecimalType(_))
501-
isSupported && isNotByteArrayDecimalType &&
502-
schemaLength <= sqlContext.conf.vectorizedAggregateMapMaxColumns
494+
495+
isSupported && isNotByteArrayDecimalType
503496
}
504497

505498
/**
506-
* Using the row-based hash map in HashAggregate is currently supported for all primitive
507-
* data types during partial aggregation. However, we currently only enable the hash map for a
499+
* We currently only enable the vectorized hash map for a
508500
* subset of cases that've been verified to show performance improvements on our benchmarks
509501
* subject to an internal conf that sets an upper limit on the maximum length of the aggregate
510502
* key/value schema.
511503
*
512-
* This list of supported use-cases should be expanded over time.
513504
*/
514-
private def enableRowBasedHashMap(ctx: CodegenContext): Boolean = {
515-
val isSupported =
516-
(groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) ||
517-
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) &&
518-
bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge)
519-
520-
// Acting conservative and do not support byte array based decimal type for aggregate values
521-
// for now.
522-
val isNotByteArrayDecimalType = bufferSchema.map(_.dataType).filter(_.isInstanceOf[DecimalType])
523-
.forall(!DecimalType.isByteArrayDecimalType(_))
505+
private def enableVectorizedHashMap(ctx: CodegenContext): Boolean = {
506+
val schemaLength = (groupingKeySchema ++ bufferSchema).length
507+
checkIfFastHashMapSupported(ctx) &&
508+
schemaLength <= sqlContext.conf.vectorizedAggregateMapMaxColumns
509+
}
524510

525-
isSupported && isNotByteArrayDecimalType
511+
/**
512+
* We currently only enable row based hash map if vectorized hash map is supported,
513+
* and if we pass a requirement check to support fast hash map.
514+
*/
515+
private def enableRowBasedHashMap(ctx: CodegenContext): Boolean = {
516+
checkIfFastHashMapSupported(ctx)
526517
}
527518

528519
private def setFastHashMapImpl(ctx: CodegenContext) = {

0 commit comments

Comments
 (0)