Skip to content

Commit 4dc1007

Browse files
author
pgandhi
committed
[SPARK-27207] : Fix SortBasedAggregator to run with different aggregate functions and write unit test
Fix SortBasedAggregator to ensure that update and merge are performed with two different sets of aggregate functions, one for update and one for merge respectively.
1 parent b4eaf31 commit 4dc1007

File tree

4 files changed

+73
-114
lines changed

4 files changed

+73
-114
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class ObjectAggregationIterator(
5858

5959
private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _
6060

61-
val (sortBasedAggExpressions, sortBasedAggFunctions): (
61+
val (sortBasedMergeAggExpressions, sortBasedMergeAggFunctions): (
6262
Seq[AggregateExpression], Array[AggregateFunction]) = {
6363
val newExpressions = aggregateExpressions.map {
6464
case agg @ AggregateExpression(_, Partial, _, _) =>
@@ -72,8 +72,9 @@ class ObjectAggregationIterator(
7272

7373
// Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers
7474
private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = {
75-
val newInputAttributes = sortBasedAggFunctions.flatMap(_.inputAggBufferAttributes)
76-
generateProcessRow(sortBasedAggExpressions, sortBasedAggFunctions, newInputAttributes)
75+
val newInputAttributes = sortBasedMergeAggFunctions.flatMap(_.inputAggBufferAttributes)
76+
generateProcessRow(
77+
sortBasedMergeAggExpressions, sortBasedMergeAggFunctions, newInputAttributes)
7778
}
7879

7980
/**
@@ -184,7 +185,9 @@ class ObjectAggregationIterator(
184185
StructType.fromAttributes(groupingAttributes),
185186
processRow,
186187
mergeAggregationBuffers,
187-
createNewAggregationBuffer(sortBasedAggFunctions))
188+
createNewAggregationBuffer(aggregateFunctions),
189+
createNewAggregationBuffer(sortBasedMergeAggFunctions),
190+
aggregateFunctions)
188191

189192
while (inputRows.hasNext) {
190193
// NOTE: The input row is always UnsafeRow
@@ -212,7 +215,12 @@ class ObjectAggregationIterator(
212215
* @param processRow Function to update the aggregation buffer with input rows
213216
* @param mergeAggregationBuffers Function used to merge the input aggregation buffers into existing
214217
* aggregation buffers
215-
* @param makeEmptyAggregationBuffer Creates an empty aggregation buffer
218+
* @param makeEmptyAggregationBufferForSortBasedUpdateAggFunctions Creates an empty aggregation
219+
* buffer for update operation
220+
* @param makeEmptyAggregationBufferForSortBasedMergeAggFunctions Creates an empty aggregation
221+
* buffer for merge operation
222+
* @param sortBasedUpdateAggFunctions aggregate functions needed to serialize the
223+
* aggregation buffer
216224
*
217225
* @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec]].
218226
*/
@@ -222,7 +230,9 @@ class SortBasedAggregator(
222230
groupingSchema: StructType,
223231
processRow: (InternalRow, InternalRow) => Unit,
224232
mergeAggregationBuffers: (InternalRow, InternalRow) => Unit,
225-
makeEmptyAggregationBuffer: => InternalRow) {
233+
makeEmptyAggregationBufferForSortBasedUpdateAggFunctions: => InternalRow,
234+
makeEmptyAggregationBufferForSortBasedMergeAggFunctions: => InternalRow,
235+
sortBasedUpdateAggFunctions: Array[AggregateFunction]) {
226236

227237
// external sorter to sort the input (grouping key + input row) with grouping key.
228238
private val inputSorter = createExternalSorterForInput()
@@ -231,6 +241,10 @@ class SortBasedAggregator(
231241
def addInput(groupingKey: UnsafeRow, inputRow: UnsafeRow): Unit = {
232242
inputSorter.insertKV(groupingKey, inputRow)
233243
}
244+
private def serializeBuffer(buffer: InternalRow): Unit = {
245+
sortBasedUpdateAggFunctions.collect { case f: TypedImperativeAggregate[_] => f }.foreach(
246+
_.serializeAggregateBufferInPlace(buffer))
247+
}
234248

235249
/**
236250
* Returns a destructive iterator of AggregationBufferEntry.
@@ -241,16 +255,18 @@ class SortBasedAggregator(
241255
val inputIterator = inputSorter.sortedIterator()
242256
var hasNextInput: Boolean = inputIterator.next()
243257
var hasNextAggBuffer: Boolean = initialAggBufferIterator.next()
244-
private var result: AggregationBufferEntry = _
258+
private var updateResult: AggregationBufferEntry = _
259+
private var finalResult: AggregationBufferEntry = _
245260
private var groupingKey: UnsafeRow = _
246261

247262
override def hasNext(): Boolean = {
248-
result != null || findNextSortedGroup()
263+
updateResult != null || finalResult != null || findNextSortedGroup()
249264
}
250265

251266
override def next(): AggregationBufferEntry = {
252-
val returnResult = result
253-
result = null
267+
val returnResult = finalResult
268+
updateResult = null
269+
finalResult = null
254270
returnResult
255271
}
256272

@@ -259,21 +275,31 @@ class SortBasedAggregator(
259275
if (hasNextInput || hasNextAggBuffer) {
260276
// Find smaller key of the initialAggBufferIterator and initialAggBufferIterator
261277
groupingKey = findGroupingKey()
262-
result = new AggregationBufferEntry(groupingKey, makeEmptyAggregationBuffer)
278+
updateResult = new AggregationBufferEntry(
279+
groupingKey, makeEmptyAggregationBufferForSortBasedUpdateAggFunctions)
280+
finalResult = new AggregationBufferEntry(
281+
groupingKey, makeEmptyAggregationBufferForSortBasedMergeAggFunctions)
263282

264283
// Firstly, update the aggregation buffer with input rows.
265284
while (hasNextInput &&
266285
groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) {
267-
processRow(result.aggregationBuffer, inputIterator.getValue)
286+
processRow(updateResult.aggregationBuffer, inputIterator.getValue)
268287
hasNextInput = inputIterator.next()
269288
}
270289

290+
// This step ensures that the contents of the updateResult aggregation buffer are
291+
// merged with the finalResult aggregation buffer to maintain consistency
292+
if (hasNextAggBuffer) {
293+
serializeBuffer(updateResult.aggregationBuffer)
294+
mergeAggregationBuffers(finalResult.aggregationBuffer, updateResult.aggregationBuffer)
295+
}
271296
// Secondly, merge the aggregation buffer with existing aggregation buffers.
272297
// NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should
273298
// be called after calling processRow.
274299
while (hasNextAggBuffer &&
275300
groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) {
276-
mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue)
301+
mergeAggregationBuffers(
302+
finalResult.aggregationBuffer, initialAggBufferIterator.getValue)
277303
hasNextAggBuffer = initialAggBufferIterator.next()
278304
}
279305

sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql
2020
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
2121

2222
import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
23-
import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax2
2423
import org.apache.spark.sql.catalyst.InternalRow
2524
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow}
2625
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
@@ -211,20 +210,6 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
211210
checkAnswer(query, expected)
212211
}
213212

214-
test("SPARK-27207: Ensure aggregate buffers are initialized again for SortBasedAggregate") {
215-
withSQLConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold" -> "5") {
216-
val df = data.toDF("value", "key").coalesce(2)
217-
val query = df.groupBy($"key").agg(typedMax2($"value"), count($"value"), typedMax2($"value"))
218-
val expected = data.groupBy(_._2).toSeq.map { group =>
219-
val (key, values) = group
220-
val valueMax = values.map(_._1).max
221-
val countValue = values.size
222-
Row(key, valueMax, countValue, valueMax)
223-
}
224-
checkAnswer(query, expected)
225-
}
226-
}
227-
228213
private def typedMax(column: Column): Column = {
229214
val max = TypedMax(column.expr, nullable = false)
230215
Column(max.toAggregateExpression())
@@ -235,10 +220,6 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
235220
Column(max.toAggregateExpression())
236221
}
237222

238-
private def typedMax2(column: Column): Column = {
239-
val max = TypedMax2(column.expr, nullable = false)
240-
Column(max.toAggregateExpression())
241-
}
242223
}
243224

244225
object TypedImperativeAggregateSuite {
@@ -319,86 +300,5 @@ object TypedImperativeAggregateSuite {
319300
}
320301
}
321302

322-
/**
323-
* Calculate the max value with object aggregation buffer. This stores class MaxValue
324-
* in aggregation buffer.
325-
*/
326-
private case class TypedMax2(
327-
child: Expression,
328-
nullable: Boolean = false,
329-
mutableAggBufferOffset: Int = 0,
330-
inputAggBufferOffset: Int = 0)
331-
extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes {
332-
333-
334-
var maxValueBuffer: MaxValue = null
335-
override def createAggregationBuffer(): MaxValue = {
336-
// Returns Int.MinValue if all inputs are null
337-
maxValueBuffer = new MaxValue(Int.MinValue)
338-
maxValueBuffer
339-
}
340-
341-
override def update(buffer: MaxValue, input: InternalRow): MaxValue = {
342-
child.eval(input) match {
343-
case inputValue: Int =>
344-
if (inputValue > buffer.value) {
345-
buffer.value = inputValue
346-
buffer.isValueSet = true
347-
}
348-
case null => // skip
349-
}
350-
buffer
351-
}
352-
353-
override def merge(bufferMax: MaxValue, inputMax: MaxValue): MaxValue = {
354-
// The below if condition will throw a Null Pointer Exception if initialize() is not called
355-
if (maxValueBuffer.isValueSet) {
356-
// do nothing
357-
}
358-
if (inputMax.value > bufferMax.value) {
359-
bufferMax.value = inputMax.value
360-
bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet
361-
}
362-
bufferMax
363-
}
364-
365-
override def eval(bufferMax: MaxValue): Any = {
366-
if (nullable && bufferMax.isValueSet == false) {
367-
null
368-
} else {
369-
bufferMax.value
370-
}
371-
}
372-
373-
override lazy val deterministic: Boolean = true
374-
375-
override def children: Seq[Expression] = Seq(child)
376-
377-
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
378-
379-
override def dataType: DataType = IntegerType
380-
381-
override def withNewMutableAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] =
382-
copy(mutableAggBufferOffset = newOffset)
383-
384-
override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] =
385-
copy(inputAggBufferOffset = newOffset)
386-
387-
override def serialize(buffer: MaxValue): Array[Byte] = {
388-
val out = new ByteArrayOutputStream()
389-
val stream = new DataOutputStream(out)
390-
stream.writeBoolean(buffer.isValueSet)
391-
stream.writeInt(buffer.value)
392-
out.toByteArray
393-
}
394-
395-
override def deserialize(storageFormat: Array[Byte]): MaxValue = {
396-
val in = new ByteArrayInputStream(storageFormat)
397-
val stream = new DataInputStream(in)
398-
val isValueSet = stream.readBoolean()
399-
val value = stream.readInt()
400-
new MaxValue(value, isValueSet)
401-
}
402-
}
403303
private class MaxValue(var value: Int, var isValueSet: Boolean = false)
404304
}

sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark._
2525
import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
2626
import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
28+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
2829
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
2930
import org.apache.spark.unsafe.KVIterator
3031

@@ -78,7 +79,9 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte
7879
groupingSchema,
7980
updateInputRow,
8081
mergeAggBuffer,
81-
createNewAggregationBuffer)
82+
createNewAggregationBuffer,
83+
createNewAggregationBuffer,
84+
sortBasedUpdateAggFunctions = new Array[AggregateFunction](5))
8285

8386
(5000 to 100000).foreach { _ =>
8487
randomKV(inputRow, group)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
4949
(2: Integer) -> null,
5050
(3: Integer) -> null
5151
).toDF("key", "value").repartition(2).createOrReplaceTempView("t")
52+
Seq(
53+
(0: Integer) -> "val_0",
54+
(1: Integer) -> "val_1",
55+
(2: Integer) -> "val_2",
56+
(3: Integer) -> "val_3",
57+
(4: Integer) -> "val_4",
58+
(5: Integer) -> "val_5",
59+
(6: Integer) -> null,
60+
(7: Integer) -> null
61+
).toDF("key", "value").repartition(2).createOrReplaceTempView("t2")
5262
}
5363

5464
protected override def afterAll(): Unit = {
@@ -111,6 +121,26 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
111121
))
112122
}
113123

124+
test("SPARK-27207: customized Hive UDAF with two aggregation buffers for Sort" +
125+
" Based Aggregation") {
126+
withSQLConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold" -> "2") {
127+
val df = sql("SELECT key % 2, mock2(value) FROM t2 GROUP BY key % 2")
128+
129+
val aggs = df.queryExecution.executedPlan.collect {
130+
case agg: ObjectHashAggregateExec => agg
131+
}
132+
133+
// There should be two aggregate operators, one for partial aggregation, and the other for
134+
// global aggregation.
135+
assert(aggs.length == 2)
136+
137+
checkAnswer(df, Seq(
138+
Row(0, Row(3, 1)),
139+
Row(1, Row(3, 1))
140+
))
141+
}
142+
}
143+
114144
test("call JAVA UDAF") {
115145
withTempView("temp") {
116146
withUserDefinedFunction("myDoubleAvg" -> false) {

0 commit comments

Comments
 (0)