Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,126 @@ abstract class DeclarativeAggregate
def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a))
}
}

/**
* This traits allows an AggregateFunction to store **arbitrary** Java objects in internal
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: traits => trait

* aggregation buffer during aggregation of each key group. The **arbitrary** Java object can be
* used to store the accumulated aggregation result.
*
* This trait must be mixed with class ImperativeAggregate.
*
* {{{
* aggregation buffer for function avg
* | |
* v v
* +--------------+---------------+----------------------+
* | sum1 (Long) | count1 (Long) | generic java object |
* +--------------+---------------+----------------------+
* ^
* |
* Aggregation buffer for aggregation-function with WithObjectAggregateBuffer
* }}}
*
* Here is how it works in a typical aggregation flow (Partial mode aggregate at Mapper side, and
* Final mode aggregate at Reducer side).
*
* Stage 1: Partial aggregate at Mapper side:
*
* 1. Upon calling method `initialize(aggBuffer: MutableRow)`, user stores an arbitrary empty
* object, object A for example, in internal aggBuffer. The object A will be used to store the
* accumulated aggregation result.
* 2. Upon calling method `update(mutableAggBuffer: MutableRow, inputRow: InternalRow)` in
* current group (group by key), user extracts object A from mutableAggBuffer, and then updates
* object A with current inputRow. After updating, object A is stored back to mutableAggBuffer.
* 3. After processing all rows of current group, the framework will call method
* {{{
* serializeObjectAggregateBuffer(
* objectAggregateBuffer: InternalRow,
* targetBuffer: MutableRow)
* }}}
* to serialize object A stored in objectAggregateBuffer to Spark SQL internally supported
* serializable format, and write the serialized format to targetBuffer MutableRow.
* The framework may persist the targetBuffer to disk if there is not enough memory, it is safe
* as all fields of targetBuffer MutableRow are serializable
* 4. The framework moves on to next group, until all groups have been processed.
*
* Shuffling exchange data to Reducer tasks...
*
* Stage 2: Final mode aggregate at Reducer side:
*
* 1. Upon calling method `initialize(aggBuffer: MutableRow)`, user stores a new empty object A1
* in internal aggBuffer. The object A1 will be used to store the accumulated aggregation result.
* 2. Upon calling method `merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow)`, user
* extracts object A1 from mutableAggBuffer, and extracts object A2 from inputAggBuffer. then
* user needs to merge A1, and A2, and stores the merged result back to mutableAggBuffer.
* 3. After processing all inputAggBuffer of current group (group by key), the framework will
* call method:
* {{{
* serializeObjectAggregateBuffer(
* objectAggregateBuffer: InternalRow,
* targetBuffer: MutableRow)
* }}}
* to serialize object A1 stored in objectAggregateBuffer to Spark SQL internally supported
* serializable format, and store the serialized format to targetBuffer MutableRow. The
* framework may persist the targetBuffer to disk if there is not enough memory, it is safe as
* all fields of targetBuffer MutableRow are serializable.
* 4. The framework moves on to next group, until all groups have been processed.
*/
trait WithObjectAggregateBuffer {
this: ImperativeAggregate =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semes we do not really need this line.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess having this line will make this trait hard to be used in Java.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, seems this trait will be still an java interface. But, I think in general, we do not really need to have this line.


/**
* Serializes the object stored at objectAggregateBuffer's index mutableAggBufferOffset to
* Spark SQL internally supported serializable format, and writes the serialized format
* to targetBuffer's index mutableAggBufferOffset.
*
* The framework calls this method every time after finishing updating/merging one
* group (group by key).
*
* - Source object aggregation buffer Before serialization:
* The object stored in object aggregation buffer can be **arbitrary** Java object type
* defined by user.
*
* - Target mutable buffer after serialization:
* The target mutable buffer is of type MutableRow. Each field's type need to be one of
* Spark SQL internally supported serializable formats, which are:
* - Null
* - Boolean
* - Byte
* - Short
* - Int
* - Long
* - Float
* - Double
* - Array[Byte]
* - org.apache.spark.sql.types.Decimal
* - org.apache.spark.unsafe.types.UTF8String
* - org.apache.spark.unsafe.types.CalendarInterval
* - org.apache.spark.sql.catalyst.util.MapData
* - org.apache.spark.sql.catalyst.util.ArrayData
* - org.apache.spark.sql.catalyst.InternalRow
*
* Code example:
*
* {{{
* def serializeObjectAggregateBuffer(
* objectAggregateBuffer: InternalRow,
* targetBuffer: MutableRow): Unit = {
* val obj = buffer.get(mutableAggBufferOffset, ObjectType(classOf[A])).asInstanceOf[A]
* // Convert the obj to Spark SQL internally supported serializable format( here it is
* // Array[Byte])
* targetBuffer(mutableAggBufferOffset) = toBytes(obj)
* }
* }}}
*
* @param objectAggregateBuffer Source object aggregation buffer. Please use the index
* mutableAggBufferOffset to get buffered object of this aggregation
* function.
* @param targetBuffer Target buffer to hold the serialized format. Please use the index
* mutableAggBufferOffset to store the serialized format for this aggregation
* function.
*/
def serializeObjectAggregateBuffer(
objectAggregateBuffer: InternalRow,
targetBuffer: MutableRow): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,16 @@ object AggUtils {
initialInputBufferOffset: Int = 0,
resultExpressions: Seq[NamedExpression] = Nil,
child: SparkPlan): SparkPlan = {
val useHash = HashAggregateExec.supportsAggregate(

val isUsingObjectAggregationBuffer: Boolean = aggregateExpressions.exists {
case AggregateExpression(agg: WithObjectAggregateBuffer, _, _, _) => true
case _ => false
}

val aggBufferAttributesSupportedByHashAggregate = HashAggregateExec.supportsAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
if (useHash) {

if (aggBufferAttributesSupportedByHashAggregate && !isUsingObjectAggregationBuffer) {
HashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, WithObjectAggregateBuffer}
import org.apache.spark.sql.execution.metric.SQLMetric

/**
Expand Down Expand Up @@ -54,7 +54,15 @@ class SortBasedAggregationIterator(
val bufferRowSize: Int = bufferSchema.length

val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)

val allFieldsMutable = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)

val isUsingObjectAggregationBuffer = aggregateFunctions.exists {
case agg: WithObjectAggregateBuffer => true
case _ => false
}

val useUnsafeBuffer = allFieldsMutable && !isUsingObjectAggregationBuffer

val buffer = if (useUnsafeBuffer) {
val unsafeProjection =
Expand Down Expand Up @@ -90,6 +98,22 @@ class SortBasedAggregationIterator(
// compared to MutableRow (aggregation buffer) directly.
private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType))

// AggregationFunction which store generic object in AggregationBuffer.
// @see [[WithObjectAggregationBuffer]] for more information
private val aggFunctionsWithObjectAggregationBuffer = aggregateFunctions.collect {
case (ag: ImperativeAggregate with WithObjectAggregateBuffer) => ag
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we make WithObjectAggregateBuffer extends ImperativeAggregate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ImperativeAggregate is an abstract class, that will make WithObjectAggregateBuffer quite heavy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heavy? In what sense?

}

// For AggregateFunction with generic object stored in aggregation buffer, we need to
// call serializeObjectAggregationBufferInPlace() explicitly to convert the generic object
// stored in aggregation buffer to serializable format.
private def serializeObjectAggregationBuffer(aggregationBuffer: MutableRow): Unit = {
aggFunctionsWithObjectAggregationBuffer.foreach { agg =>
// Serializes and **in-place** replaces the object stored in sortBasedAggregationBuffer
agg.serializeObjectAggregateBuffer(sortBasedAggregationBuffer, sortBasedAggregationBuffer)
}
}

protected def initialize(): Unit = {
if (inputIterator.hasNext) {
initializeBuffer(sortBasedAggregationBuffer)
Expand Down Expand Up @@ -131,6 +155,10 @@ class SortBasedAggregationIterator(
firstRowInNextGroup = currentRow.copy()
}
}

// Serializes the generic object stored in aggregation buffer.
serializeObjectAggregationBuffer(sortBasedAggregationBuffer)

// We have not seen a new group. It means that there is no new row in the input
// iter. The current group is the last group of the iter.
if (!findNextPartition) {
Expand Down Expand Up @@ -162,6 +190,8 @@ class SortBasedAggregationIterator(

def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
initializeBuffer(sortBasedAggregationBuffer)
// Serializes the generic object stored in aggregation buffer.
serializeObjectAggregationBuffer(sortBasedAggregationBuffer)
generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.sql.AggregateWithObjectAggregateBufferSuite.MaxWithObjectAggregateBuffer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, GenericMutableRow, MutableRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, WithObjectAggregateBuffer}
import org.apache.spark.sql.execution.aggregate.{SortAggregateExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, StructType}

class AggregateWithObjectAggregateBufferSuite extends QueryTest with SharedSQLContext {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also put a basic test in HashAggregationQueryWithControlledFallbackSuite, to test the fallback.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will not use HashAggregationExec, so there is no point to fallback from HashAggregationExec?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right, I misread the code.


import testImplicits._

private val data = Seq((1, 0), (3, 1), (2, 0), (6, 3), (3, 1), (4, 1), (5, 0))


test("aggregate with object aggregate buffer, should not use HashAggregate") {
val df = data.toDF("a", "b")
val max = new MaxWithObjectAggregateBuffer($"a".expr)

// Always use SortAggregateExec instead of HashAggregateExec for planning even if the aggregate
// buffer attributes are mutable fields (every field can be mutated inline like int, long...)
val allFieldsMutable = max.aggBufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)
val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan
assert(allFieldsMutable == true && sparkPlan.isInstanceOf[SortAggregateExec])
}

test("aggregate with object aggregate buffer, no group by") {
val df = data.toDF("a", "b").coalesce(2)
checkAnswer(
df.select(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"b"), count($"b")),
Seq(Row(6, 7, 3, 7))
)
}

test("aggregate with object aggregate buffer, with group by") {
val df = data.toDF("a", "b").coalesce(2)
checkAnswer(
df.groupBy($"b").agg(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"a")),
Seq(
Row(0, 5, 3, 5),
Row(1, 4, 3, 4),
Row(3, 6, 1, 6)
)
)
}

test("aggregate with object aggregate buffer, empty inputs, no group by") {
val empty = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
empty.select(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"b"), count($"b")),
Seq(Row(Int.MinValue, 0, Int.MinValue, 0)))
}

test("aggregate with object aggregate buffer, empty inputs, with group by") {
val empty = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
empty.groupBy($"b").agg(objectAggregateMax($"a"), count($"a"), objectAggregateMax($"a")),
Seq.empty[Row])
}

private def objectAggregateMax(column: Column): Column = {
val max = MaxWithObjectAggregateBuffer(column.expr)
Column(max.toAggregateExpression())
}
}

object AggregateWithObjectAggregateBufferSuite {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(we do not need to put the example class inside this object.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use the companion object to define a private scope.


/**
* Calculate the max value with object aggregation buffer. This stores object of class MaxValue
* in aggregation buffer.
*/
private case class MaxWithObjectAggregateBuffer(
child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends ImperativeAggregate with WithObjectAggregateBuffer {

override def withNewMutableAggBufferOffset(newOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newOffset)

override def withNewInputAggBufferOffset(newOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newOffset)

// Stores a generic object MaxValue in aggregation buffer.
override def initialize(buffer: MutableRow): Unit = {
// Makes sure we are using an unsafe row for aggregation buffer.
assert(buffer.isInstanceOf[GenericMutableRow])
buffer.update(mutableAggBufferOffset, new MaxValue(Int.MinValue))
}

override def update(buffer: MutableRow, input: InternalRow): Unit = {
val inputValue = child.eval(input).asInstanceOf[Int]
val maxValue = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue]
if (inputValue > maxValue.value) {
maxValue.value = inputValue
}
}

override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = {
val bufferMax = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue]
val inputMax = deserialize(inputBuffer, inputAggBufferOffset)
if (inputMax.value > bufferMax.value) {
bufferMax.value = inputMax.value
}
}

private def deserialize(buffer: InternalRow, offset: Int): MaxValue = {
new MaxValue((buffer.getInt(offset)))
}

override def serializeObjectAggregateBuffer(buffer: InternalRow, target: MutableRow): Unit = {
val bufferMax = buffer.get(mutableAggBufferOffset, null).asInstanceOf[MaxValue]
target(mutableAggBufferOffset) = bufferMax.value
}

override def eval(buffer: InternalRow): Any = {
val max = deserialize(buffer, mutableAggBufferOffset)
max.value
}

override val aggBufferAttributes: Seq[AttributeReference] =
Seq(AttributeReference("buf", IntegerType)())

override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())

override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
override def nullable: Boolean = true
override def deterministic: Boolean = false
override def children: Seq[Expression] = Seq(child)
}

private class MaxValue(var value: Int)
}