From 10861b207e8cac0b7348b374d9054c4de03b7965 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Sat, 20 Aug 2016 00:34:56 +0800 Subject: [PATCH 01/12] object aggregation buffer --- .../expressions/aggregate/interfaces.scala | 145 ++++++++++++++ .../sql/execution/aggregate/AggUtils.scala | 11 +- .../SortBasedAggregationIterator.scala | 38 +++- .../sql/TypedImperativeAggregateSuite.scala | 178 ++++++++++++++++++ 4 files changed, 368 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 7a39e568fa289..ff55b95c2ad5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -389,3 +389,148 @@ abstract class DeclarativeAggregate def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) } } + +/** + * Aggregation function which allows **arbitrary** user-defined java object to be used as internal + * aggregation buffer object. + * + * {{{ + * aggregation buffer for normal aggregation function `avg` + * | + * v + * +--------------+---------------+-----------------------------------+ + * | sum1 (Long) | count1 (Long) | generic user-defined java objects | + * +--------------+---------------+-----------------------------------+ + * ^ + * | + * Aggregation buffer object for `TypedImperativeAggregate` aggregation function + * }}} + * + * Work flow (Partial mode aggregate at Mapper side, and Final mode aggregate at Reducer side): + * + * Stage 1: Partial aggregate at Mapper side: + * + * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation + * buffer object. + * 2. Upon each input row, the framework calls + * `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T. + * 3. After processing all rows of current group (group by key), the framework will serialize + * aggregation buffer object T to SparkSQL internally supported underlying storage format, and + * persist the serializable format to disk if needed. + * 4. The framework moves on to next group, until all groups have been processed. + * + * Shuffling exchange data to Reducer tasks... + * + * Stage 2: Final mode aggregate at Reducer side: + * + * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation + * buffer object (type T) for merging. + * 2. For each aggregation output of Stage 1, The framework de-serializes the storage + * format and generates one input aggregation object (type T). + * 3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit` + * to merge the input aggregation object into aggregation buffer object. + * 4. After processing all input aggregation objects of current group (group by key), the framework + * calls method `eval(buffer: T)` to generate the final output for this group. + * 5. The framework moves on to next group, until all groups have been processed. + */ +abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate { + + /** + * Spark Sql type of user-defined aggregation buffer object. It needs to be an `UserDefinedType` + * so that the framework knows how to serialize the aggregation buffer object to Spark sql + * internally supported storage format. + */ + def aggregationBufferType: UserDefinedType[T] + + /** + * Creates an empty aggregation buffer object. This is called before processing each key group + * (group by key). + * + * @return an aggregation buffer object + */ + def createAggregationBuffer(): T + + /** + * In-place updates the aggregation buffer object with an input row. buffer = buffer + input. + * This is typically called when doing Partial or Complete mode aggregation. + * + * @param buffer The aggregation buffer object. + * @param input an input row + */ + def update(buffer: T, input: InternalRow): Unit + + /** + * Merges an input aggregation object into aggregation buffer object. buffer = buffer + input. + * This is typically called when doing PartialMerge or Final mode aggregation. + * + * @param buffer the aggregation buffer object used to store the aggregation result. + * @param input an input aggregation object. Input aggregation object can be produced by + * de-serializing the partial aggregate's output from Mapper side. + */ + def merge(buffer: T, input: T): Unit + + /** + * Generates the final aggregation result value for current key group with the aggregation buffer + * object. + * + * @param buffer aggregation buffer object. + * @return The aggregation result of current key group + */ + def eval(buffer: T): Any + + final override def initialize(buffer: MutableRow): Unit = { + val bufferObject = createAggregationBuffer() + buffer.update(mutableAggBufferOffset, bufferObject) + } + + final override def update(buffer: MutableRow, input: InternalRow): Unit = { + val bufferObject = field(buffer, mutableAggBufferOffset).asInstanceOf[T] + update(bufferObject, input) + } + + final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { + val bufferObject = field(buffer, mutableAggBufferOffset).asInstanceOf[T] + val inputObject = deserialize(field(inputBuffer, inputAggBufferOffset)) + merge(bufferObject, inputObject) + } + + final override def eval(buffer: InternalRow): Any = { + val bufferObject = field(buffer, mutableAggBufferOffset) + if (bufferObject.getClass == aggregationBufferType.userClass) { + // When used in Window frame aggregation, eval(buffer: InternalRow) is called directly + // on the object aggregation buffer without intermediate serializing/de-serializing. + eval(bufferObject.asInstanceOf[T]) + } else { + eval(deserialize(bufferObject)) + } + } + + private def deserialize(input: AnyRef): T = { + aggregationBufferType.deserialize(input) + } + + private def field(input: InternalRow, offset: Int): AnyRef = { + input.get(offset, null) + } + + final override val aggBufferAttributes: Seq[AttributeReference] = { + // Underlying storage type for the aggregation buffer object + Seq(AttributeReference("buf", aggregationBufferType.sqlType)()) + } + + final override lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + /** + * In-place replaces the aggregation buffer object stored at buffer's index + * `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format. + * + * The framework calls this method every time after updating/merging one group (group by key). + */ + final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = { + val bufferObject = field(buffer, mutableAggBufferOffset).asInstanceOf[T] + buffer(mutableAggBufferOffset) = aggregationBufferType.serialize(bufferObject) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 4fbb9d554c9bf..def48528da891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -54,9 +54,16 @@ object AggUtils { initialInputBufferOffset: Int = 0, resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { - val useHash = HashAggregateExec.supportsAggregate( + + val hasTypedImperativeAggregate: Boolean = aggregateExpressions.exists { + case AggregateExpression(agg: TypedImperativeAggregate[_], _, _, _) => true + case _ => false + } + + val aggBufferAttributesSupportedByHashAggregate = HashAggregateExec.supportsAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - if (useHash) { + + if (aggBufferAttributesSupportedByHashAggregate && !hasTypedImperativeAggregate) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 3f7f84988594a..7448b1dbf9a51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, TypedImperativeAggregate} import org.apache.spark.sql.execution.metric.SQLMetric /** @@ -54,7 +54,15 @@ class SortBasedAggregationIterator( val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) + + val allFieldsMutable = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) + + val hasTypedImperativeAggregate = aggregateFunctions.exists { + case agg: TypedImperativeAggregate[_] => true + case _ => false + } + + val useUnsafeBuffer = allFieldsMutable && !hasTypedImperativeAggregate val buffer = if (useUnsafeBuffer) { val unsafeProjection = @@ -90,6 +98,24 @@ class SortBasedAggregationIterator( // compared to MutableRow (aggregation buffer) directly. private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) + // Aggregation function which uses generic aggregation buffer object. + // @see [[TypedImperativeAggregate]] for more information + private val typedImperativeAggregates: Array[TypedImperativeAggregate[_]] = { + aggregateFunctions.collect { + case (ag: TypedImperativeAggregate[_]) => ag + } + } + + // For TypedImperativeAggregate with generic aggregation buffer object, we need to call + // serializeAggregateBufferInPlace(...) explicitly to convert the aggregation buffer object + // to Spark Sql internally supported serializable storage format. + private def serializeTypedAggregateBuffer(aggregationBuffer: MutableRow): Unit = { + typedImperativeAggregates.foreach { agg => + // In-place serialization + agg.serializeAggregateBufferInPlace(sortBasedAggregationBuffer) + } + } + protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -131,6 +157,11 @@ class SortBasedAggregationIterator( firstRowInNextGroup = currentRow.copy() } } + + // Serializes the generic object stored in aggregation buffer for TypedImperativeAggregate + // aggregation functions. + serializeTypedAggregateBuffer(sortBasedAggregationBuffer) + // We have not seen a new group. It means that there is no new row in the input // iter. The current group is the last group of the iter. if (!findNextPartition) { @@ -162,6 +193,9 @@ class SortBasedAggregationIterator( def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { initializeBuffer(sortBasedAggregationBuffer) + // Serializes the generic object stored in aggregation buffer for TypedImperativeAggregate + // aggregation functions. + serializeTypedAggregateBuffer(sortBasedAggregationBuffer) generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala new file mode 100644 index 0000000000000..9b8740c10a351 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.execution.aggregate.SortAggregateExec +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, UserDefinedType} + +class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + private val data = Seq((1, 0), (3, 1), (2, 0), (6, 3), (3, 1), (4, 1), (5, 0)) + + test("aggregate with object aggregate buffer") { + val agg = new TypedMax(BoundReference(0, IntegerType, nullable = false)) + + val group1 = (0 until data.length / 2) + val group1Buffer = agg.createAggregationBuffer() + group1.foreach { index => + val input = InternalRow(data(index)._1, data(index)._2) + agg.update(group1Buffer, input) + } + + val group2 = (data.length / 2 until data.length) + val group2Buffer = agg.createAggregationBuffer() + group2.foreach { index => + val input = InternalRow(data(index)._1, data(index)._2) + agg.update(group2Buffer, input) + } + + val mergeBuffer = agg.createAggregationBuffer() + agg.merge(mergeBuffer, group1Buffer) + agg.merge(mergeBuffer, group2Buffer) + + assert(mergeBuffer.value == data.map(_._1).max) + assert(agg.eval(mergeBuffer) == data.map(_._1).max) + } + + test("dataframe aggregate with object aggregate buffer, should not use HashAggregate") { + val df = data.toDF("a", "b") + val max = new TypedMax($"a".expr) + + // Always use SortAggregateExec instead of HashAggregateExec for planning even if the aggregate + // buffer attributes are mutable fields (every field can be mutated inline like int, long...) + val allFieldsMutable = max.aggBufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) + val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan + assert(allFieldsMutable == true && sparkPlan.isInstanceOf[SortAggregateExec]) + } + + test("dataframe aggregate with object aggregate buffer, no group by") { + val df = data.toDF("a", "b").coalesce(2) + checkAnswer( + df.select(typedMax($"a"), count($"a"), typedMax($"b"), count($"b")), + Seq(Row(6, 7, 3, 7)) + ) + } + + test("dataframe aggregate with object aggregate buffer, with group by") { + val df = data.toDF("a", "b").coalesce(2) + checkAnswer( + df.groupBy($"b").agg(typedMax($"a"), count($"a"), typedMax($"a")), + Seq( + Row(0, 5, 3, 5), + Row(1, 4, 3, 4), + Row(3, 6, 1, 6) + ) + ) + } + + test("dataframe aggregate with object aggregate buffer, empty inputs, no group by") { + val empty = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + empty.select(typedMax($"a"), count($"a"), typedMax($"b"), count($"b")), + Seq(Row(Int.MinValue, 0, Int.MinValue, 0))) + } + + test("dataframe aggregate with object aggregate buffer, empty inputs, with group by") { + val empty = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + empty.groupBy($"b").agg(typedMax($"a"), count($"a"), typedMax($"a")), + Seq.empty[Row]) + } + + private def typedMax(column: Column): Column = { + val max = TypedMax(column.expr) + Column(max.toAggregateExpression()) + } +} + +object TypedImperativeAggregateSuite { + + /** + * Calculate the max value with object aggregation buffer. This stores class MaxValue + * in aggregation buffer. + */ + private case class TypedMax( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[MaxValue] { + + override lazy val aggregationBufferType: UserDefinedType[MaxValue] = new MaxValueUDT() + + override def createAggregationBuffer(): MaxValue = { + new MaxValue(Int.MinValue) + } + + override def update(buffer: MaxValue, input: InternalRow): Unit = { + val inputValue = child.eval(input).asInstanceOf[Int] + if (inputValue > buffer.value) { + buffer.value = inputValue + } + } + + override def merge(bufferMax: MaxValue, inputMax: MaxValue): Unit = { + if (inputMax.value > bufferMax.value) { + bufferMax.value = inputMax.value + } + } + + override def eval(bufferMax: MaxValue): Any = { + bufferMax.value + } + + override def nullable: Boolean = true + + override def deterministic: Boolean = false + + override def children: Seq[Expression] = Seq(child) + + override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) + + override def dataType: DataType = IntegerType + + override def withNewMutableAggBufferOffset(newOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newOffset) + + } + + private class MaxValue(var value: Int) + + private class MaxValueUDT extends UserDefinedType[MaxValue] { + override def sqlType: DataType = IntegerType + + override def serialize(obj: MaxValue): Any = obj.value + + override def userClass: Class[MaxValue] = classOf[MaxValue] + + override def deserialize(datum: Any): MaxValue = { + datum match { + case i: Int => new MaxValue(i) + } + } + } +} From 0fdc1eadf46c6db96cf479fd317e8e0b89e65b05 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Tue, 23 Aug 2016 09:29:28 +0800 Subject: [PATCH 02/12] fix comments --- .../expressions/aggregate/interfaces.scala | 58 ++++++++++++++----- .../sql/TypedImperativeAggregateSuite.scala | 26 ++++----- 2 files changed, 53 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index ff55b95c2ad5b..aea2f381c782a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -433,14 +433,7 @@ abstract class DeclarativeAggregate * calls method `eval(buffer: T)` to generate the final output for this group. * 5. The framework moves on to next group, until all groups have been processed. */ -abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate { - - /** - * Spark Sql type of user-defined aggregation buffer object. It needs to be an `UserDefinedType` - * so that the framework knows how to serialize the aggregation buffer object to Spark sql - * internally supported storage format. - */ - def aggregationBufferType: UserDefinedType[T] +abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { /** * Creates an empty aggregation buffer object. This is called before processing each key group @@ -478,6 +471,43 @@ abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate { */ def eval(buffer: T): Any + /** Returns the class of aggregation buffer object */ + def aggregationBufferClass: Class[T] + + /** Serializes the aggregation buffer object T to Spark-sql internally supported storage format */ + def serialize(buffer: T): Any + + /** De-serializes the storage format, and produces aggregation buffer object T */ + def deserialize(storageFormat: Any): T + + /** + * Returns the aggregation-buffer-object storage format's Sql type. + * + * Here is a list of supported storage format and corresponding Sql type: + * + * {{{ + * aggregation buffer object's Storage format | storage format's Sql type + * ------------------------------------------------------------------------------------------ + * Array[Byte] (*) | BinaryType (*) + * Null | NullType + * Boolean | BooleanType + * Byte | ByteType + * Short | ShortType + * Int | IntegerType + * Long | LongType + * Float | FloatType + * Double | DoubleType + * org.apache.spark.sql.types.Decimal | DecimalType + * org.apache.spark.unsafe.types.UTF8String | StringType + * org.apache.spark.unsafe.types.CalendarInterval| CalendarIntervalType + * org.apache.spark.sql.catalyst.util.MapData | MapType + * org.apache.spark.sql.catalyst.util.ArrayData | ArrayType + * org.apache.spark.sql.catalyst.InternalRow | + * }}} + * + */ + def aggregationBufferStorageFormatSqlType: DataType + final override def initialize(buffer: MutableRow): Unit = { val bufferObject = createAggregationBuffer() buffer.update(mutableAggBufferOffset, bufferObject) @@ -496,7 +526,7 @@ abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate { final override def eval(buffer: InternalRow): Any = { val bufferObject = field(buffer, mutableAggBufferOffset) - if (bufferObject.getClass == aggregationBufferType.userClass) { + if (bufferObject.getClass == aggregationBufferClass) { // When used in Window frame aggregation, eval(buffer: InternalRow) is called directly // on the object aggregation buffer without intermediate serializing/de-serializing. eval(bufferObject.asInstanceOf[T]) @@ -505,17 +535,13 @@ abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate { } } - private def deserialize(input: AnyRef): T = { - aggregationBufferType.deserialize(input) - } - private def field(input: InternalRow, offset: Int): AnyRef = { input.get(offset, null) } - final override val aggBufferAttributes: Seq[AttributeReference] = { + final override lazy val aggBufferAttributes: Seq[AttributeReference] = { // Underlying storage type for the aggregation buffer object - Seq(AttributeReference("buf", aggregationBufferType.sqlType)()) + Seq(AttributeReference("buf", aggregationBufferStorageFormatSqlType)()) } final override lazy val inputAggBufferAttributes: Seq[AttributeReference] = @@ -531,6 +557,6 @@ abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate { */ final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = { val bufferObject = field(buffer, mutableAggBufferOffset).asInstanceOf[T] - buffer(mutableAggBufferOffset) = aggregationBufferType.serialize(bufferObject) + buffer(mutableAggBufferOffset) = serialize(bufferObject) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 9b8740c10a351..eaf33c737299b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.{TypedImperativeAggregate} import org.apache.spark.sql.execution.aggregate.SortAggregateExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, UserDefinedType} +import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType} class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { @@ -119,7 +119,6 @@ object TypedImperativeAggregateSuite { mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[MaxValue] { - override lazy val aggregationBufferType: UserDefinedType[MaxValue] = new MaxValueUDT() override def createAggregationBuffer(): MaxValue = { new MaxValue(Int.MinValue) @@ -152,27 +151,24 @@ object TypedImperativeAggregateSuite { override def dataType: DataType = IntegerType - override def withNewMutableAggBufferOffset(newOffset: Int): ImperativeAggregate = + override def withNewMutableAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = copy(mutableAggBufferOffset = newOffset) - override def withNewInputAggBufferOffset(newOffset: Int): ImperativeAggregate = + override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = copy(inputAggBufferOffset = newOffset) - } - - private class MaxValue(var value: Int) + override def aggregationBufferClass: Class[MaxValue] = classOf[MaxValue] - private class MaxValueUDT extends UserDefinedType[MaxValue] { - override def sqlType: DataType = IntegerType + override def serialize(buffer: MaxValue): Any = buffer.value - override def serialize(obj: MaxValue): Any = obj.value + override def aggregationBufferStorageFormatSqlType: DataType = IntegerType - override def userClass: Class[MaxValue] = classOf[MaxValue] - - override def deserialize(datum: Any): MaxValue = { - datum match { + override def deserialize(storageFormat: Any): MaxValue = { + storageFormat match { case i: Int => new MaxValue(i) } } } + + private class MaxValue(var value: Int) } From d3108ab7ea1e10b8de31f1fd6546cc3275d6e48a Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Tue, 23 Aug 2016 09:42:22 +0800 Subject: [PATCH 03/12] fix review comments --- .../expressions/aggregate/interfaces.scala | 56 +++------ .../sql/TypedImperativeAggregateSuite.scala | 119 +++++++++++++----- 2 files changed, 104 insertions(+), 71 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index aea2f381c782a..fc563933059b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -415,8 +415,8 @@ abstract class DeclarativeAggregate * 2. Upon each input row, the framework calls * `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T. * 3. After processing all rows of current group (group by key), the framework will serialize - * aggregation buffer object T to SparkSQL internally supported underlying storage format, and - * persist the serializable format to disk if needed. + * aggregation buffer object T to storage format (Array[Byte]) and persist the Array[Byte] + * to disk if needed. * 4. The framework moves on to next group, until all groups have been processed. * * Shuffling exchange data to Reducer tasks... @@ -426,7 +426,7 @@ abstract class DeclarativeAggregate * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation * buffer object (type T) for merging. * 2. For each aggregation output of Stage 1, The framework de-serializes the storage - * format and generates one input aggregation object (type T). + * format (Array[Byte]) and produces one input aggregation object (type T). * 3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit` * to merge the input aggregation object into aggregation buffer object. * 4. After processing all input aggregation objects of current group (group by key), the framework @@ -474,39 +474,11 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { /** Returns the class of aggregation buffer object */ def aggregationBufferClass: Class[T] - /** Serializes the aggregation buffer object T to Spark-sql internally supported storage format */ - def serialize(buffer: T): Any + /** Serializes the aggregation buffer object T to Array[Byte] */ + def serialize(buffer: T): Array[Byte] - /** De-serializes the storage format, and produces aggregation buffer object T */ - def deserialize(storageFormat: Any): T - - /** - * Returns the aggregation-buffer-object storage format's Sql type. - * - * Here is a list of supported storage format and corresponding Sql type: - * - * {{{ - * aggregation buffer object's Storage format | storage format's Sql type - * ------------------------------------------------------------------------------------------ - * Array[Byte] (*) | BinaryType (*) - * Null | NullType - * Boolean | BooleanType - * Byte | ByteType - * Short | ShortType - * Int | IntegerType - * Long | LongType - * Float | FloatType - * Double | DoubleType - * org.apache.spark.sql.types.Decimal | DecimalType - * org.apache.spark.unsafe.types.UTF8String | StringType - * org.apache.spark.unsafe.types.CalendarInterval| CalendarIntervalType - * org.apache.spark.sql.catalyst.util.MapData | MapType - * org.apache.spark.sql.catalyst.util.ArrayData | ArrayType - * org.apache.spark.sql.catalyst.InternalRow | - * }}} - * - */ - def aggregationBufferStorageFormatSqlType: DataType + /** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */ + def deserialize(storageFormat: Array[Byte]): T final override def initialize(buffer: MutableRow): Unit = { val bufferObject = createAggregationBuffer() @@ -519,29 +491,29 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { } final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { - val bufferObject = field(buffer, mutableAggBufferOffset).asInstanceOf[T] - val inputObject = deserialize(field(inputBuffer, inputAggBufferOffset)) + val bufferObject = field[T](buffer, mutableAggBufferOffset) + val inputObject = deserialize(field[Array[Byte]](inputBuffer, inputAggBufferOffset)) merge(bufferObject, inputObject) } final override def eval(buffer: InternalRow): Any = { - val bufferObject = field(buffer, mutableAggBufferOffset) + val bufferObject = field[AnyRef](buffer, mutableAggBufferOffset) if (bufferObject.getClass == aggregationBufferClass) { // When used in Window frame aggregation, eval(buffer: InternalRow) is called directly // on the object aggregation buffer without intermediate serializing/de-serializing. eval(bufferObject.asInstanceOf[T]) } else { - eval(deserialize(bufferObject)) + eval(deserialize(bufferObject.asInstanceOf[Array[Byte]])) } } - private def field(input: InternalRow, offset: Int): AnyRef = { - input.get(offset, null) + private def field[U](input: InternalRow, fieldIndex: Int): U = { + input.get(fieldIndex, null).asInstanceOf[U] } final override lazy val aggBufferAttributes: Seq[AttributeReference] = { // Underlying storage type for the aggregation buffer object - Seq(AttributeReference("buf", aggregationBufferStorageFormatSqlType)()) + Seq(AttributeReference("buf", BinaryType)()) } final override lazy val inputAggBufferAttributes: Seq[AttributeReference] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index eaf33c737299b..1ab14019fd33b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -17,21 +17,27 @@ package org.apache.spark.sql +import com.google.common.primitives.Ints + import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.aggregate.{TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericMutableRow, SpecificMutableRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.execution.aggregate.SortAggregateExec +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType} +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, IntegerType, LongType} class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ - private val data = Seq((1, 0), (3, 1), (2, 0), (6, 3), (3, 1), (4, 1), (5, 0)) + private val random = new java.util.Random() + private val data = (0 until 1000).map { _ => + (random.nextInt(10), random.nextInt(100)) + } test("aggregate with object aggregate buffer") { val agg = new TypedMax(BoundReference(0, IntegerType, nullable = false)) @@ -55,37 +61,66 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { assert(mergeBuffer.value == data.map(_._1).max) assert(agg.eval(mergeBuffer) == data.map(_._1).max) + + // Tests low level eval(row: InternalRow) API. + val array: Array[Any] = Array(mergeBuffer) + val row = new GenericMutableRow(array) + + // Evaluates directly on row consist of aggregation buffer object. + assert(agg.eval(row) == data.map(_._1).max) + + // Serializes the aggregation buffer object and then evals. + agg.serializeAggregateBufferInPlace(row) + assert(agg.eval(row) == data.map(_._1).max) + } + + test("supports SpecificMutableRow as mutable row") { + val aggregationBufferSchema = Seq(IntegerType, LongType, BinaryType, IntegerType) + val aggBufferOffset = 2 + val inputBufferObject = 1 + val buffer = new SpecificMutableRow(aggregationBufferSchema) + val agg = new TypedMax(BoundReference(inputBufferObject, IntegerType, nullable = false)) + .withNewMutableAggBufferOffset(aggBufferOffset) + .withNewInputAggBufferOffset(inputBufferObject) + + agg.initialize(buffer) + data.foreach { kv => + val input = InternalRow(kv._1, kv._2) + agg.update(buffer, input) + } + assert(agg.eval(buffer) == data.map(_._2).max) } test("dataframe aggregate with object aggregate buffer, should not use HashAggregate") { val df = data.toDF("a", "b") val max = new TypedMax($"a".expr) - // Always use SortAggregateExec instead of HashAggregateExec for planning even if the aggregate - // buffer attributes are mutable fields (every field can be mutated inline like int, long...) - val allFieldsMutable = max.aggBufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) + // Always uses SortAggregateExec val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan - assert(allFieldsMutable == true && sparkPlan.isInstanceOf[SortAggregateExec]) + assert(sparkPlan.isInstanceOf[SortAggregateExec]) } test("dataframe aggregate with object aggregate buffer, no group by") { - val df = data.toDF("a", "b").coalesce(2) - checkAnswer( - df.select(typedMax($"a"), count($"a"), typedMax($"b"), count($"b")), - Seq(Row(6, 7, 3, 7)) - ) + val df = data.toDF("key", "value").coalesce(2) + val query = df.select(typedMax($"key"), count($"key"), typedMax($"value"), count($"value")) + val maxKey = data.map(_._1).max + val countKey = data.size + val maxValue = data.map(_._2).max + val countValue = data.size + val expected = Seq(Row(maxKey, countKey, maxValue, countValue)) + checkAnswer(query, expected) } test("dataframe aggregate with object aggregate buffer, with group by") { - val df = data.toDF("a", "b").coalesce(2) - checkAnswer( - df.groupBy($"b").agg(typedMax($"a"), count($"a"), typedMax($"a")), - Seq( - Row(0, 5, 3, 5), - Row(1, 4, 3, 4), - Row(3, 6, 1, 6) - ) - ) + val df = data.toDF("value", "key").coalesce(2) + val query = df.groupBy($"key").agg(typedMax($"value"), count($"value"), typedMax($"value")) + val expected = data.groupBy(_._2).toSeq.map { group => + val (key, values) = group + val valueMax = values.map(_._1).max + val countValue = values.size + Row(key, valueMax, countValue, valueMax) + } + checkAnswer(query, expected) } test("dataframe aggregate with object aggregate buffer, empty inputs, no group by") { @@ -102,6 +137,36 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { Seq.empty[Row]) } + test("TypedImperativeAggregate should not break Window function") { + val df = data.toDF("key", "value") + // OVER (PARTITION BY a ORDER BY b ROW BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + val w = Window.orderBy("value").partitionBy("key").rowsBetween(Long.MinValue, 0) + + val query = df.select(sum($"key").over(w), typedMax($"key").over(w), sum($"value").over(w), + typedMax($"value").over(w)) + + val expected = data.groupBy(_._1).toSeq.flatMap { group => + val (key, values) = group + val sortedValues = values.map(_._2).sorted + + var outputRows = Seq.empty[Row] + var i = 0 + while (i < sortedValues.size) { + val unboundedPrecedingAndCurrent = sortedValues.slice(0, i + 1) + val sumKey = key * unboundedPrecedingAndCurrent.size + val maxKey = key + val sumValue = unboundedPrecedingAndCurrent.sum + val maxValue = unboundedPrecedingAndCurrent.max + + outputRows :+= Row(sumKey, maxKey, sumValue, maxValue) + i += 1 + } + + outputRows + } + checkAnswer(query, expected) + } + private def typedMax(column: Column): Column = { val max = TypedMax(column.expr) Column(max.toAggregateExpression()) @@ -159,14 +224,10 @@ object TypedImperativeAggregateSuite { override def aggregationBufferClass: Class[MaxValue] = classOf[MaxValue] - override def serialize(buffer: MaxValue): Any = buffer.value + override def serialize(buffer: MaxValue): Array[Byte] = Ints.toByteArray(buffer.value) - override def aggregationBufferStorageFormatSqlType: DataType = IntegerType - - override def deserialize(storageFormat: Any): MaxValue = { - storageFormat match { - case i: Int => new MaxValue(i) - } + override def deserialize(storageFormat: Array[Byte]): MaxValue = { + new MaxValue(Ints.fromByteArray(storageFormat)) } } From 2873765dcc3cb2d57935a68f77f8e6e2585929c9 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Wed, 24 Aug 2016 07:23:39 +0800 Subject: [PATCH 04/12] fix review comments --- .../catalyst/expressions/aggregate/interfaces.scala | 12 ++++++------ .../spark/sql/TypedImperativeAggregateSuite.scala | 3 +-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index fc563933059b7..e9c4e5d13fa66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -486,18 +486,18 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { } final override def update(buffer: MutableRow, input: InternalRow): Unit = { - val bufferObject = field(buffer, mutableAggBufferOffset).asInstanceOf[T] + val bufferObject = getField(buffer, mutableAggBufferOffset).asInstanceOf[T] update(bufferObject, input) } final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { - val bufferObject = field[T](buffer, mutableAggBufferOffset) - val inputObject = deserialize(field[Array[Byte]](inputBuffer, inputAggBufferOffset)) + val bufferObject = getField[T](buffer, mutableAggBufferOffset) + val inputObject = deserialize(getField[Array[Byte]](inputBuffer, inputAggBufferOffset)) merge(bufferObject, inputObject) } final override def eval(buffer: InternalRow): Any = { - val bufferObject = field[AnyRef](buffer, mutableAggBufferOffset) + val bufferObject = getField[AnyRef](buffer, mutableAggBufferOffset) if (bufferObject.getClass == aggregationBufferClass) { // When used in Window frame aggregation, eval(buffer: InternalRow) is called directly // on the object aggregation buffer without intermediate serializing/de-serializing. @@ -507,7 +507,7 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { } } - private def field[U](input: InternalRow, fieldIndex: Int): U = { + private def getField[U](input: InternalRow, fieldIndex: Int): U = { input.get(fieldIndex, null).asInstanceOf[U] } @@ -528,7 +528,7 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { * The framework calls this method every time after updating/merging one group (group by key). */ final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = { - val bufferObject = field(buffer, mutableAggBufferOffset).asInstanceOf[T] + val bufferObject = getField(buffer, mutableAggBufferOffset).asInstanceOf[T] buffer(mutableAggBufferOffset) = serialize(bufferObject) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 1ab14019fd33b..252449dc15ba0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -63,8 +63,7 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { assert(agg.eval(mergeBuffer) == data.map(_._1).max) // Tests low level eval(row: InternalRow) API. - val array: Array[Any] = Array(mergeBuffer) - val row = new GenericMutableRow(array) + val row = new GenericMutableRow(Array(mergeBuffer): Array[Any]) // Evaluates directly on row consist of aggregation buffer object. assert(agg.eval(row) == data.map(_._1).max) From 7190eb0c2a4dce2c5b84c29fb90bb2def23a3520 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Wed, 24 Aug 2016 11:20:56 +0800 Subject: [PATCH 05/12] fix review comments --- .../catalyst/expressions/aggregate/interfaces.scala | 9 ++++++++- .../spark/sql/execution/aggregate/AggUtils.scala | 11 ++--------- .../aggregate/SortBasedAggregationIterator.scala | 9 +-------- 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index e9c4e5d13fa66..1300e9a6e4c29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -432,6 +432,12 @@ abstract class DeclarativeAggregate * 4. After processing all input aggregation objects of current group (group by key), the framework * calls method `eval(buffer: T)` to generate the final output for this group. * 5. The framework moves on to next group, until all groups have been processed. + * + * NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation, + * instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation + * buffer's storage format, which is not supported by hash based aggregation. Hash based + * aggregation only support aggregation buffer of mutable types (like LongType, IntType that have + * fixed length and can be mutated in place in UnsafeRow) */ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { @@ -507,8 +513,9 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { } } + private[this] val anyObjectType = ObjectType(classOf[AnyRef]) private def getField[U](input: InternalRow, fieldIndex: Int): U = { - input.get(fieldIndex, null).asInstanceOf[U] + input.get(fieldIndex, anyObjectType).asInstanceOf[U] } final override lazy val aggBufferAttributes: Seq[AttributeReference] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index def48528da891..4fbb9d554c9bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -54,16 +54,9 @@ object AggUtils { initialInputBufferOffset: Int = 0, resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { - - val hasTypedImperativeAggregate: Boolean = aggregateExpressions.exists { - case AggregateExpression(agg: TypedImperativeAggregate[_], _, _, _) => true - case _ => false - } - - val aggBufferAttributesSupportedByHashAggregate = HashAggregateExec.supportsAggregate( + val useHash = HashAggregateExec.supportsAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - - if (aggBufferAttributesSupportedByHashAggregate && !hasTypedImperativeAggregate) { + if (useHash) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 7448b1dbf9a51..9a28eaa67104b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -55,14 +55,7 @@ class SortBasedAggregationIterator( val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val allFieldsMutable = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) - - val hasTypedImperativeAggregate = aggregateFunctions.exists { - case agg: TypedImperativeAggregate[_] => true - case _ => false - } - - val useUnsafeBuffer = allFieldsMutable && !hasTypedImperativeAggregate + val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) val buffer = if (useUnsafeBuffer) { val unsafeProjection = From 5904bcd2eb523b6f3e744925a0e9d9da52f6ae0b Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Wed, 24 Aug 2016 13:06:15 +0800 Subject: [PATCH 06/12] on viirya's comment --- .../spark/sql/catalyst/expressions/aggregate/interfaces.scala | 4 ++-- .../execution/aggregate/SortBasedAggregationIterator.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 1300e9a6e4c29..82eb0a1a5f1ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -492,7 +492,7 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { } final override def update(buffer: MutableRow, input: InternalRow): Unit = { - val bufferObject = getField(buffer, mutableAggBufferOffset).asInstanceOf[T] + val bufferObject = getField[T](buffer, mutableAggBufferOffset) update(bufferObject, input) } @@ -535,7 +535,7 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { * The framework calls this method every time after updating/merging one group (group by key). */ final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = { - val bufferObject = getField(buffer, mutableAggBufferOffset).asInstanceOf[T] + val bufferObject = getField[T](buffer, mutableAggBufferOffset) buffer(mutableAggBufferOffset) = serialize(bufferObject) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 9a28eaa67104b..7547c4f0c6daa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -105,7 +105,7 @@ class SortBasedAggregationIterator( private def serializeTypedAggregateBuffer(aggregationBuffer: MutableRow): Unit = { typedImperativeAggregates.foreach { agg => // In-place serialization - agg.serializeAggregateBufferInPlace(sortBasedAggregationBuffer) + agg.serializeAggregateBufferInPlace(aggregationBuffer) } } From 8c8bd9a293bafabb4d077928f7a76cd10f36c772 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Wed, 24 Aug 2016 17:07:44 +0800 Subject: [PATCH 07/12] on yin's comment --- .../expressions/aggregate/interfaces.scala | 14 ++---- .../aggregate/AggregationIterator.scala | 13 +++++ .../SortBasedAggregationIterator.scala | 29 +---------- .../sql/TypedImperativeAggregateSuite.scala | 49 +++++++++++++++---- 4 files changed, 56 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 82eb0a1a5f1ca..114c9b96bd743 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -477,9 +477,6 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { */ def eval(buffer: T): Any - /** Returns the class of aggregation buffer object */ - def aggregationBufferClass: Class[T] - /** Serializes the aggregation buffer object T to Array[Byte] */ def serialize(buffer: T): Array[Byte] @@ -498,19 +495,14 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { val bufferObject = getField[T](buffer, mutableAggBufferOffset) + // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate val inputObject = deserialize(getField[Array[Byte]](inputBuffer, inputAggBufferOffset)) merge(bufferObject, inputObject) } final override def eval(buffer: InternalRow): Any = { - val bufferObject = getField[AnyRef](buffer, mutableAggBufferOffset) - if (bufferObject.getClass == aggregationBufferClass) { - // When used in Window frame aggregation, eval(buffer: InternalRow) is called directly - // on the object aggregation buffer without intermediate serializing/de-serializing. - eval(bufferObject.asInstanceOf[T]) - } else { - eval(deserialize(bufferObject.asInstanceOf[Array[Byte]])) - } + val bufferObject = getField[T](buffer, mutableAggBufferOffset) + eval(bufferObject) } private[this] val anyObjectType = ObjectType(classOf[AnyRef]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 34de76dd4ab4e..cf55f70b81e26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -234,7 +234,20 @@ abstract class AggregationIterator( val resultProjection = UnsafeProjection.create( groupingAttributes ++ bufferAttributes, groupingAttributes ++ bufferAttributes) + + // TypedImperativeAggregate stores generic object in aggregation buffer, and requires + // calling serialization before shuffling. See [[TypedImperativeAggregate]] for more info. + val typedImperativeAggregates: Array[TypedImperativeAggregate[_]] = { + aggregateFunctions.collect { + case (ag: TypedImperativeAggregate[_]) => ag + } + } + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + // Serializes the generic object stored in aggregation buffer + typedImperativeAggregates.foreach { agg => + agg.serializeAggregateBufferInPlace(currentBuffer) + } resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 7547c4f0c6daa..3f7f84988594a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.execution.metric.SQLMetric /** @@ -54,7 +54,6 @@ class SortBasedAggregationIterator( val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) val buffer = if (useUnsafeBuffer) { @@ -91,24 +90,6 @@ class SortBasedAggregationIterator( // compared to MutableRow (aggregation buffer) directly. private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) - // Aggregation function which uses generic aggregation buffer object. - // @see [[TypedImperativeAggregate]] for more information - private val typedImperativeAggregates: Array[TypedImperativeAggregate[_]] = { - aggregateFunctions.collect { - case (ag: TypedImperativeAggregate[_]) => ag - } - } - - // For TypedImperativeAggregate with generic aggregation buffer object, we need to call - // serializeAggregateBufferInPlace(...) explicitly to convert the aggregation buffer object - // to Spark Sql internally supported serializable storage format. - private def serializeTypedAggregateBuffer(aggregationBuffer: MutableRow): Unit = { - typedImperativeAggregates.foreach { agg => - // In-place serialization - agg.serializeAggregateBufferInPlace(aggregationBuffer) - } - } - protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -150,11 +131,6 @@ class SortBasedAggregationIterator( firstRowInNextGroup = currentRow.copy() } } - - // Serializes the generic object stored in aggregation buffer for TypedImperativeAggregate - // aggregation functions. - serializeTypedAggregateBuffer(sortBasedAggregationBuffer) - // We have not seen a new group. It means that there is no new row in the input // iter. The current group is the last group of the iter. if (!findNextPartition) { @@ -186,9 +162,6 @@ class SortBasedAggregationIterator( def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { initializeBuffer(sortBasedAggregationBuffer) - // Serializes the generic object stored in aggregation buffer for TypedImperativeAggregate - // aggregation functions. - serializeTypedAggregateBuffer(sortBasedAggregationBuffer) generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 252449dc15ba0..91bd89f0756fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -38,6 +38,7 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { private val data = (0 until 1000).map { _ => (random.nextInt(10), random.nextInt(100)) } + test("aggregate with object aggregate buffer") { val agg = new TypedMax(BoundReference(0, IntegerType, nullable = false)) @@ -67,10 +68,6 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { // Evaluates directly on row consist of aggregation buffer object. assert(agg.eval(row) == data.map(_._1).max) - - // Serializes the aggregation buffer object and then evals. - agg.serializeAggregateBufferInPlace(row) - assert(agg.eval(row) == data.map(_._1).max) } test("supports SpecificMutableRow as mutable row") { @@ -110,6 +107,36 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } + test("dataframe aggregate with object aggregate buffer, null expression, no group by") { + val df = data.toDF("key", "value").coalesce(2) + val query = df.select(typedMax(lit(null)), count($"key"), typedMax(lit(null)), + count($"value")) + val maxNull = Int.MinValue + val countKey = data.size + val countValue = data.size + val expected = Seq(Row(maxNull, countKey, maxNull, countValue)) + checkAnswer(query, expected) + } + + test("dataframe aggregation with object aggregate buffer, input row contains null") { + + val nullableData = (0 until 1000).map {id => + val nullableKey: Integer = if (random.nextBoolean()) null else random.nextInt(100) + val nullableValue: Integer = if (random.nextBoolean()) null else random.nextInt(100) + (nullableKey, nullableValue) + } + + val df = nullableData.toDF("key", "value").coalesce(2) + val query = df.select(typedMax($"key"), count($"key"), typedMax($"value"), + count($"value")) + val maxKey = nullableData.map(_._1).filter(_ != null).max + val countKey = nullableData.map(_._1).filter(_ != null).size + val maxValue = nullableData.map(_._2).filter(_ != null).max + val countValue = nullableData.map(_._2).filter(_ != null).size + val expected = Seq(Row(maxKey, countKey, maxValue, countValue)) + checkAnswer(query, expected) + } + test("dataframe aggregate with object aggregate buffer, with group by") { val df = data.toDF("value", "key").coalesce(2) val query = df.groupBy($"key").agg(typedMax($"value"), count($"value"), typedMax($"value")) @@ -185,13 +212,17 @@ object TypedImperativeAggregateSuite { override def createAggregationBuffer(): MaxValue = { + // Returns Int.MinValue if all inputs are null new MaxValue(Int.MinValue) } override def update(buffer: MaxValue, input: InternalRow): Unit = { - val inputValue = child.eval(input).asInstanceOf[Int] - if (inputValue > buffer.value) { - buffer.value = inputValue + child.eval(input) match { + case inputValue: Int => + if (inputValue > buffer.value) { + buffer.value = inputValue + } + case null => buffer } } @@ -207,7 +238,7 @@ object TypedImperativeAggregateSuite { override def nullable: Boolean = true - override def deterministic: Boolean = false + override def deterministic: Boolean = true override def children: Seq[Expression] = Seq(child) @@ -221,8 +252,6 @@ object TypedImperativeAggregateSuite { override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = copy(inputAggBufferOffset = newOffset) - override def aggregationBufferClass: Class[MaxValue] = classOf[MaxValue] - override def serialize(buffer: MaxValue): Array[Byte] = Ints.toByteArray(buffer.value) override def deserialize(storageFormat: Array[Byte]): MaxValue = { From 7e7cb8546731bd3a50ce3c32d2ba275a8c25aafb Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Wed, 24 Aug 2016 21:29:38 +0800 Subject: [PATCH 08/12] On wenchen's comment --- .../sql/catalyst/expressions/aggregate/interfaces.scala | 2 -- .../apache/spark/sql/TypedImperativeAggregateSuite.scala | 6 ++---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 114c9b96bd743..a6cb243f8bb14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -523,8 +523,6 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { /** * In-place replaces the aggregation buffer object stored at buffer's index * `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format. - * - * The framework calls this method every time after updating/merging one group (group by key). */ final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = { val bufferObject = getField[T](buffer, mutableAggBufferOffset) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 91bd89f0756fa..097d4ac7e3716 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -73,11 +73,9 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { test("supports SpecificMutableRow as mutable row") { val aggregationBufferSchema = Seq(IntegerType, LongType, BinaryType, IntegerType) val aggBufferOffset = 2 - val inputBufferObject = 1 val buffer = new SpecificMutableRow(aggregationBufferSchema) - val agg = new TypedMax(BoundReference(inputBufferObject, IntegerType, nullable = false)) + val agg = new TypedMax(BoundReference(ordinal = 1, dataType = IntegerType, nullable = false)) .withNewMutableAggBufferOffset(aggBufferOffset) - .withNewInputAggBufferOffset(inputBufferObject) agg.initialize(buffer) data.foreach { kv => @@ -222,7 +220,7 @@ object TypedImperativeAggregateSuite { if (inputValue > buffer.value) { buffer.value = inputValue } - case null => buffer + case null => // skip } } From 86166a12a1cc411f93bb96a5cf3c284a1ffb651c Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Wed, 24 Aug 2016 21:42:56 +0800 Subject: [PATCH 09/12] On wenchen's comment --- .../spark/sql/catalyst/expressions/aggregate/interfaces.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index a6cb243f8bb14..8ef5311bf923a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -522,7 +522,8 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { /** * In-place replaces the aggregation buffer object stored at buffer's index - * `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format. + * `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format + * (BinaryType). */ final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = { val bufferObject = getField[T](buffer, mutableAggBufferOffset) From e060d213a287385eec6f9f568d1aae1f9c7ddfee Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Thu, 25 Aug 2016 06:39:32 +0800 Subject: [PATCH 10/12] On wenchen's comment --- .../spark/sql/catalyst/expressions/aggregate/interfaces.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 8ef5311bf923a..ecbaa2f4669b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -496,7 +496,7 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { val bufferObject = getField[T](buffer, mutableAggBufferOffset) // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate - val inputObject = deserialize(getField[Array[Byte]](inputBuffer, inputAggBufferOffset)) + val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset)) merge(bufferObject, inputObject) } From ac8e36ae768cf872a96f4c076ad471ab688470ad Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Thu, 25 Aug 2016 10:44:06 +0800 Subject: [PATCH 11/12] add test for nullable aggregation function --- .../aggregate/AggregationIterator.scala | 4 ++ .../sql/TypedImperativeAggregateSuite.scala | 59 +++++++++++++++---- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index cf55f70b81e26..4412862e35cce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -245,6 +245,10 @@ abstract class AggregationIterator( (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { // Serializes the generic object stored in aggregation buffer + var i = 0 + while (i < typedImperativeAggregates.length) { + i += 1 + } typedImperativeAggregates.foreach { agg => agg.serializeAggregateBufferInPlace(currentBuffer) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 097d4ac7e3716..b5eb16b6f650b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql -import com.google.common.primitives.Ints +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericMutableRow, SpecificMutableRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericMutableRow, SpecificMutableRow} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.execution.aggregate.SortAggregateExec import org.apache.spark.sql.expressions.Window @@ -105,10 +105,14 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } - test("dataframe aggregate with object aggregate buffer, null expression, no group by") { + test("dataframe aggregate with object aggregate buffer, non-nullable aggregator") { val df = data.toDF("key", "value").coalesce(2) + + // Test non-nullable typedMax val query = df.select(typedMax(lit(null)), count($"key"), typedMax(lit(null)), count($"value")) + + // typedMax is not nullable val maxNull = Int.MinValue val countKey = data.size val countValue = data.size @@ -116,6 +120,21 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } + test("dataframe aggregate with object aggregate buffer, nullable aggregator") { + val df = data.toDF("key", "value").coalesce(2) + + // Test nullable nullableTypedMax + val query = df.select(nullableTypedMax(lit(null)), count($"key"), nullableTypedMax(lit(null)), + count($"value")) + + // nullableTypedMax is nullable + val maxNull = null + val countKey = data.size + val countValue = data.size + val expected = Seq(Row(maxNull, countKey, maxNull, countValue)) + checkAnswer(query, expected) + } + test("dataframe aggregation with object aggregate buffer, input row contains null") { val nullableData = (0 until 1000).map {id => @@ -192,7 +211,12 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { } private def typedMax(column: Column): Column = { - val max = TypedMax(column.expr) + val max = TypedMax(column.expr, nullable = false) + Column(max.toAggregateExpression()) + } + + private def nullableTypedMax(column: Column): Column = { + val max = TypedMax(column.expr, nullable = true) Column(max.toAggregateExpression()) } } @@ -205,6 +229,7 @@ object TypedImperativeAggregateSuite { */ private case class TypedMax( child: Expression, + nullable: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[MaxValue] { @@ -219,6 +244,7 @@ object TypedImperativeAggregateSuite { case inputValue: Int => if (inputValue > buffer.value) { buffer.value = inputValue + buffer.isValueSet = true } case null => // skip } @@ -227,15 +253,18 @@ object TypedImperativeAggregateSuite { override def merge(bufferMax: MaxValue, inputMax: MaxValue): Unit = { if (inputMax.value > bufferMax.value) { bufferMax.value = inputMax.value + bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet } } override def eval(bufferMax: MaxValue): Any = { - bufferMax.value + if (nullable && bufferMax.isValueSet == false) { + null + } else { + bufferMax.value + } } - override def nullable: Boolean = true - override def deterministic: Boolean = true override def children: Seq[Expression] = Seq(child) @@ -250,12 +279,22 @@ object TypedImperativeAggregateSuite { override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = copy(inputAggBufferOffset = newOffset) - override def serialize(buffer: MaxValue): Array[Byte] = Ints.toByteArray(buffer.value) + override def serialize(buffer: MaxValue): Array[Byte] = { + val out = new ByteArrayOutputStream() + val stream = new DataOutputStream(out) + stream.writeBoolean(buffer.isValueSet) + stream.writeInt(buffer.value) + out.toByteArray + } override def deserialize(storageFormat: Array[Byte]): MaxValue = { - new MaxValue(Ints.fromByteArray(storageFormat)) + val in = new ByteArrayInputStream(storageFormat) + val stream = new DataInputStream(in) + val isValueSet = stream.readBoolean() + val value = stream.readInt() + new MaxValue(value, isValueSet) } } - private class MaxValue(var value: Int) + private class MaxValue(var value: Int, var isValueSet: Boolean = false) } From ca574e145543c6fc555220fa8080bf7fbe152ba5 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 25 Aug 2016 08:35:20 -0700 Subject: [PATCH 12/12] use while loop --- .../spark/sql/execution/aggregate/AggregationIterator.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 4412862e35cce..dfed084fe64a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -247,11 +247,9 @@ abstract class AggregationIterator( // Serializes the generic object stored in aggregation buffer var i = 0 while (i < typedImperativeAggregates.length) { + typedImperativeAggregates(i).serializeAggregateBufferInPlace(currentBuffer) i += 1 } - typedImperativeAggregates.foreach { agg => - agg.serializeAggregateBufferInPlace(currentBuffer) - } resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } } else {