Skip to content

Commit ececd57

Browse files
committed
fix tests
1 parent fc6b8cb commit ececd57

File tree

4 files changed

+51
-36
lines changed

4 files changed

+51
-36
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,11 @@ class ColumnarAggMapCodeGenerator(
191191
| ${groupingKeys.zipWithIndex.map(k =>
192192
s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")}
193193
| ${bufferValues.zipWithIndex.map(k =>
194-
s"batch.column(${groupingKeys.length + k._2}).putLong(numRows, 0);")
194+
s"batch.column(${groupingKeys.length + k._2}).putNull(numRows);")
195195
.mkString("\n")}
196196
| buckets[idx] = numRows++;
197197
| batch.setNumRows(numRows);
198+
| aggregateBufferBatch.setNumRows(numRows);
198199
| return aggregateBufferBatch.getRow(buckets[idx]);
199200
| } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) {
200201
| return aggregateBufferBatch.getRow(buckets[idx]);

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,14 @@ case class TungstenAggregate(
7070
}
7171
}
7272

73-
// This is for testing. We force TungstenAggregationIterator to fall back to sort-based
74-
// aggregation once it has processed a given number of input rows.
75-
private val testFallbackStartsAt: Option[Int] = {
73+
// This is for testing. We force TungstenAggregationIterator to fall back to the bytes to bytes
74+
// map and the sort-based aggregation once it has processed a given number of input rows.
75+
private val testFallbackStartsAt: Option[(Int, Int)] = {
7676
sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match {
7777
case null | "" => None
78-
case fallbackStartsAt => Some(fallbackStartsAt.toInt)
78+
case fallbackStartsAt =>
79+
val splits = fallbackStartsAt.split(",").map(_.trim)
80+
Some((splits.head.toInt, splits.last.toInt))
7981
}
8082
}
8183

@@ -593,20 +595,27 @@ case class TungstenAggregate(
593595
ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
594596
}
595597

596-
val (checkFallback, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) {
598+
val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter,
599+
incCounter) = if (testFallbackStartsAt.isDefined) {
597600
val countTerm = ctx.freshName("fallbackCounter")
598601
ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
599-
(s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;")
602+
(s"$countTerm < ${testFallbackStartsAt.get._1}",
603+
s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;")
600604
} else {
601-
("true", "", "")
605+
("true", "true", "", "")
602606
}
603607

604608
val findOrInsertInGeneratedHashMap: Option[String] = {
605609
if (isAggregateHashMapEnabled) {
606610
Option(
607611
s"""
608-
| $aggregateRow =
609-
| $aggregateHashMapTerm.findOrInsert(${groupByKeys.map(_.value).mkString(", ")});
612+
|if ($checkFallbackForGeneratedHashMap) {
613+
| ${groupByKeys.map(_.code).mkString("\n")}
614+
| if (${groupByKeys.map("!" + _.isNull).mkString(" && ")}) {
615+
| $aggregateRow =
616+
| $aggregateHashMapTerm.findOrInsert(${groupByKeys.map(_.value).mkString(", ")});
617+
| }
618+
|}
610619
""".stripMargin)
611620
} else {
612621
None
@@ -619,7 +628,7 @@ case class TungstenAggregate(
619628
| // generate grouping key
620629
| ${keyCode.code.trim}
621630
| ${hashEval.code.trim}
622-
| if ($checkFallback) {
631+
| if ($checkFallbackForBytesToBytesMap) {
623632
| // try to get the buffer from hash map
624633
| $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
625634
| }

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class TungstenAggregationIterator(
8585
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
8686
originalInputAttributes: Seq[Attribute],
8787
inputIter: Iterator[InternalRow],
88-
testFallbackStartsAt: Option[Int],
88+
testFallbackStartsAt: Option[(Int, Int)],
8989
numOutputRows: LongSQLMetric,
9090
dataSize: LongSQLMetric,
9191
spillSize: LongSQLMetric)
@@ -171,7 +171,7 @@ class TungstenAggregationIterator(
171171
// hashMap. If there is not enough memory, it will multiple hash-maps, spilling
172172
// after each becomes full then using sort to merge these spills, finally do sort
173173
// based aggregation.
174-
private def processInputs(fallbackStartsAt: Int): Unit = {
174+
private def processInputs(fallbackStartsAt: (Int, Int)): Unit = {
175175
if (groupingExpressions.isEmpty) {
176176
// If there is no grouping expressions, we can just reuse the same buffer over and over again.
177177
// Note that it would be better to eliminate the hash map entirely in the future.
@@ -187,7 +187,7 @@ class TungstenAggregationIterator(
187187
val newInput = inputIter.next()
188188
val groupingKey = groupingProjection.apply(newInput)
189189
var buffer: UnsafeRow = null
190-
if (i < fallbackStartsAt) {
190+
if (i < fallbackStartsAt._2) {
191191
buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
192192
}
193193
if (buffer == null) {
@@ -352,7 +352,7 @@ class TungstenAggregationIterator(
352352
/**
353353
* Start processing input rows.
354354
*/
355-
processInputs(testFallbackStartsAt.getOrElse(Int.MaxValue))
355+
processInputs(testFallbackStartsAt.getOrElse((Int.MaxValue, Int.MaxValue)))
356356

357357
// If we did not switch to sort-based aggregation in processInputs,
358358
// we pre-load the first key-value pair from the map (to make hasNext idempotent).

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -967,27 +967,32 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite
967967
class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite {
968968

969969
override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
970-
(0 to 2).foreach { fallbackStartsAt =>
971-
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> fallbackStartsAt.toString) {
972-
// Create a new df to make sure its physical operator picks up
973-
// spark.sql.TungstenAggregate.testFallbackStartsAt.
974-
// todo: remove it?
975-
val newActual = Dataset.ofRows(sqlContext, actual.logicalPlan)
976-
977-
QueryTest.checkAnswer(newActual, expectedAnswer) match {
978-
case Some(errorMessage) =>
979-
val newErrorMessage =
980-
s"""
981-
|The following aggregation query failed when using TungstenAggregate with
982-
|controlled fallback (it falls back to sort-based aggregation once it has processed
983-
|$fallbackStartsAt input rows). The query is
984-
|${actual.queryExecution}
985-
|
986-
|$errorMessage
987-
""".stripMargin
988-
989-
fail(newErrorMessage)
990-
case None =>
970+
Seq(false, true).foreach { enableColumnarHashMap =>
971+
withSQLConf("spark.sql.codegen.aggregate.map.enabled" -> enableColumnarHashMap.toString) {
972+
(1 to 3).foreach { fallbackStartsAt =>
973+
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
974+
s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
975+
// Create a new df to make sure its physical operator picks up
976+
// spark.sql.TungstenAggregate.testFallbackStartsAt.
977+
// todo: remove it?
978+
val newActual = Dataset.ofRows(sqlContext, actual.logicalPlan)
979+
980+
QueryTest.checkAnswer(newActual, expectedAnswer) match {
981+
case Some(errorMessage) =>
982+
val newErrorMessage =
983+
s"""
984+
|The following aggregation query failed when using TungstenAggregate with
985+
|controlled fallback (it falls back to bytes to bytes map once it has processed
986+
|${fallbackStartsAt -1} input rows and to sort-based aggregation once it has
987+
|processed $fallbackStartsAt input rows). The query is ${actual.queryExecution}
988+
|
989+
|$errorMessage
990+
""".stripMargin
991+
992+
fail(newErrorMessage)
993+
case None => // Success
994+
}
995+
}
991996
}
992997
}
993998
}

0 commit comments

Comments
 (0)