@@ -58,7 +58,7 @@ class ObjectAggregationIterator(
5858
5959 private [this ] var aggBufferIterator : Iterator [AggregationBufferEntry ] = _
6060
61- val (sortBasedAggExpressions, sortBasedAggFunctions ): (
61+ val (sortBasedMergeAggExpressions, sortBasedMergeAggFunctions ): (
6262 Seq [AggregateExpression ], Array [AggregateFunction ]) = {
6363 val newExpressions = aggregateExpressions.map {
6464 case agg @ AggregateExpression (_, Partial , _, _) =>
@@ -72,8 +72,9 @@ class ObjectAggregationIterator(
7272
7373 // Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers
7474 private val mergeAggregationBuffers : (InternalRow , InternalRow ) => Unit = {
75- val newInputAttributes = sortBasedAggFunctions.flatMap(_.inputAggBufferAttributes)
76- generateProcessRow(sortBasedAggExpressions, sortBasedAggFunctions, newInputAttributes)
75+ val newInputAttributes = sortBasedMergeAggFunctions.flatMap(_.inputAggBufferAttributes)
76+ generateProcessRow(
77+ sortBasedMergeAggExpressions, sortBasedMergeAggFunctions, newInputAttributes)
7778 }
7879
7980 /**
@@ -184,7 +185,9 @@ class ObjectAggregationIterator(
184185 StructType .fromAttributes(groupingAttributes),
185186 processRow,
186187 mergeAggregationBuffers,
187- createNewAggregationBuffer(sortBasedAggFunctions))
188+ createNewAggregationBuffer(aggregateFunctions),
189+ createNewAggregationBuffer(sortBasedMergeAggFunctions),
190+ aggregateFunctions)
188191
189192 while (inputRows.hasNext) {
190193 // NOTE: The input row is always UnsafeRow
@@ -212,7 +215,12 @@ class ObjectAggregationIterator(
212215 * @param processRow Function to update the aggregation buffer with input rows
213216 * @param mergeAggregationBuffers Function used to merge the input aggregation buffers into existing
214217 * aggregation buffers
215- * @param makeEmptyAggregationBuffer Creates an empty aggregation buffer
218+ * @param makeEmptyAggregationBufferForSortBasedUpdateAggFunctions Creates an empty aggregation
219+ * buffer for update operation
220+ * @param makeEmptyAggregationBufferForSortBasedMergeAggFunctions Creates an empty aggregation
221+ * buffer for merge operation
222+ * @param sortBasedUpdateAggFunctions aggregate functions needed to serialize the
223+ * aggregation buffer
216224 *
217225 * @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec ]].
218226 */
@@ -222,7 +230,9 @@ class SortBasedAggregator(
222230 groupingSchema : StructType ,
223231 processRow : (InternalRow , InternalRow ) => Unit ,
224232 mergeAggregationBuffers : (InternalRow , InternalRow ) => Unit ,
225- makeEmptyAggregationBuffer : => InternalRow ) {
233+ makeEmptyAggregationBufferForSortBasedUpdateAggFunctions : => InternalRow ,
234+ makeEmptyAggregationBufferForSortBasedMergeAggFunctions : => InternalRow ,
235+ sortBasedUpdateAggFunctions : Array [AggregateFunction ]) {
226236
227237 // external sorter to sort the input (grouping key + input row) with grouping key.
228238 private val inputSorter = createExternalSorterForInput()
@@ -231,6 +241,10 @@ class SortBasedAggregator(
231241 def addInput (groupingKey : UnsafeRow , inputRow : UnsafeRow ): Unit = {
232242 inputSorter.insertKV(groupingKey, inputRow)
233243 }
244+ private def serializeBuffer (buffer : InternalRow ): Unit = {
245+ sortBasedUpdateAggFunctions.collect { case f : TypedImperativeAggregate [_] => f }.foreach(
246+ _.serializeAggregateBufferInPlace(buffer))
247+ }
234248
235249 /**
236250 * Returns a destructive iterator of AggregationBufferEntry.
@@ -241,16 +255,18 @@ class SortBasedAggregator(
241255 val inputIterator = inputSorter.sortedIterator()
242256 var hasNextInput : Boolean = inputIterator.next()
243257 var hasNextAggBuffer : Boolean = initialAggBufferIterator.next()
244- private var result : AggregationBufferEntry = _
258+ private var updateResult : AggregationBufferEntry = _
259+ private var finalResult : AggregationBufferEntry = _
245260 private var groupingKey : UnsafeRow = _
246261
247262 override def hasNext (): Boolean = {
248- result != null || findNextSortedGroup()
263+ updateResult != null || finalResult != null || findNextSortedGroup()
249264 }
250265
251266 override def next (): AggregationBufferEntry = {
252- val returnResult = result
253- result = null
267+ val returnResult = finalResult
268+ updateResult = null
269+ finalResult = null
254270 returnResult
255271 }
256272
@@ -259,21 +275,31 @@ class SortBasedAggregator(
259275 if (hasNextInput || hasNextAggBuffer) {
260276 // Find smaller key of the initialAggBufferIterator and initialAggBufferIterator
261277 groupingKey = findGroupingKey()
262- result = new AggregationBufferEntry (groupingKey, makeEmptyAggregationBuffer)
278+ updateResult = new AggregationBufferEntry (
279+ groupingKey, makeEmptyAggregationBufferForSortBasedUpdateAggFunctions)
280+ finalResult = new AggregationBufferEntry (
281+ groupingKey, makeEmptyAggregationBufferForSortBasedMergeAggFunctions)
263282
264283 // Firstly, update the aggregation buffer with input rows.
265284 while (hasNextInput &&
266285 groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0 ) {
267- processRow(result .aggregationBuffer, inputIterator.getValue)
286+ processRow(updateResult .aggregationBuffer, inputIterator.getValue)
268287 hasNextInput = inputIterator.next()
269288 }
270289
290+ // This step ensures that the contents of the updateResult aggregation buffer are
291+ // merged with the finalResult aggregation buffer to maintain consistency
292+ if (hasNextAggBuffer) {
293+ serializeBuffer(updateResult.aggregationBuffer)
294+ mergeAggregationBuffers(finalResult.aggregationBuffer, updateResult.aggregationBuffer)
295+ }
271296 // Secondly, merge the aggregation buffer with existing aggregation buffers.
272297 // NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should
273298 // be called after calling processRow.
274299 while (hasNextAggBuffer &&
275300 groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0 ) {
276- mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue)
301+ mergeAggregationBuffers(
302+ finalResult.aggregationBuffer, initialAggBufferIterator.getValue)
277303 hasNextAggBuffer = initialAggBufferIterator.next()
278304 }
279305
0 commit comments