-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SQL][WIP][Test] Supports object-based aggregation function which can store arbitrary objects in aggregation buffer. #14723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -389,3 +389,126 @@ abstract class DeclarativeAggregate | |
| def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * This traits allows an AggregateFunction to store **arbitrary** Java objects in internal | ||
| * aggregation buffer during aggregation of each key group. The **arbitrary** Java object can be | ||
| * used to store the accumulated aggregation result. | ||
| * | ||
| * This trait must be mixed with class ImperativeAggregate. | ||
| * | ||
| * {{{ | ||
| * aggregation buffer for function avg | ||
| * | | | ||
| * v v | ||
| * +--------------+---------------+----------------------+ | ||
| * | sum1 (Long) | count1 (Long) | generic java object | | ||
| * +--------------+---------------+----------------------+ | ||
| * ^ | ||
| * | | ||
| * Aggregation buffer for aggregation-function with WithObjectAggregateBuffer | ||
| * }}} | ||
| * | ||
| * Here is how it works in a typical aggregation flow (Partial mode aggregate at Mapper side, and | ||
| * Final mode aggregate at Reducer side). | ||
| * | ||
| * Stage 1: Partial aggregate at Mapper side: | ||
| * | ||
| * 1. Upon calling method `initialize(aggBuffer: MutableRow)`, user stores an arbitrary empty | ||
| * object, object A for example, in internal aggBuffer. The object A will be used to store the | ||
| * accumulated aggregation result. | ||
| * 2. Upon calling method `update(mutableAggBuffer: MutableRow, inputRow: InternalRow)` in | ||
| * current group (group by key), user extracts object A from mutableAggBuffer, and then updates | ||
| * object A with current inputRow. After updating, object A is stored back to mutableAggBuffer. | ||
| * 3. After processing all rows of current group, the framework will call method | ||
| * {{{ | ||
| * serializeObjectAggregateBuffer( | ||
| * objectAggregateBuffer: InternalRow, | ||
| * targetBuffer: MutableRow) | ||
| * }}} | ||
| * to serialize object A stored in objectAggregateBuffer to Spark SQL internally supported | ||
| * serializable format, and write the serialized format to targetBuffer MutableRow. | ||
| * The framework may persist the targetBuffer to disk if there is not enough memory, it is safe | ||
| * as all fields of targetBuffer MutableRow are serializable | ||
| * 4. The framework moves on to next group, until all groups have been processed. | ||
| * | ||
| * Shuffling exchange data to Reducer tasks... | ||
| * | ||
| * Stage 2: Final mode aggregate at Reducer side: | ||
| * | ||
| * 1. Upon calling method `initialize(aggBuffer: MutableRow)`, user stores a new empty object A1 | ||
| * in internal aggBuffer. The object A1 will be used to store the accumulated aggregation result. | ||
| * 2. Upon calling method `merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow)`, user | ||
| * extracts object A1 from mutableAggBuffer, and extracts object A2 from inputAggBuffer. then | ||
| * user needs to merge A1, and A2, and stores the merged result back to mutableAggBuffer. | ||
| * 3. After processing all inputAggBuffer of current group (group by key), the framework will | ||
| * call method: | ||
| * {{{ | ||
| * serializeObjectAggregateBuffer( | ||
| * objectAggregateBuffer: InternalRow, | ||
| * targetBuffer: MutableRow) | ||
| * }}} | ||
| * to serialize object A1 stored in objectAggregateBuffer to Spark SQL internally supported | ||
| * serializable format, and store the serialized format to targetBuffer MutableRow. The | ||
| * framework may persist the targetBuffer to disk if there is not enough memory, it is safe as | ||
| * all fields of targetBuffer MutableRow are serializable. | ||
| * 4. The framework moves on to next group, until all groups have been processed. | ||
| */ | ||
| trait WithObjectAggregateBuffer { | ||
| this: ImperativeAggregate => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Semes we do not really need this line.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess having this line will make this trait hard to be used in Java.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, seems this trait will be still an java |
||
|
|
||
| /** | ||
| * Serializes the object stored at objectAggregateBuffer's index mutableAggBufferOffset to | ||
| * Spark SQL internally supported serializable format, and writes the serialized format | ||
| * to targetBuffer's index mutableAggBufferOffset. | ||
| * | ||
| * The framework calls this method every time after finishing updating/merging one | ||
| * group (group by key). | ||
| * | ||
| * - Source object aggregation buffer Before serialization: | ||
| * The object stored in object aggregation buffer can be **arbitrary** Java object type | ||
| * defined by user. | ||
| * | ||
| * - Target mutable buffer after serialization: | ||
| * The target mutable buffer is of type MutableRow. Each field's type need to be one of | ||
| * Spark SQL internally supported serializable formats, which are: | ||
| * - Null | ||
| * - Boolean | ||
| * - Byte | ||
| * - Short | ||
| * - Int | ||
| * - Long | ||
| * - Float | ||
| * - Double | ||
| * - Array[Byte] | ||
| * - org.apache.spark.sql.types.Decimal | ||
| * - org.apache.spark.unsafe.types.UTF8String | ||
| * - org.apache.spark.unsafe.types.CalendarInterval | ||
| * - org.apache.spark.sql.catalyst.util.MapData | ||
| * - org.apache.spark.sql.catalyst.util.ArrayData | ||
| * - org.apache.spark.sql.catalyst.InternalRow | ||
| * | ||
| * Code example: | ||
| * | ||
| * {{{ | ||
| * def serializeObjectAggregateBuffer( | ||
| * objectAggregateBuffer: InternalRow, | ||
| * targetBuffer: MutableRow): Unit = { | ||
| * val obj = buffer.get(mutableAggBufferOffset, ObjectType(classOf[A])).asInstanceOf[A] | ||
| * // Convert the obj to Spark SQL internally supported serializable format( here it is | ||
| * // Array[Byte]) | ||
| * targetBuffer(mutableAggBufferOffset) = toBytes(obj) | ||
| * } | ||
| * }}} | ||
| * | ||
| * @param objectAggregateBuffer Source object aggregation buffer. Please use the index | ||
| * mutableAggBufferOffset to get buffered object of this aggregation | ||
| * function. | ||
| * @param targetBuffer Target buffer to hold the serialized format. Please use the index | ||
| * mutableAggBufferOffset to store the serialized format for this aggregation | ||
| * function. | ||
| */ | ||
| def serializeObjectAggregateBuffer( | ||
| objectAggregateBuffer: InternalRow, | ||
| targetBuffer: MutableRow): Unit | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate | |
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, WithObjectAggregateBuffer} | ||
| import org.apache.spark.sql.execution.metric.SQLMetric | ||
|
|
||
| /** | ||
|
|
@@ -54,7 +54,15 @@ class SortBasedAggregationIterator( | |
| val bufferRowSize: Int = bufferSchema.length | ||
|
|
||
| val genericMutableBuffer = new GenericMutableRow(bufferRowSize) | ||
| val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) | ||
|
|
||
| val allFieldsMutable = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) | ||
|
|
||
| val isUsingObjectAggregationBuffer = aggregateFunctions.exists { | ||
| case agg: WithObjectAggregateBuffer => true | ||
| case _ => false | ||
| } | ||
|
|
||
| val useUnsafeBuffer = allFieldsMutable && !isUsingObjectAggregationBuffer | ||
|
|
||
| val buffer = if (useUnsafeBuffer) { | ||
| val unsafeProjection = | ||
|
|
@@ -90,6 +98,22 @@ class SortBasedAggregationIterator( | |
| // compared to MutableRow (aggregation buffer) directly. | ||
| private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) | ||
|
|
||
| // AggregationFunction which store generic object in AggregationBuffer. | ||
| // @see [[WithObjectAggregationBuffer]] for more information | ||
| private val aggFunctionsWithObjectAggregationBuffer = aggregateFunctions.collect { | ||
| case (ag: ImperativeAggregate with WithObjectAggregateBuffer) => ag | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about we make
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ImperativeAggregate is an abstract class, that will make
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Heavy? In what sense? |
||
| } | ||
|
|
||
| // For AggregateFunction with generic object stored in aggregation buffer, we need to | ||
| // call serializeObjectAggregationBufferInPlace() explicitly to convert the generic object | ||
| // stored in aggregation buffer to serializable format. | ||
| private def serializeObjectAggregationBuffer(aggregationBuffer: MutableRow): Unit = { | ||
| aggFunctionsWithObjectAggregationBuffer.foreach { agg => | ||
| // Serializes and **in-place** replaces the object stored in sortBasedAggregationBuffer | ||
| agg.serializeObjectAggregateBuffer(sortBasedAggregationBuffer, sortBasedAggregationBuffer) | ||
| } | ||
| } | ||
|
|
||
| protected def initialize(): Unit = { | ||
| if (inputIterator.hasNext) { | ||
| initializeBuffer(sortBasedAggregationBuffer) | ||
|
|
@@ -131,6 +155,10 @@ class SortBasedAggregationIterator( | |
| firstRowInNextGroup = currentRow.copy() | ||
| } | ||
| } | ||
|
|
||
| // Serializes the generic object stored in aggregation buffer. | ||
| serializeObjectAggregationBuffer(sortBasedAggregationBuffer) | ||
|
|
||
| // We have not seen a new group. It means that there is no new row in the input | ||
| // iter. The current group is the last group of the iter. | ||
| if (!findNextPartition) { | ||
|
|
@@ -162,6 +190,8 @@ class SortBasedAggregationIterator( | |
|
|
||
| def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { | ||
| initializeBuffer(sortBasedAggregationBuffer) | ||
| // Serializes the generic object stored in aggregation buffer. | ||
| serializeObjectAggregationBuffer(sortBasedAggregationBuffer) | ||
| generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql | ||
|
|
||
| import org.apache.spark.sql.AggregateWithObjectAggregateBufferSuite.MaxWithObjectAggregateBuffer | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, GenericMutableRow, MutableRow, UnsafeRow} | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, WithObjectAggregateBuffer} | ||
| import org.apache.spark.sql.execution.aggregate.{SortAggregateExec} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.test.SharedSQLContext | ||
| import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, StructType} | ||
|
|
||
| class AggregateWithObjectAggregateBufferSuite extends QueryTest with SharedSQLContext { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also put a basic test in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will not use HashAggregationExec, so there is no point to fallback from HashAggregationExec?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh right, I misread the code. |
||
|
|
||
| import testImplicits._ | ||
|
|
||
| private val data = Seq((1, 0), (3, 1), (2, 0), (6, 3), (3, 1), (4, 1), (5, 0)) | ||
|
|
||
|
|
||
| test("aggregate with object aggregate buffer, should not use HashAggregate") { | ||
| val df = data.toDF("a", "b") | ||
| val max = new MaxWithObjectAggregateBuffer($"a".expr) | ||
|
|
||
| // Always use SortAggregateExec instead of HashAggregateExec for planning even if the aggregate | ||
| // buffer attributes are mutable fields (every field can be mutated inline like int, long...) | ||
| val allFieldsMutable = max.aggBufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) | ||
| val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan | ||
| assert(allFieldsMutable == true && sparkPlan.isInstanceOf[SortAggregateExec]) | ||
| } | ||
|
|
||
| test("aggregate with object aggregate buffer, no group by") { | ||
| val df = data.toDF("a", "b").coalesce(2) | ||
| checkAnswer( | ||
| df.select(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"b"), count($"b")), | ||
| Seq(Row(6, 7, 3, 7)) | ||
| ) | ||
| } | ||
|
|
||
| test("aggregate with object aggregate buffer, with group by") { | ||
| val df = data.toDF("a", "b").coalesce(2) | ||
| checkAnswer( | ||
| df.groupBy($"b").agg(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"a")), | ||
| Seq( | ||
| Row(0, 5, 3, 5), | ||
| Row(1, 4, 3, 4), | ||
| Row(3, 6, 1, 6) | ||
| ) | ||
| ) | ||
| } | ||
|
|
||
| test("aggregate with object aggregate buffer, empty inputs, no group by") { | ||
| val empty = Seq.empty[(Int, Int)].toDF("a", "b") | ||
| checkAnswer( | ||
| empty.select(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"b"), count($"b")), | ||
| Seq(Row(Int.MinValue, 0, Int.MinValue, 0))) | ||
| } | ||
|
|
||
| test("aggregate with object aggregate buffer, empty inputs, with group by") { | ||
| val empty = Seq.empty[(Int, Int)].toDF("a", "b") | ||
| checkAnswer( | ||
| empty.groupBy($"b").agg(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"a")), | ||
| Seq.empty[Row]) | ||
| } | ||
|
|
||
| private def objectAggregateMax(column: Column): Column = { | ||
| val max = MaxWithObjectAggregateBuffer(column.expr) | ||
| Column(max.toAggregateExpression()) | ||
| } | ||
| } | ||
|
|
||
| object AggregateWithObjectAggregateBufferSuite { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (we do not need to put the example class inside this object.)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I use the companion object to define a private scope. |
||
|
|
||
| /** | ||
| * Calculate the max value with object aggregation buffer. This stores object of class MaxValue | ||
| * in aggregation buffer. | ||
| */ | ||
| private case class MaxWithObjectAggregateBuffer( | ||
| child: Expression, | ||
| mutableAggBufferOffset: Int = 0, | ||
| inputAggBufferOffset: Int = 0) extends ImperativeAggregate with WithObjectAggregateBuffer { | ||
|
|
||
| override def withNewMutableAggBufferOffset(newOffset: Int): ImperativeAggregate = | ||
| copy(mutableAggBufferOffset = newOffset) | ||
|
|
||
| override def withNewInputAggBufferOffset(newOffset: Int): ImperativeAggregate = | ||
| copy(inputAggBufferOffset = newOffset) | ||
|
|
||
| // Stores a generic object MaxValue in aggregation buffer. | ||
| override def initialize(buffer: MutableRow): Unit = { | ||
| // Makes sure we are using an unsafe row for aggregation buffer. | ||
| assert(buffer.isInstanceOf[GenericMutableRow]) | ||
| buffer.update(mutableAggBufferOffset, new MaxValue(Int.MinValue)) | ||
| } | ||
|
|
||
| override def update(buffer: MutableRow, input: InternalRow): Unit = { | ||
| val inputValue = child.eval(input).asInstanceOf[Int] | ||
| val maxValue = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue] | ||
| if (inputValue > maxValue.value) { | ||
| maxValue.value = inputValue | ||
| } | ||
| } | ||
|
|
||
| override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { | ||
| val bufferMax = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue] | ||
| val inputMax = deserialize(inputBuffer, inputAggBufferOffset) | ||
| if (inputMax.value > bufferMax.value) { | ||
| bufferMax.value = inputMax.value | ||
| } | ||
| } | ||
|
|
||
| private def deserialize(buffer: InternalRow, offset: Int): MaxValue = { | ||
| new MaxValue((buffer.getInt(offset))) | ||
| } | ||
|
|
||
| override def serializeObjectAggregateBuffer(buffer: InternalRow, target: MutableRow): Unit = { | ||
| val bufferMax = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue] | ||
| target(mutableAggBufferOffset) = bufferMax.value | ||
| } | ||
|
|
||
| override def eval(buffer: InternalRow): Any = { | ||
| val max = deserialize(buffer, mutableAggBufferOffset) | ||
| max.value | ||
| } | ||
|
|
||
| override val aggBufferAttributes: Seq[AttributeReference] = | ||
| Seq(AttributeReference("buf", IntegerType)()) | ||
|
|
||
| override lazy val inputAggBufferAttributes: Seq[AttributeReference] = | ||
| aggBufferAttributes.map(_.newInstance()) | ||
|
|
||
| override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) | ||
| override def dataType: DataType = IntegerType | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) | ||
| override def nullable: Boolean = true | ||
| override def deterministic: Boolean = false | ||
| override def children: Seq[Expression] = Seq(child) | ||
| } | ||
|
|
||
| private class MaxValue(var value: Int) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: traits => trait