From 520a17f0c21266a70c9ee957987fba7bbea42d4a Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Sat, 20 Aug 2016 00:34:56 +0800 Subject: [PATCH 1/2] object aggregation buffer Revert: object aggregation buffer --- .../expressions/aggregate/interfaces.scala | 86 ++++++++++ .../sql/execution/aggregate/AggUtils.scala | 11 +- .../SortBasedAggregationIterator.scala | 33 +++- ...regateWithObjectAggregateBufferSuite.scala | 156 ++++++++++++++++++ 4 files changed, 282 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/AggregateWithObjectAggregateBufferSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 7a39e568fa28..ed91f9ea504f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -389,3 +389,89 @@ abstract class DeclarativeAggregate def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) } } + +/** + * This traits allow user to define an AggregateFunction which can store **arbitrary** Java objects + * in Aggregation buffer during aggregation of each key group. This trait must be mixed with + * class ImperativeAggregate. + * + * Here is how it works in a typical aggregation flow (Partial mode aggregate at Mapper side, and + * Final mode aggregate at Reducer side). + * + * Stage 1: Partial aggregate at Mapper side: + * + * 1. Upon calling method `initialize(aggBuffer: MutableRow)`, user stores an arbitrary empty + * object, object A for example, in aggBuffer. The object A will be used to store the + * accumulated aggregation result. + * 1. Upon calling method `update(mutableAggBuffer: MutableRow, inputRow: InternalRow)` in + * current group (group by key), user extracts object A from mutableAggBuffer, and then updates + * object A with current inputRow. After updating, object A is stored back to mutableAggBuffer. + * 1. After processing all rows of current group, the framework will call method + * `serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow)` to serialize object A + * to a serializable format in place. + * 1. The framework may spill the aggregationBuffer to disk if there is not enough memory. + * It is safe since we have already convert aggregationBuffer to serializable format. + * 1. Spark framework moves on to next group, until all groups have been + * processed. + * + * Shuffling exchange data to Reducer tasks... + * + * Stage 2: Final mode aggregate at Reducer side: + * + * 1. Upon calling method `initialize(aggBuffer: MutableRow)`, user stores a new empty object A1 + * in aggBuffer. The object A1 will be used to store the accumulated aggregation result. + * 1. Upon calling method `merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow)`, user + * extracts object A1 from mutableAggBuffer, and extracts object A2 from inputAggBuffer. then + * user needs to merge A1, and A2, and stores the merged result back to mutableAggBuffer. + * 1. After processing all inputAggBuffer of current group (group by key), the Spark framework will + * call method `serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow)` to + * serialize object A1 to a serializable format in place. + * 1. The Spark framework may spill the aggregationBuffer to disk if there is not enough memory. + * It is safe since we have already convert aggregationBuffer to serializable format. + * 1. Spark framework moves on to next group, until all groups have been processed. + */ +trait WithObjectAggregateBuffer { + this: ImperativeAggregate => + + /** + * Serializes and in-place replaces the object stored in Aggregation buffer. The framework + * calls this method every time after finishing updating/merging one group (group by key). + * + * aggregationBuffer before serialization: + * + * The object stored in aggregationBuffer can be **arbitrary** Java objects defined by user. + * + * aggregationBuffer after serialization: + * + * The object's type must be one of: + * + * - Null + * - Boolean + * - Byte + * - Short + * - Int + * - Long + * - Float + * - Double + * - Array[Byte] + * - org.apache.spark.sql.types.Decimal + * - org.apache.spark.unsafe.types.UTF8String + * - org.apache.spark.unsafe.types.CalendarInterval + * - org.apache.spark.sql.catalyst.util.MapData + * - org.apache.spark.sql.catalyst.util.ArrayData + * - org.apache.spark.sql.catalyst.InternalRow + * + * Code example: + * + * {{{ + * override def serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow): Unit = { + * val obj = buffer.get(mutableAggBufferOffset, ObjectType(classOf[A])).asInstanceOf[A] + * // Convert the obj to bytes, which is a serializable format. + * buffer(mutableAggBufferOffset) = toBytes(obj) + * } + * }}} + * + * @param aggregationBuffer aggregation buffer before serialization + */ + def serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 4fbb9d554c9b..8cb063ceb14e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -54,9 +54,16 @@ object AggUtils { initialInputBufferOffset: Int = 0, resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { - val useHash = HashAggregateExec.supportsAggregate( + + val isUsingObjectAggregationBuffer: Boolean = aggregateExpressions.exists { + case AggregateExpression(agg: WithObjectAggregateBuffer, _, _, _) => true + case _ => false + } + + val aggBufferAttributesSupportedByHashAggregate = HashAggregateExec.supportsAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - if (useHash) { + + if (aggBufferAttributesSupportedByHashAggregate && !isUsingObjectAggregationBuffer) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 3f7f84988594..c7c176f8fa7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, WithObjectAggregateBuffer} import org.apache.spark.sql.execution.metric.SQLMetric /** @@ -54,7 +54,15 @@ class SortBasedAggregationIterator( val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) + + val allFieldsMutable = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) + + val isUsingObjectAggregationBuffer = aggregateFunctions.exists { + case agg: WithObjectAggregateBuffer => true + case _ => false + } + + val useUnsafeBuffer = allFieldsMutable && !isUsingObjectAggregationBuffer val buffer = if (useUnsafeBuffer) { val unsafeProjection = @@ -90,6 +98,21 @@ class SortBasedAggregationIterator( // compared to MutableRow (aggregation buffer) directly. private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) + // AggregationFunction which store generic object in AggregationBuffer. + // @see [[WithObjectAggregationBuffer]] for more information + private val aggFunctionsWithObjectAggregationBuffer = aggregateFunctions.collect { + case (ag: ImperativeAggregate with WithObjectAggregateBuffer) => ag + } + + // For AggregateFunction with generic object stored in aggregation buffer, we need to + // call serializeObjectAggregationBufferInPlace() explicitly to convert the generic object + // stored in aggregation buffer to serializable format. + private def serializeObjectAggregationBuffer(aggregationBuffer: MutableRow): Unit = { + aggFunctionsWithObjectAggregationBuffer.foreach { agg => + agg.serializeObjectAggregationBufferInPlace(sortBasedAggregationBuffer) + } + } + protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -131,6 +154,10 @@ class SortBasedAggregationIterator( firstRowInNextGroup = currentRow.copy() } } + + // Serializes the generic object stored in aggregation buffer. + serializeObjectAggregationBuffer(sortBasedAggregationBuffer) + // We have not seen a new group. It means that there is no new row in the input // iter. The current group is the last group of the iter. if (!findNextPartition) { @@ -162,6 +189,8 @@ class SortBasedAggregationIterator( def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { initializeBuffer(sortBasedAggregationBuffer) + // Serializes the generic object stored in aggregation buffer. + serializeObjectAggregationBuffer(sortBasedAggregationBuffer) generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateWithObjectAggregateBufferSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateWithObjectAggregateBufferSuite.scala new file mode 100644 index 000000000000..8daaf41ecfe3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateWithObjectAggregateBufferSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.AggregateWithObjectAggregateBufferSuite.MaxWithObjectAggregateBuffer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, GenericMutableRow, MutableRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, WithObjectAggregateBuffer} +import org.apache.spark.sql.execution.aggregate.{SortAggregateExec} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, StructType} + +class AggregateWithObjectAggregateBufferSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + private val data = Seq((1, 0), (3, 1), (2, 0), (6, 3), (3, 1), (4, 1), (5, 0)) + + + test("aggregate with object aggregate buffer, should not use HashAggregate") { + val df = data.toDF("a", "b") + val max = new MaxWithObjectAggregateBuffer($"a".expr) + + // Always use SortAggregateExec instead of HashAggregateExec for planning even if the aggregate + // buffer attributes are mutable fields (every field can be mutated inline like int, long...) + val allFieldsMutable = max.aggBufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) + val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan + assert(allFieldsMutable == true && sparkPlan.isInstanceOf[SortAggregateExec]) + } + + test("aggregate with object aggregate buffer, no group by") { + val df = data.toDF("a", "b").coalesce(2) + checkAnswer( + df.select(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"b"), count($"b")), + Seq(Row(6, 7, 3, 7)) + ) + } + + test("aggregate with object aggregate buffer, with group by") { + val df = data.toDF("a", "b").coalesce(2) + checkAnswer( + df.groupBy($"b").agg(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"a")), + Seq( + Row(0, 5, 3, 5), + Row(1, 4, 3, 4), + Row(3, 6, 1, 6) + ) + ) + } + + test("aggregate with object aggregate buffer, empty inputs, no group by") { + val empty = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + empty.select(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"b"), count($"b")), + Seq(Row(Int.MinValue, 0, Int.MinValue, 0))) + } + + test("aggregate with object aggregate buffer, empty inputs, with group by") { + val empty = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + empty.groupBy($"b").agg(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"a")), + Seq.empty[Row]) + } + + private def objectAggregateMax(column: Column): Column = { + val max = MaxWithObjectAggregateBuffer(column.expr) + Column(max.toAggregateExpression()) + } +} + +object AggregateWithObjectAggregateBufferSuite { + + /** + * Calculate the max value with object aggregation buffer. This stores object of class MaxValue + * in aggregation buffer. + */ + private case class MaxWithObjectAggregateBuffer( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends ImperativeAggregate with WithObjectAggregateBuffer { + + override def withNewMutableAggBufferOffset(newOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newOffset) + + // Stores a generic object MaxValue in aggregation buffer. + override def initialize(buffer: MutableRow): Unit = { + // Makes sure we are using an unsafe row for aggregation buffer. + assert(buffer.isInstanceOf[GenericMutableRow]) + buffer.update(mutableAggBufferOffset, new MaxValue(Int.MinValue)) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val inputValue = child.eval(input).asInstanceOf[Int] + val maxValue = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue] + if (inputValue > maxValue.value) { + maxValue.value = inputValue + } + } + + override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { + val bufferMax = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue] + val inputMax = deserialize(inputBuffer, inputAggBufferOffset) + if (inputMax.value > bufferMax.value) { + bufferMax.value = inputMax.value + } + } + + private def deserialize(buffer: InternalRow, offset: Int): MaxValue = { + new MaxValue((buffer.getInt(offset))) + } + + override def serializeObjectAggregationBufferInPlace(buffer: MutableRow): Unit = { + val bufferMax = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue] + buffer(mutableAggBufferOffset) = bufferMax.value + } + + override def eval(buffer: InternalRow): Any = { + val max = deserialize(buffer, mutableAggBufferOffset) + max.value + } + + override val aggBufferAttributes: Seq[AttributeReference] = + Seq(AttributeReference("buf", IntegerType)()) + + override lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + override def dataType: DataType = IntegerType + override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) + override def nullable: Boolean = true + override def deterministic: Boolean = false + override def children: Seq[Expression] = Seq(child) + } + + private class MaxValue(var value: Int) +} From 9ae648c485819cba27bde3aad920970ccdc1ab4e Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Mon, 22 Aug 2016 14:13:42 +0800 Subject: [PATCH 2/2] fix review comments --- .../expressions/aggregate/interfaces.scala | 143 +++++++++++------- .../SortBasedAggregationIterator.scala | 3 +- ...regateWithObjectAggregateBufferSuite.scala | 4 +- 3 files changed, 94 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index ed91f9ea504f..d10a77ddc709 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -391,9 +391,23 @@ abstract class DeclarativeAggregate } /** - * This traits allow user to define an AggregateFunction which can store **arbitrary** Java objects - * in Aggregation buffer during aggregation of each key group. This trait must be mixed with - * class ImperativeAggregate. + * This traits allows an AggregateFunction to store **arbitrary** Java objects in internal + * aggregation buffer during aggregation of each key group. The **arbitrary** Java object can be + * used to store the accumulated aggregation result. + * + * This trait must be mixed with class ImperativeAggregate. + * + * {{{ + * aggregation buffer for function avg + * | | + * v v + * +--------------+---------------+----------------------+ + * | sum1 (Long) | count1 (Long) | generic java object | + * +--------------+---------------+----------------------+ + * ^ + * | + * Aggregation buffer for aggregation-function with WithObjectAggregateBuffer + * }}} * * Here is how it works in a typical aggregation flow (Partial mode aggregate at Mapper side, and * Final mode aggregate at Reducer side). @@ -401,77 +415,100 @@ abstract class DeclarativeAggregate * Stage 1: Partial aggregate at Mapper side: * * 1. Upon calling method `initialize(aggBuffer: MutableRow)`, user stores an arbitrary empty - * object, object A for example, in aggBuffer. The object A will be used to store the - * accumulated aggregation result. - * 1. Upon calling method `update(mutableAggBuffer: MutableRow, inputRow: InternalRow)` in - * current group (group by key), user extracts object A from mutableAggBuffer, and then updates - * object A with current inputRow. After updating, object A is stored back to mutableAggBuffer. - * 1. After processing all rows of current group, the framework will call method - * `serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow)` to serialize object A - * to a serializable format in place. - * 1. The framework may spill the aggregationBuffer to disk if there is not enough memory. - * It is safe since we have already convert aggregationBuffer to serializable format. - * 1. Spark framework moves on to next group, until all groups have been - * processed. + * object, object A for example, in internal aggBuffer. The object A will be used to store the + * accumulated aggregation result. + * 2. Upon calling method `update(mutableAggBuffer: MutableRow, inputRow: InternalRow)` in + * current group (group by key), user extracts object A from mutableAggBuffer, and then updates + * object A with current inputRow. After updating, object A is stored back to mutableAggBuffer. + * 3. After processing all rows of current group, the framework will call method + * {{{ + * serializeObjectAggregateBuffer( + * objectAggregateBuffer: InternalRow, + * targetBuffer: MutableRow) + * }}} + * to serialize object A stored in objectAggregateBuffer to Spark SQL internally supported + * serializable format, and write the serialized format to targetBuffer MutableRow. + * The framework may persist the targetBuffer to disk if there is not enough memory, it is safe + * as all fields of targetBuffer MutableRow are serializable + * 4. The framework moves on to next group, until all groups have been processed. * * Shuffling exchange data to Reducer tasks... * * Stage 2: Final mode aggregate at Reducer side: * * 1. Upon calling method `initialize(aggBuffer: MutableRow)`, user stores a new empty object A1 - * in aggBuffer. The object A1 will be used to store the accumulated aggregation result. - * 1. Upon calling method `merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow)`, user - * extracts object A1 from mutableAggBuffer, and extracts object A2 from inputAggBuffer. then - * user needs to merge A1, and A2, and stores the merged result back to mutableAggBuffer. - * 1. After processing all inputAggBuffer of current group (group by key), the Spark framework will - * call method `serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow)` to - * serialize object A1 to a serializable format in place. - * 1. The Spark framework may spill the aggregationBuffer to disk if there is not enough memory. - * It is safe since we have already convert aggregationBuffer to serializable format. - * 1. Spark framework moves on to next group, until all groups have been processed. + * in internal aggBuffer. The object A1 will be used to store the accumulated aggregation result. + * 2. Upon calling method `merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow)`, user + * extracts object A1 from mutableAggBuffer, and extracts object A2 from inputAggBuffer. then + * user needs to merge A1, and A2, and stores the merged result back to mutableAggBuffer. + * 3. After processing all inputAggBuffer of current group (group by key), the framework will + * call method: + * {{{ + * serializeObjectAggregateBuffer( + * objectAggregateBuffer: InternalRow, + * targetBuffer: MutableRow) + * }}} + * to serialize object A1 stored in objectAggregateBuffer to Spark SQL internally supported + * serializable format, and store the serialized format to targetBuffer MutableRow. The + * framework may persist the targetBuffer to disk if there is not enough memory, it is safe as + * all fields of targetBuffer MutableRow are serializable. + * 4. The framework moves on to next group, until all groups have been processed. */ trait WithObjectAggregateBuffer { this: ImperativeAggregate => /** - * Serializes and in-place replaces the object stored in Aggregation buffer. The framework - * calls this method every time after finishing updating/merging one group (group by key). - * - * aggregationBuffer before serialization: - * - * The object stored in aggregationBuffer can be **arbitrary** Java objects defined by user. + * Serializes the object stored at objectAggregateBuffer's index mutableAggBufferOffset to + * Spark SQL internally supported serializable format, and writes the serialized format + * to targetBuffer's index mutableAggBufferOffset. * - * aggregationBuffer after serialization: + * The framework calls this method every time after finishing updating/merging one + * group (group by key). * - * The object's type must be one of: + * - Source object aggregation buffer Before serialization: + * The object stored in object aggregation buffer can be **arbitrary** Java object type + * defined by user. * - * - Null - * - Boolean - * - Byte - * - Short - * - Int - * - Long - * - Float - * - Double - * - Array[Byte] - * - org.apache.spark.sql.types.Decimal - * - org.apache.spark.unsafe.types.UTF8String - * - org.apache.spark.unsafe.types.CalendarInterval - * - org.apache.spark.sql.catalyst.util.MapData - * - org.apache.spark.sql.catalyst.util.ArrayData - * - org.apache.spark.sql.catalyst.InternalRow + * - Target mutable buffer after serialization: + * The target mutable buffer is of type MutableRow. Each field's type need to be one of + * Spark SQL internally supported serializable formats, which are: + * - Null + * - Boolean + * - Byte + * - Short + * - Int + * - Long + * - Float + * - Double + * - Array[Byte] + * - org.apache.spark.sql.types.Decimal + * - org.apache.spark.unsafe.types.UTF8String + * - org.apache.spark.unsafe.types.CalendarInterval + * - org.apache.spark.sql.catalyst.util.MapData + * - org.apache.spark.sql.catalyst.util.ArrayData + * - org.apache.spark.sql.catalyst.InternalRow * * Code example: * * {{{ - * override def serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow): Unit = { + * def serializeObjectAggregateBuffer( + * objectAggregateBuffer: InternalRow, + * targetBuffer: MutableRow): Unit = { * val obj = buffer.get(mutableAggBufferOffset, ObjectType(classOf[A])).asInstanceOf[A] - * // Convert the obj to bytes, which is a serializable format. - * buffer(mutableAggBufferOffset) = toBytes(obj) + * // Convert the obj to Spark SQL internally supported serializable format( here it is + * // Array[Byte]) + * targetBuffer(mutableAggBufferOffset) = toBytes(obj) * } * }}} * - * @param aggregationBuffer aggregation buffer before serialization + * @param objectAggregateBuffer Source object aggregation buffer. Please use the index + * mutableAggBufferOffset to get buffered object of this aggregation + * function. + * @param targetBuffer Target buffer to hold the serialized format. Please use the index + * mutableAggBufferOffset to store the serialized format for this aggregation + * function. */ - def serializeObjectAggregationBufferInPlace(aggregationBuffer: MutableRow): Unit + def serializeObjectAggregateBuffer( + objectAggregateBuffer: InternalRow, + targetBuffer: MutableRow): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index c7c176f8fa7e..b4bd18353b52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -109,7 +109,8 @@ class SortBasedAggregationIterator( // stored in aggregation buffer to serializable format. private def serializeObjectAggregationBuffer(aggregationBuffer: MutableRow): Unit = { aggFunctionsWithObjectAggregationBuffer.foreach { agg => - agg.serializeObjectAggregationBufferInPlace(sortBasedAggregationBuffer) + // Serializes and **in-place** replaces the object stored in sortBasedAggregationBuffer + agg.serializeObjectAggregateBuffer(sortBasedAggregationBuffer, sortBasedAggregationBuffer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateWithObjectAggregateBufferSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateWithObjectAggregateBufferSuite.scala index 8daaf41ecfe3..6e57ce0ff1fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/AggregateWithObjectAggregateBufferSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateWithObjectAggregateBufferSuite.scala @@ -128,9 +128,9 @@ object AggregateWithObjectAggregateBufferSuite { new MaxValue((buffer.getInt(offset))) } - override def serializeObjectAggregationBufferInPlace(buffer: MutableRow): Unit = { + override def serializeObjectAggregateBuffer(buffer: InternalRow, target: MutableRow): Unit = { val bufferMax = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue] - buffer(mutableAggBufferOffset) = bufferMax.value + target(mutableAggBufferOffset) = bufferMax.value } override def eval(buffer: InternalRow): Any = {