From cb2acca73cbc11c5664058028d0f02f596f71135 Mon Sep 17 00:00:00 2001 From: Shipra Agrawal Date: Tue, 8 Jun 2021 18:35:41 -0700 Subject: [PATCH 1/3] https://github.com/apache/spark/pull/28804/commits Co-authored-by: Karuppayya Rajendran --- .../apache/spark/sql/internal/SQLConf.scala | 42 +++ .../UnsafeFixedWidthAggregationMap.java | 13 + .../aggregate/HashAggregateExec.scala | 248 ++++++++++++++---- .../aggregate/HashMapGenerator.scala | 10 + .../execution/WholeStageCodegenSuite.scala | 71 +++-- .../execution/metric/SQLMetricsSuite.scala | 92 +++---- .../execution/AggregationQuerySuite.scala | 59 +++-- 7 files changed, 403 insertions(+), 132 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cfd9704ba4525..e70be169909bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2615,6 +2615,42 @@ object SQLConf { .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].") .createWithDefault(16) + + val SKIP_PARTIAL_AGGREGATE_MINROWS = + buildConf("spark.sql.aggregate.skipPartialAggregate.minNumRows") + .internal() + .doc("Number of records after which aggregate operator checks if " + + "partial aggregation phase can be avoided") + .version("3.1.0") + .longConf + .createWithDefault(100000) + + val SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO = + buildConf("spark.sql.aggregate.skipPartialAggregate.aggregateRatio") + .internal() + .doc("Ratio beyond which the partial aggregation is skipped." + + "This is computed by taking the ratio of number of records present" + + " in map of Aggregate operator to the total number of records processed" + + " by the Aggregate operator.") + .version("3.1.0") + .doubleConf + .checkValue(ratio => ratio > 0 && ratio < 1, "Invalid value for " + + "spark.sql.aggregate.skipPartialAggregate.aggregateRatio. Valid value needs" + + " to be between 0 and 1" ) + .createWithDefault(0.5) + + val SKIP_PARTIAL_AGGREGATE_ENABLED = + buildConf("spark.sql.aggregate.skipPartialAggregate") + .internal() + .doc("When enabled, the partial aggregation is skipped when the following" + + "two conditions are met. 1. When the total number of records processed is greater" + + s"than threshold defined by ${SKIP_PARTIAL_AGGREGATE_MINROWS.key} 2. When the ratio" + + "of record count in map to the total records is less that value defined by " + + s"${SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO.key}") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") .doc("Compression codec used in writing of AVRO files. Supported codecs: " + "uncompressed, deflate, snappy, bzip2, xz and zstandard. Default codec is snappy.") @@ -3605,6 +3641,12 @@ class SQLConf extends Serializable with Logging { def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + def skipPartialAggregate: Boolean = getConf(SKIP_PARTIAL_AGGREGATE_ENABLED) + + def skipPartialAggregateThreshold: Long = getConf(SKIP_PARTIAL_AGGREGATE_MINROWS) + + def skipPartialAggregateRatio: Double = getConf(SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO) + def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) def uiExplainMode: String = getConf(UI_EXPLAIN_MODE) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 117e98f33a0ec..3046423ad7658 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -63,6 +63,13 @@ public final class UnsafeFixedWidthAggregationMap { */ private final UnsafeRow currentAggregationBuffer; + /** + * Number of rows that were added to the map + * This includes the elements that were passed on sorter + * using {@link #destructAndCreateExternalSorter()} + */ + private long numRowsAdded = 0L; + /** * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given * schema, false otherwise. @@ -147,6 +154,8 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) { ); if (!putSucceeded) { return null; + } else { + numRowsAdded = numRowsAdded + 1; } } @@ -249,4 +258,8 @@ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOExcepti package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), map); } + + public long getNumRows() { + return numRowsAdded; + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 1192f02955a58..8e38683a8d16c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -61,14 +61,39 @@ case class HashAggregateExec( child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), - "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), - "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), - "avgHashProbe" -> - SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters"), - "numTasksFallBacked" -> SQLMetrics.createMetric(sparkContext, "number of sort fallback tasks")) + override lazy val metrics = { + val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), + "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), + "avgHashProbe" -> + SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters"), + "numTasksFallBacked" -> SQLMetrics.createMetric(sparkContext, "number of sort fallback tasks")) + if (skipPartialAggregateEnabled) { + metrics ++ Map("partialAggSkipped" -> SQLMetrics.createMetric(sparkContext, + "number of skipped records for partial aggregates")) + } else { + metrics + } + } + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override protected def outputExpressions: Seq[NamedExpression] = resultExpressions + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash // map and/or the sort-based aggregation once it has processed a given number of input rows. @@ -412,6 +437,14 @@ case class HashAggregateExec( private var fastHashMapTerm: String = _ private var isFastHashMapEnabled: Boolean = false + private var avoidSpillInPartialAggregateTerm: String = _ + private val skipPartialAggregateEnabled = { + conf.skipPartialAggregate && + modes.nonEmpty && modes.forall(_ == Partial) && find(_.isInstanceOf[ExpandExec]).isEmpty + } + private var rowCountTerm: String = _ + private var outputFunc: String = _ + // whether a vectorized hashmap is used instead // we have decided to always use the row-based hashmap, // but the vectorized hashmap can still be switched on for testing and benchmarking purposes. @@ -685,6 +718,20 @@ case class HashAggregateExec( private def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") + if (conf.enableTwoLevelAggMap) { + + var childrenConsumed: String = null + if (skipPartialAggregateEnabled) { + avoidSpillInPartialAggregateTerm = ctx. + addMutableState(CodeGenerator.JAVA_BOOLEAN, + "avoidPartialAggregate", + term => s"$term = ${Utils.isTesting};") + rowCountTerm = ctx. + addMutableState(CodeGenerator.JAVA_LONG, "rowCount") + childrenConsumed = ctx. + addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed") + } + if (conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else if (conf.enableVectorizedHashMap) { @@ -784,10 +831,15 @@ case class HashAggregateExec( finishRegularHashMap } + outputFunc = generateResultFunction(ctx) + val genChildrenConsumedCode = if (skipPartialAggregateEnabled) { + s"${childrenConsumed} = true;" + } else "" val doAggFuncName = ctx.addNewFunction(doAgg, s""" |private void $doAgg() throws java.io.IOException { | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | $genChildrenConsumedCode | $finishHashMap |} """.stripMargin) @@ -795,7 +847,6 @@ case class HashAggregateExec( // generate code for output val keyTerm = ctx.freshName("aggKey") val bufferTerm = ctx.freshName("aggBuffer") - val outputFunc = generateResultFunction(ctx) val limitNotReachedCondition = limitNotReachedCond @@ -866,6 +917,15 @@ case class HashAggregateExec( val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") + val genCodePostInitCode = + if (skipPartialAggregateEnabled) { + s""" + |if (!$childrenConsumed) { + | $doAggFuncName(); + | if (shouldStop()) return; + |} + """.stripMargin + } else "" s""" |if (!$initAgg) { | $initAgg = true; @@ -875,13 +935,17 @@ case class HashAggregateExec( | long $beforeAgg = System.nanoTime(); | $doAggFuncName(); | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); + | if (shouldStop()) return; |} + |$genCodePostInitCode |// output the result |$outputFromFastHashMap |$outputFromRegularHashMap """.stripMargin } + override def needStopCheck: Boolean = skipPartialAggregateEnabled + private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( @@ -912,48 +976,107 @@ case class HashAggregateExec( case _ => ("true", "", "") } + val skipPartialAggregateThreshold = conf.skipPartialAggregateThreshold + val skipPartialAggRatio = conf.skipPartialAggregateRatio + + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") val oomeClassName = classOf[SparkOutOfMemoryError].getName + val findOrInsertRegularHashMap: String = { + def getAggBufferFromMap = { + s""" + |// generate grouping key + |${unsafeRowKeyCode.code} + |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); + |if ($checkFallbackForBytesToBytesMap) { + | // try to get the buffer from hash map + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); + |} + """.stripMargin + } - val findOrInsertRegularHashMap: String = - s""" - |// generate grouping key - |${unsafeRowKeyCode.code} - |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); - |if ($checkFallbackForBytesToBytesMap) { - | // try to get the buffer from hash map - | $unsafeRowBuffer = - | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); - |} - |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based - |// aggregation after processing all input rows. - |if ($unsafeRowBuffer == null) { - | if ($sorterTerm == null) { - | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); - | } else { - | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); - | } - | $resetCounter - | // the hash map had be spilled, it should have enough memory now, - | // try to allocate buffer again. - | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( - | $unsafeRowKeys, $unsafeRowKeyHash); - | if ($unsafeRowBuffer == null) { - | // failed to allocate the first page - | throw new $oomeClassName("No enough memory for aggregation"); - | } - |} + def addToSorter: String = { + s""" + |if ($sorterTerm == null) { + | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + |} else { + | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + |} + |$resetCounter + |// the hash map had be spilled, it should have enough memory now, + |// try to allocate buffer again. + |$unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( + | $unsafeRowKeys, $unsafeRowKeyHash); + |if ($unsafeRowBuffer == null) { + | // failed to allocate the first page + | throw new $oomeClassName("No enough memory for aggregation"); + |}""".stripMargin + } + + if (skipPartialAggregateEnabled) { + val checkIfPartialAggSkipped = + s""" + |!($rowCountTerm < $skipPartialAggregateThreshold) && + | ((float)$countTerm/$rowCountTerm) > $skipPartialAggRatio; + |""".stripMargin + s""" + |if (!$avoidSpillInPartialAggregateTerm) { + | $getAggBufferFromMap + | // Can't allocate buffer from the hash map. + | // Check if we can avoid partial aggregation. + | // Otherwise, Spill the map and fallback to sort-based + | // aggregation after processing all input rows. + | if ($unsafeRowBuffer == null) { + | $countTerm = $countTerm + $hashMapTerm.getNumRows(); + | boolean skipPartAgg = $checkIfPartialAggSkipped + | if (skipPartAgg) { + | // Aggregation buffer is created later + | $avoidSpillInPartialAggregateTerm = true; + | } else { + | $addToSorter + | } + | } + |} + """.stripMargin + } else { + s""" + |$getAggBufferFromMap + |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based + |// aggregation after processing all input rows. + |if ($unsafeRowBuffer == null) { + | $addToSorter + |} """.stripMargin + } + } val findOrInsertHashMap: String = { - if (isFastHashMapEnabled) { + val insertCode = if (isFastHashMapEnabled) { + def findOrInsertIntoFastHashMap = { + s""" + |${fastRowKeys.map(_.code).mkString("\n")} + |if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { + | $fastRowBuffer = $fastHashMapTerm.findOrInsert( + | ${fastRowKeys.map(_.value).mkString(", ")}); + |} + |""".stripMargin + } + val insertFastMap = if (skipPartialAggregateEnabled) { + s""" + |if (!$avoidSpillInPartialAggregateTerm) { + | $findOrInsertIntoFastHashMap + |} + |$countTerm = $fastHashMapTerm.getNumRows(); + |""".stripMargin + } else { + s""" + |$findOrInsertIntoFastHashMap + |""".stripMargin + } // If fast hash map is on, we first generate code to probe and update the fast hash map. // If the probe is successful the corresponding fast row buffer will hold the mutable row. s""" - |${fastRowKeys.map(_.code).mkString("\n")} - |if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { - | $fastRowBuffer = $fastHashMapTerm.findOrInsert( - | ${fastRowKeys.map(_.value).mkString(", ")}); - |} + |$insertFastMap |// Cannot find the key in fast hash map, try regular hash map. |if ($fastRowBuffer == null) { | $findOrInsertRegularHashMap @@ -962,6 +1085,27 @@ case class HashAggregateExec( } else { findOrInsertRegularHashMap } + def createEmptyAggBufferAndUpdateMetrics: String = { + if (skipPartialAggregateEnabled) { + val numAggSkippedRows = metricTerm(ctx, "partialAggSkipped") + val initExpr = declFunctions.flatMap(f => f.initialValues) + val emptyBufferKeyCode = GenerateUnsafeProjection.createCode(ctx, initExpr) + s""" + |// Create an empty aggregation buffer + |if ($avoidSpillInPartialAggregateTerm) { + | ${unsafeRowKeyCode.code} + | ${emptyBufferKeyCode.code} + | $unsafeRowBuffer = ${emptyBufferKeyCode.value}; + | $numAggSkippedRows.add(1); + |} + |""".stripMargin + } else "" + } + + s""" + |$insertCode + |$createEmptyAggBufferAndUpdateMetrics + |""".stripMargin } val inputAttrs = aggregateBufferAttributes ++ inputAttributes @@ -1028,7 +1172,7 @@ case class HashAggregateExec( } val updateRowInHashMap: String = { - if (isFastHashMapEnabled) { + val updateRowInMap = if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => @@ -1093,6 +1237,22 @@ case class HashAggregateExec( } else { updateRowInRegularHashMap } + + def outputRow: String = { + if (skipPartialAggregateEnabled) { + s""" + |if ($avoidSpillInPartialAggregateTerm) { + | $outputFunc(${unsafeRowKeyCode.value}, $unsafeRowBuffer); + |} + |$rowCountTerm = $rowCountTerm + 1; + |""".stripMargin + } else "" + } + + s""" + |$updateRowInMap + |$outputRow + |""".stripMargin } val declareRowBuffer: String = if (isFastHashMapEnabled) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index b3f5e341f66f5..45c4ad74ff04b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -75,6 +75,8 @@ abstract class HashMapGenerator( | |${generateRowIterator()} | + |${generateNumRows()} + | |${generateClose()} |} """.stripMargin @@ -136,6 +138,14 @@ abstract class HashMapGenerator( """.stripMargin } + protected final def generateNumRows(): String = { + s""" + |public int getNumRows() { + | return batch.numRows(); + |} + """.stripMargin + } + protected final def genComputeHash( ctx: CodegenContext, input: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 6cc6e33dd688a..57c76470f322a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.aggregate.Partial import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.aggregate.{AggUtils, HashAggregateExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -278,7 +279,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val twoJoinsDF = df1.join(df2, $"k1" < $"k2").crossJoin(df3) hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect { case WholeStageCodegenExec(BroadcastNestedLoopJoinExec( - _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true + _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true }.size === 1 assert(hasJoinInCodegen == codegenEnabled) checkAnswer(twoJoinsDF, @@ -317,7 +318,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession .join(df3, $"k1" <= $"k3", "left_outer") hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect { case WholeStageCodegenExec(BroadcastNestedLoopJoinExec( - _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true + _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true }.size === 1 assert(hasJoinInCodegen == codegenEnabled) checkAnswer(twoJoinsDF, @@ -386,7 +387,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined) assert(ds.collect() === 0.until(10).map(_.toString).toArray) } @@ -404,7 +405,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) assert(ds.collect() === Array(0, 6)) } @@ -417,7 +418,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val planInt = dsIntFilter.queryExecution.executedPlan assert(planInt.collect { case WholeStageCodegenExec(FilterExec(_, - ColumnarToRowExec(InputAdapter(_: InMemoryTableScanExec)))) => () + ColumnarToRowExec(InputAdapter(_: InMemoryTableScanExec)))) => () }.length == 1) assert(dsIntFilter.collect() === Array(1, 2)) @@ -555,7 +556,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession .write.mode(SaveMode.Overwrite).parquet(path) withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", - SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") { + SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") { val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i") val df = spark.read.parquet(path).selectExpr(projection: _*) @@ -641,18 +642,18 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession .join(baseTable, "idx") assert(distinctWithId.queryExecution.executedPlan.collectFirst { case WholeStageCodegenExec( - ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true + ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true }.isDefined) checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0))) // BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate // expression val groupByWithId = - baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id()) + baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id()) .join(baseTable, "idx") assert(groupByWithId.queryExecution.executedPlan.collectFirst { case WholeStageCodegenExec( - ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true + ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true }.isDefined) checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0))) } @@ -681,7 +682,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert( executedPlan.find { case WholeStageCodegenExec( - HashAggregateExec(_, _, _, _, _, _, _: LocalTableScanExec)) => true + HashAggregateExec(_, _, _, _, _, _, _: LocalTableScanExec)) => true case _ => false }.isDefined, "LocalTableScanExec should be within a WholeStageCodegen domain.") @@ -689,9 +690,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession test("Give up splitting aggregate code if a parameter length goes over the limit") { withSQLConf( - SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true", - SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", - "spark.sql.CodeGenerator.validParamLength" -> "0") { + SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true", + SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", + "spark.sql.CodeGenerator.validParamLength" -> "0") { withTable("t") { val expectedErrMsg = "Failed to split aggregate code into small functions" Seq( @@ -710,9 +711,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession test("Give up splitting subexpression code if a parameter length goes over the limit") { withSQLConf( - SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "false", - SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", - "spark.sql.CodeGenerator.validParamLength" -> "0") { + SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "false", + SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", + "spark.sql.CodeGenerator.validParamLength" -> "0") { withTable("t") { val expectedErrMsg = "Failed to split subexpression code into small functions" Seq( @@ -730,4 +731,40 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } } + + test("Avoid spill in partial aggregation" ) { + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key -> "true"), + (SQLConf.SKIP_PARTIAL_AGGREGATE_MINROWS.key -> "2")) { + // Create Dataframes + val data = Seq(("James", 1), ("James", 1), ("Phil", 1)) + val aggDF = data.toDF("name", "values").groupBy("name").sum("values") + val partAggNode = aggDF.queryExecution.executedPlan.find { + case h: HashAggregateExec => + val modes = h.aggregateExpressions.map(_.mode) + modes.nonEmpty && modes.forall(_ == Partial) + case _ => false + } + + checkAnswer(aggDF, Seq(Row("James", 2), Row("Phil", 1))) + assert(partAggNode.isDefined, + "No HashAggregate node with partial aggregate expression found") + assert(partAggNode.get.metrics("partialAggSkipped").value == data.size, + "Partial aggregation got triggered in partial hash aggregate node") + } + } + + test(s"Distinct: Partial aggregation should happen for " + + "HashAggregate nodes performing partial Aggregate operations " ) { + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key -> "true")) { + val aggDF = testData2.select(sumDistinct($"a"), sum($"b")) + val aggNodes = aggDF.queryExecution.executedPlan.collect { + case h: HashAggregateExec => h + } + val (baseNodes, other) = aggNodes.partition(_.child.isInstanceOf[SerializeFromObjectExec]) + checkAnswer(aggDF, Row(6, 9)) + assert(baseNodes.size == 1 ) + assert(baseNodes.head.metrics("partialAggSkipped").value == testData2.count()) + assert(other.forall(!_.metrics.contains("partialAggSkipped"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 922e7b89dc01c..299cee95244d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -148,48 +148,52 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } test("Aggregate metrics: track avg probe") { - // The executed plan looks like: - // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) - // +- Exchange hashpartitioning(a#61, 5) - // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) - // +- Exchange RoundRobinPartitioning(1) - // +- LocalTableScan [a#61] - // - // Assume the execution plan with node id is: - // Wholestage disabled: - // HashAggregate(nodeId = 0) - // Exchange(nodeId = 1) - // HashAggregate(nodeId = 2) - // Exchange (nodeId = 3) - // LocalTableScan(nodeId = 4) - // - // Wholestage enabled: - // WholeStageCodegen(nodeId = 0) - // HashAggregate(nodeId = 1) - // Exchange(nodeId = 2) - // WholeStageCodegen(nodeId = 3) - // HashAggregate(nodeId = 4) - // Exchange(nodeId = 5) - // LocalTableScan(nodeId = 6) - Seq(true, false).foreach { enableWholeStage => - val df = generateRandomBytesDF().repartition(2).groupBy('a).count() - val nodeIds = if (enableWholeStage) { - Set(4L, 1L) - } else { - Set(2L, 0L) - } - val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get - nodeIds.foreach { nodeId => - val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString - if (!probes.contains("\n")) { - // It's a single metrics value - assert(probes.toDouble > 1.0) + if (spark.sessionState.conf.getConf(SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED)) { + logInfo("Skipping, since partial Aggregation is disabled") + } else { + // The executed plan looks like: + // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) + // +- Exchange hashpartitioning(a#61, 5) + // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) + // +- Exchange RoundRobinPartitioning(1) + // +- LocalTableScan [a#61] + // + // Assume the execution plan with node id is: + // Wholestage disabled: + // HashAggregate(nodeId = 0) + // Exchange(nodeId = 1) + // HashAggregate(nodeId = 2) + // Exchange (nodeId = 3) + // LocalTableScan(nodeId = 4) + // + // Wholestage enabled: + // WholeStageCodegen(nodeId = 0) + // HashAggregate(nodeId = 1) + // Exchange(nodeId = 2) + // WholeStageCodegen(nodeId = 3) + // HashAggregate(nodeId = 4) + // Exchange(nodeId = 5) + // LocalTableScan(nodeId = 6) + Seq(true, false).foreach { enableWholeStage => + val df = generateRandomBytesDF().repartition(2).groupBy('a).count() + val nodeIds = if (enableWholeStage) { + Set(4L, 1L) } else { - val mainValue = probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")") - // Extract min, med, max from the string and strip off everything else. - val index = mainValue.indexOf(" (", 0) - mainValue.slice(0, index).split(", ").foreach { - probe => assert(probe.toDouble > 1.0) + Set(2L, 0L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString + if (!probes.contains("\n")) { + // It's a single metrics value + assert(probes.toDouble > 1.0) + } else { + val mainValue = probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")") + // Extract min, med, max from the string and strip off everything else. + val index = mainValue.indexOf(" (", 0) + mainValue.slice(0, index).split(", ").foreach { + probe => assert(probe.toDouble > 1.0) + } } } } @@ -653,9 +657,9 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("SPARK-25602: SparkPlan.getByteArrayRdd should not consume the input when not necessary") { def checkFilterAndRangeMetrics( - df: DataFrame, - filterNumOutputs: Int, - rangeNumOutputs: Int): Unit = { + df: DataFrame, + filterNumOutputs: Int, + rangeNumOutputs: Int): Unit = { val plan = df.queryExecution.executedPlan val filters = collectNodeWithinWholeStage[FilterExec](plan) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 069bc7372b038..710881198b1b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1060,38 +1060,43 @@ class HashAggregationQuerySuite extends AggregationQuerySuite class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - Seq("true", "false").foreach { enableTwoLevelMaps => - withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> - enableTwoLevelMaps) { - Seq(4, 8).foreach { uaoSize => - UnsafeAlignedOffset.setUaoSize(uaoSize) - (1 to 3).foreach { fallbackStartsAt => - withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> - s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") { - // Create a new df to make sure its physical operator picks up - // spark.sql.TungstenAggregate.testFallbackStartsAt. - // todo: remove it? - val newActual = Dataset.ofRows(spark, actual.logicalPlan) - - QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match { - case Some(errorMessage) => - val newErrorMessage = - s""" - |The following aggregation query failed when using HashAggregate with - |controlled fallback (it falls back to bytes to bytes map once it has - |processed ${fallbackStartsAt - 1} input rows and to sort-based aggregation - |once it has processed $fallbackStartsAt input rows). - |The query is ${actual.queryExecution} - |$errorMessage + // The HashAggregationQueryWithControlledFallbackSuite is dependent on ordering and also + // assumes partial aggregation to have happened. + // disabling the flag that skips partial aggregation + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "false")) { + Seq("true", "false").foreach { enableTwoLevelMaps => + withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> + enableTwoLevelMaps) { + Seq(4, 8).foreach { uaoSize => + UnsafeAlignedOffset.setUaoSize(uaoSize) + (1 to 3).foreach { fallbackStartsAt => + withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> + s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") { + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? + val newActual = Dataset.ofRows(spark, actual.logicalPlan) + + QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using HashAggregate with + |controlled fallback (it falls back to bytes to bytes map once it has + |processed ${fallbackStartsAt - 1} input rows and to sort-based aggregation + |once it has processed $fallbackStartsAt input rows). + |The query is ${actual.queryExecution} + |$errorMessage """.stripMargin - fail(newErrorMessage) - case None => // Success + fail(newErrorMessage) + case None => // Success + } } } + // reset static uaoSize to avoid affect other tests + UnsafeAlignedOffset.setUaoSize(0) } - // reset static uaoSize to avoid affect other tests - UnsafeAlignedOffset.setUaoSize(0) } } } From 362c5eab243553753f3f290f16efa88a3a72a7c7 Mon Sep 17 00:00:00 2001 From: Shipra Agrawal Date: Tue, 22 Jun 2021 14:31:42 -0700 Subject: [PATCH 2/3] changes on top of 28804 --- .../apache/spark/sql/internal/SQLConf.scala | 44 +++---- .../aggregate/HashAggregateExec.scala | 122 ++++++++++-------- .../execution/WholeStageCodegenSuite.scala | 2 +- .../execution/AggregationQuerySuite.scala | 2 + 4 files changed, 94 insertions(+), 76 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e70be169909bf..c43760c44290b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2615,39 +2615,39 @@ object SQLConf { .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].") .createWithDefault(16) - - val SKIP_PARTIAL_AGGREGATE_MINROWS = - buildConf("spark.sql.aggregate.skipPartialAggregate.minNumRows") + val SKIP_PARTIAL_AGGREGATE_MIN_ROWS = + buildConf("spark.sql.aggregate.skipPartialAggregate.minNumRows") .internal() - .doc("Number of records after which aggregate operator checks if " + - "partial aggregation phase can be avoided") - .version("3.1.0") + .doc("The minimal number of input rows processed before hash aggregate checks if it can be" + + " skipped. Only applies to partial hash aggregate.") + .version("3.1.2") .longConf + .checkValue(minNumRows => minNumRows > 0, "Invalid value for " + + "spark.sql.aggregate.skipPartialAggregate.minNumRows. Valid value needs" + + " to be greater than 0" ) .createWithDefault(100000) - val SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO = - buildConf("spark.sql.aggregate.skipPartialAggregate.aggregateRatio") + val SKIP_PARTIAL_AGGREGATE_MIN_RATIO = + buildConf("spark.sql.aggregate.skipPartialAggregate.minRatio") .internal() - .doc("Ratio beyond which the partial aggregation is skipped." + - "This is computed by taking the ratio of number of records present" + - " in map of Aggregate operator to the total number of records processed" + - " by the Aggregate operator.") - .version("3.1.0") + .doc("The minimal ratio between input and output rows for partial hash aggregate allows it" + + " to be skipped") + .version("3.1.2") .doubleConf .checkValue(ratio => ratio > 0 && ratio < 1, "Invalid value for " + - "spark.sql.aggregate.skipPartialAggregate.aggregateRatio. Valid value needs" + + "spark.sql.aggregate.skipPartialAggregate.minRatio. Valid value needs" + " to be between 0 and 1" ) .createWithDefault(0.5) val SKIP_PARTIAL_AGGREGATE_ENABLED = - buildConf("spark.sql.aggregate.skipPartialAggregate") + buildConf("spark.sql.aggregate.skipPartialAggregate.enabled") .internal() - .doc("When enabled, the partial aggregation is skipped when the following" + - "two conditions are met. 1. When the total number of records processed is greater" + - s"than threshold defined by ${SKIP_PARTIAL_AGGREGATE_MINROWS.key} 2. When the ratio" + + .doc("When enabled, the partial aggregation is skipped when the following " + + "two conditions are met. 1. When the total number of records processed is greater " + + s"than threshold defined by ${SKIP_PARTIAL_AGGREGATE_MIN_ROWS.key} 2. When the ratio " + "of record count in map to the total records is less that value defined by " + - s"${SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO.key}") - .version("3.1.0") + s"${SKIP_PARTIAL_AGGREGATE_MIN_RATIO.key}") + .version("3.1.2") .booleanConf .createWithDefault(true) @@ -3643,9 +3643,9 @@ class SQLConf extends Serializable with Logging { def skipPartialAggregate: Boolean = getConf(SKIP_PARTIAL_AGGREGATE_ENABLED) - def skipPartialAggregateThreshold: Long = getConf(SKIP_PARTIAL_AGGREGATE_MINROWS) + def skipPartialAggregateThreshold: Long = getConf(SKIP_PARTIAL_AGGREGATE_MIN_ROWS) - def skipPartialAggregateRatio: Double = getConf(SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO) + def skipPartialAggregateRatio: Double = getConf(SKIP_PARTIAL_AGGREGATE_MIN_RATIO) def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8e38683a8d16c..121c4d780a636 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -437,7 +437,7 @@ case class HashAggregateExec( private var fastHashMapTerm: String = _ private var isFastHashMapEnabled: Boolean = false - private var avoidSpillInPartialAggregateTerm: String = _ + private var skipPartialAggTerm: String = _ private val skipPartialAggregateEnabled = { conf.skipPartialAggregate && modes.nonEmpty && modes.forall(_ == Partial) && find(_.isInstanceOf[ExpandExec]).isEmpty @@ -499,7 +499,8 @@ case class HashAggregateExec( peakMemory: SQLMetric, spillSize: SQLMetric, avgHashProbe: SQLMetric, - numTasksFallBacked: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = { + numTasksFallBacked: SQLMetric, + skipPartialAggTerm: Boolean): KVIterator[UnsafeRow, UnsafeRow] = { // update peak execution memory val mapMemory = hashMap.getPeakMemoryUsedBytes @@ -521,6 +522,11 @@ case class HashAggregateExec( numTasksFallBacked += 1 sorter.merge(hashMap.destructAndCreateExternalSorter()) hashMap.free() + if (!skipPartialAggTerm) { + sorter.merge(hashMap.destructAndCreateExternalSorter()) + hashMap.free() + } + val sortedIter = sorter.sortedIterator() // Create a KVIterator based on the sorted iterator. @@ -722,12 +728,12 @@ case class HashAggregateExec( var childrenConsumed: String = null if (skipPartialAggregateEnabled) { - avoidSpillInPartialAggregateTerm = ctx. + skipPartialAggTerm = ctx. addMutableState(CodeGenerator.JAVA_BOOLEAN, - "avoidPartialAggregate", + "skipPartialAgg", term => s"$term = ${Utils.isTesting};") rowCountTerm = ctx. - addMutableState(CodeGenerator.JAVA_LONG, "rowCount") + addMutableState(CodeGenerator.JAVA_LONG, "inputRowCount") childrenConsumed = ctx. addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed") } @@ -820,8 +826,14 @@ case class HashAggregateExec( val avgHashProbe = metricTerm(ctx, "avgHashProbe") val numTasksFallBacked = metricTerm(ctx, "numTasksFallBacked") + val mapCleared = if (skipPartialAggregateEnabled) { + s"$skipPartialAggTerm" + } else { + s"false" + } + val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" + - s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe, $numTasksFallBacked);" + s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe, $numTasksFallBacked, $mapCleared);" val finishHashMap = if (isFastHashMapEnabled) { s""" |$iterTermForFastHashMap = $fastHashMapTerm.rowIterator(); @@ -956,6 +968,7 @@ case class HashAggregateExec( val unsafeRowKeyHash = ctx.freshName("unsafeRowKeyHash") val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") val fastRowBuffer = ctx.freshName("fastAggBuffer") + val skipPartialAgg = ctx.freshName("skipPartialAgg") // To individually generate code for each aggregate function, an element in `updateExprs` holds // all the expressions for the buffer of an aggregation function. @@ -969,7 +982,7 @@ case class HashAggregateExec( } } - val (checkFallbackForBytesToBytesMap, resetCounter, incCounter) = testFallbackStartsAt match { + var (checkFallbackForBytesToBytesMap, resetCounter, incCounter) = testFallbackStartsAt match { case Some((_, regularMapCounter)) => val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") (s"$countTerm < $regularMapCounter", s"$countTerm = 0;", s"$countTerm += 1;") @@ -979,7 +992,7 @@ case class HashAggregateExec( val skipPartialAggregateThreshold = conf.skipPartialAggregateThreshold val skipPartialAggRatio = conf.skipPartialAggregateRatio - val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") + val outputCountTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "outputCount") val oomeClassName = classOf[SparkOutOfMemoryError].getName val findOrInsertRegularHashMap: String = { def getAggBufferFromMap = { @@ -1017,21 +1030,28 @@ case class HashAggregateExec( val checkIfPartialAggSkipped = s""" |!($rowCountTerm < $skipPartialAggregateThreshold) && - | ((float)$countTerm/$rowCountTerm) > $skipPartialAggRatio; + | ((float)$outputCountTerm/$rowCountTerm) > $skipPartialAggRatio; |""".stripMargin s""" - |if (!$avoidSpillInPartialAggregateTerm) { + |if (!$skipPartialAggTerm) { | $getAggBufferFromMap | // Can't allocate buffer from the hash map. | // Check if we can avoid partial aggregation. | // Otherwise, Spill the map and fallback to sort-based | // aggregation after processing all input rows. | if ($unsafeRowBuffer == null) { - | $countTerm = $countTerm + $hashMapTerm.getNumRows(); - | boolean skipPartAgg = $checkIfPartialAggSkipped - | if (skipPartAgg) { + | $outputCountTerm = $outputCountTerm + $hashMapTerm.getNumRows(); + | boolean $skipPartialAgg = $checkIfPartialAggSkipped + | if ($skipPartialAgg) { | // Aggregation buffer is created later - | $avoidSpillInPartialAggregateTerm = true; + | $skipPartialAggTerm = true; + | // this is done to clear up off-heap memory sooner and prevent OOMs + | if ($sorterTerm == null) { + | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + | } else { + | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + | } + | $hashMapTerm.free(); | } else { | $addToSorter | } @@ -1050,6 +1070,23 @@ case class HashAggregateExec( } } + val createEmptyAggBufferWhenPartialAggSkipped: String = { + if (skipPartialAggregateEnabled) { + val numAggSkippedRows = metricTerm(ctx, "partialAggSkipped") + val initExpr = declFunctions.flatMap(f => f.initialValues) + val emptyBufferKeyCode = GenerateUnsafeProjection.createCode(ctx, initExpr) + s""" + |// Create an empty aggregation buffer + |if ($skipPartialAggTerm) { + | ${unsafeRowKeyCode.code} + | ${emptyBufferKeyCode.code} + | $unsafeRowBuffer = ${emptyBufferKeyCode.value}; + | $numAggSkippedRows.add(1); + |} + |""".stripMargin + } else "" + } + val findOrInsertHashMap: String = { val insertCode = if (isFastHashMapEnabled) { def findOrInsertIntoFastHashMap = { @@ -1061,18 +1098,14 @@ case class HashAggregateExec( |} |""".stripMargin } - val insertFastMap = if (skipPartialAggregateEnabled) { - s""" - |if (!$avoidSpillInPartialAggregateTerm) { - | $findOrInsertIntoFastHashMap - |} - |$countTerm = $fastHashMapTerm.getNumRows(); - |""".stripMargin - } else { + + val insertFastMap = s""" - |$findOrInsertIntoFastHashMap - |""".stripMargin - } + |$findOrInsertIntoFastHashMap + |if ($skipPartialAggregateEnabled) { + | $outputCountTerm = $fastHashMapTerm.getNumRows(); + |} + |""".stripMargin // If fast hash map is on, we first generate code to probe and update the fast hash map. // If the probe is successful the corresponding fast row buffer will hold the mutable row. s""" @@ -1085,26 +1118,10 @@ case class HashAggregateExec( } else { findOrInsertRegularHashMap } - def createEmptyAggBufferAndUpdateMetrics: String = { - if (skipPartialAggregateEnabled) { - val numAggSkippedRows = metricTerm(ctx, "partialAggSkipped") - val initExpr = declFunctions.flatMap(f => f.initialValues) - val emptyBufferKeyCode = GenerateUnsafeProjection.createCode(ctx, initExpr) - s""" - |// Create an empty aggregation buffer - |if ($avoidSpillInPartialAggregateTerm) { - | ${unsafeRowKeyCode.code} - | ${emptyBufferKeyCode.code} - | $unsafeRowBuffer = ${emptyBufferKeyCode.value}; - | $numAggSkippedRows.add(1); - |} - |""".stripMargin - } else "" - } s""" |$insertCode - |$createEmptyAggBufferAndUpdateMetrics + |$createEmptyAggBufferWhenPartialAggSkipped |""".stripMargin } @@ -1172,6 +1189,16 @@ case class HashAggregateExec( } val updateRowInHashMap: String = { + val outputRow: String = { + if (skipPartialAggregateEnabled) { + s""" + |if ($skipPartialAggTerm) { + | $outputFunc(${unsafeRowKeyCode.value}, $unsafeRowBuffer); + |} + |$rowCountTerm = $rowCountTerm + 1; + |""".stripMargin + } else "" + } val updateRowInMap = if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer @@ -1238,17 +1265,6 @@ case class HashAggregateExec( updateRowInRegularHashMap } - def outputRow: String = { - if (skipPartialAggregateEnabled) { - s""" - |if ($avoidSpillInPartialAggregateTerm) { - | $outputFunc(${unsafeRowKeyCode.value}, $unsafeRowBuffer); - |} - |$rowCountTerm = $rowCountTerm + 1; - |""".stripMargin - } else "" - } - s""" |$updateRowInMap |$outputRow diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 57c76470f322a..86456bb54926d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -734,7 +734,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession test("Avoid spill in partial aggregation" ) { withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key -> "true"), - (SQLConf.SKIP_PARTIAL_AGGREGATE_MINROWS.key -> "2")) { + (SQLConf.SKIP_PARTIAL_AGGREGATE_MIN_ROWS.key -> "2")) { // Create Dataframes val data = Seq(("James", 1), ("James", 1), ("Phil", 1)) val aggDF = data.toDF("name", "values").groupBy("name").sum("values") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 710881198b1b5..92158f2f43dff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -25,6 +25,8 @@ import test.org.apache.spark.sql.MyDoubleSum import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.Partial +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton From a7e97ddf3c46641d9c8c298d09c4c4e4ac6024ae Mon Sep 17 00:00:00 2001 From: Shipra Agrawal Date: Wed, 30 Jun 2021 17:13:09 -0700 Subject: [PATCH 3/3] remove unintended changes --- .../aggregate/HashAggregateExec.scala | 27 +++------------ .../execution/WholeStageCodegenSuite.scala | 34 +++++++++---------- .../execution/metric/SQLMetricsSuite.scala | 6 ++-- 3 files changed, 24 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 121c4d780a636..1c58b885e3f22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -69,7 +69,8 @@ case class HashAggregateExec( "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters"), - "numTasksFallBacked" -> SQLMetrics.createMetric(sparkContext, "number of sort fallback tasks")) + "numTasksFallBacked" -> SQLMetrics.createMetric(sparkContext, + "number of sort fallback tasks")) if (skipPartialAggregateEnabled) { metrics ++ Map("partialAggSkipped" -> SQLMetrics.createMetric(sparkContext, "number of skipped records for partial aggregates")) @@ -78,23 +79,6 @@ case class HashAggregateExec( } } - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override protected def outputExpressions: Seq[NamedExpression] = resultExpressions - - override def producedAttributes: AttributeSet = - AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash // map and/or the sort-based aggregation once it has processed a given number of input rows. private val testFallbackStartsAt: Option[(Int, Int)] = { @@ -520,8 +504,6 @@ case class HashAggregateExec( // merge the final hashMap into sorter numTasksFallBacked += 1 - sorter.merge(hashMap.destructAndCreateExternalSorter()) - hashMap.free() if (!skipPartialAggTerm) { sorter.merge(hashMap.destructAndCreateExternalSorter()) hashMap.free() @@ -724,8 +706,6 @@ case class HashAggregateExec( private def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") - if (conf.enableTwoLevelAggMap) { - var childrenConsumed: String = null if (skipPartialAggregateEnabled) { skipPartialAggTerm = ctx. @@ -833,7 +813,8 @@ case class HashAggregateExec( } val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" + - s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe, $numTasksFallBacked, $mapCleared);" + s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe, $numTasksFallBacked," + + s" $mapCleared);" val finishHashMap = if (isFastHashMapEnabled) { s""" |$iterTermForFastHashMap = $fastHashMapTerm.rowIterator(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 86456bb54926d..813d21fb02372 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.aggregate.Partial import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite -import org.apache.spark.sql.execution.aggregate.{AggUtils, HashAggregateExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -279,7 +279,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val twoJoinsDF = df1.join(df2, $"k1" < $"k2").crossJoin(df3) hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect { case WholeStageCodegenExec(BroadcastNestedLoopJoinExec( - _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true + _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true }.size === 1 assert(hasJoinInCodegen == codegenEnabled) checkAnswer(twoJoinsDF, @@ -318,7 +318,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession .join(df3, $"k1" <= $"k3", "left_outer") hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect { case WholeStageCodegenExec(BroadcastNestedLoopJoinExec( - _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true + _: BroadcastNestedLoopJoinExec, _, _, _, _)) => true }.size === 1 assert(hasJoinInCodegen == codegenEnabled) checkAnswer(twoJoinsDF, @@ -387,7 +387,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined) assert(ds.collect() === 0.until(10).map(_.toString).toArray) } @@ -405,7 +405,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) assert(ds.collect() === Array(0, 6)) } @@ -418,7 +418,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val planInt = dsIntFilter.queryExecution.executedPlan assert(planInt.collect { case WholeStageCodegenExec(FilterExec(_, - ColumnarToRowExec(InputAdapter(_: InMemoryTableScanExec)))) => () + ColumnarToRowExec(InputAdapter(_: InMemoryTableScanExec)))) => () }.length == 1) assert(dsIntFilter.collect() === Array(1, 2)) @@ -556,7 +556,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession .write.mode(SaveMode.Overwrite).parquet(path) withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", - SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") { + SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") { val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i") val df = spark.read.parquet(path).selectExpr(projection: _*) @@ -642,18 +642,18 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession .join(baseTable, "idx") assert(distinctWithId.queryExecution.executedPlan.collectFirst { case WholeStageCodegenExec( - ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true + ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true }.isDefined) checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0))) // BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate // expression val groupByWithId = - baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id()) + baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id()) .join(baseTable, "idx") assert(groupByWithId.queryExecution.executedPlan.collectFirst { case WholeStageCodegenExec( - ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true + ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true }.isDefined) checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0))) } @@ -682,7 +682,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert( executedPlan.find { case WholeStageCodegenExec( - HashAggregateExec(_, _, _, _, _, _, _: LocalTableScanExec)) => true + HashAggregateExec(_, _, _, _, _, _, _: LocalTableScanExec)) => true case _ => false }.isDefined, "LocalTableScanExec should be within a WholeStageCodegen domain.") @@ -690,9 +690,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession test("Give up splitting aggregate code if a parameter length goes over the limit") { withSQLConf( - SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true", - SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", - "spark.sql.CodeGenerator.validParamLength" -> "0") { + SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true", + SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", + "spark.sql.CodeGenerator.validParamLength" -> "0") { withTable("t") { val expectedErrMsg = "Failed to split aggregate code into small functions" Seq( @@ -711,9 +711,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession test("Give up splitting subexpression code if a parameter length goes over the limit") { withSQLConf( - SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "false", - SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", - "spark.sql.CodeGenerator.validParamLength" -> "0") { + SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "false", + SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", + "spark.sql.CodeGenerator.validParamLength" -> "0") { withTable("t") { val expectedErrMsg = "Failed to split subexpression code into small functions" Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 299cee95244d6..a1e2587c0b9ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -657,9 +657,9 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("SPARK-25602: SparkPlan.getByteArrayRdd should not consume the input when not necessary") { def checkFilterAndRangeMetrics( - df: DataFrame, - filterNumOutputs: Int, - rangeNumOutputs: Int): Unit = { + df: DataFrame, + filterNumOutputs: Int, + rangeNumOutputs: Int): Unit = { val plan = df.queryExecution.executedPlan val filters = collectNodeWithinWholeStage[FilterExec](plan)