Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -524,23 +524,142 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
/** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */
def deserialize(storageFormat: Array[Byte]): T

override def initialize(buffer: InternalRow): Unit = {
buffer(mutableAggBufferOffset) = createAggregationBuffer()
}

override def update(buffer: InternalRow, input: InternalRow): Unit = {
buffer(mutableAggBufferOffset) = update(getBufferObject(buffer), input)
}

override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
val bufferObject = getBufferObject(buffer)
// The inputBuffer stores serialized aggregation buffer object produced by partial aggregate
val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset))
buffer(mutableAggBufferOffset) = merge(bufferObject, inputObject)
}

override def eval(buffer: InternalRow): Any = {
eval(getBufferObject(buffer))
}

private[this] val anyObjectType = ObjectType(classOf[AnyRef])
private def getBufferObject(bufferRow: InternalRow): T = {
bufferRow.get(mutableAggBufferOffset, anyObjectType).asInstanceOf[T]
}

override lazy val aggBufferAttributes: Seq[AttributeReference] = {
// Underlying storage type for the aggregation buffer object
Seq(AttributeReference("buf", BinaryType)())
}

override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())

override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)

/**
* In-place replaces the aggregation buffer object stored at buffer's index
* `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format
* (BinaryType).
*
* This is only called when doing Partial or PartialMerge mode aggregation, before the framework
* shuffle out aggregate buffers.
*/
def serializeAggregateBufferInPlace(buffer: InternalRow): Unit = {
buffer(mutableAggBufferOffset) = serialize(getBufferObject(buffer))
}
}

/**
* Aggregation function which allows **arbitrary** user-defined java object to be used as internal
* aggregation buffer for Hive.
*/
abstract class HiveTypedImperativeAggregate[T] extends TypedImperativeAggregate[T] {

/**
* Creates an empty aggregation buffer object for partial 1 mode. This is called
* before processing each key group(group by key).
*
* @return an aggregation buffer object
*/
def createAggregationBuffer(): T

/**
* Creates an empty aggregation buffer object for partial 2 mode.
*
* @return an aggregation buffer object
*/
def createPartial2ModeAggregationBuffer(): T

var partial2ModeBuffer: InternalRow = _

/**
* Updates the aggregation buffer object with an input row and returns a new buffer object. For
* performance, the function may do in-place update and return it instead of constructing new
* buffer object.
*
* This is typically called when doing Partial or Complete mode aggregation.
*
* @param buffer The aggregation buffer object.
* @param input an input row
*/
def update(buffer: T, input: InternalRow): T

/**
* Merges an input aggregation object into aggregation buffer object and returns a new buffer
* object. For performance, the function may do in-place merge and return it instead of
* constructing new buffer object.
*
* This is typically called when doing PartialMerge or Final mode aggregation.
*
* @param buffer the aggregation buffer object used to store the aggregation result.
* @param input an input aggregation object. Input aggregation object can be produced by
* de-serializing the partial aggregate's output from Mapper side.
*/
def merge(buffer: T, input: T): T

/**
* Generates the final aggregation result value for current key group with the aggregation buffer
* object.
*
* Developer note: the only return types accepted by Spark are:
* - primitive types
* - InternalRow and subclasses
* - ArrayData
* - MapData
*
* @param buffer aggregation buffer object.
* @return The aggregation result of current key group
*/
def eval(buffer: T): Any

/** Serializes the aggregation buffer object T to Array[Byte] */
def serialize(buffer: T): Array[Byte]

/** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */
def deserialize(storageFormat: Array[Byte]): T

final override def initialize(buffer: InternalRow): Unit = {
partial2ModeBuffer = buffer.copy()
partial2ModeBuffer(mutableAggBufferOffset) = createPartial2ModeAggregationBuffer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little lost here. So this HiveTypedImperativeAggregate has 2 buffers? What's the difference between partial2ModeBuffer and buffer?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So Spark Catalyst has designed UDAF execution such that it uses one aggregation buffer for performing the aggregations for all UDAF operators(Sort based, Object hash based etc.) which makes sense from Spark's point of view. However, from Hive's point of view, two aggregate buffers are expected to be used, one for PARTIAL1/COMPLETE and the other for PARTIAL2/FINAL modes respectively. Since, I did not wish to redesign Catalyst UDAF structure only for Hive, I have let the original calls and buffer be as they are for PARTIAL1/COMPLETE mode and have created the partial2ModeBuffer exclusively for PARTIAL2/FINAL mode operations. Thus, to answer your question, buffer here is used for Partial1 mode operations and partial2ModeBuffer is used for Partial2 mode operations respectively. I hope that answers your question. Thank you once again for reviewing @cloud-fan .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I don't quite understand how you make the Hive UDAF work with Spark's two phase aggregate?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Can you please elaborate more on the two phase aggregate functionality by Spark? That will help me understand and answer your question better. Thank you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's start with the 5 phases of a UDAF:

  • Initialize: The aggregation buffers for PARTIAL1 Mode and PARTIAL2 Mode are created in this phase.
  • Iterate(Update) : This state processes a new row of data into the aggregation buffer created for PARTIAL1.
  • TerminatePartial: Returns the contents of the aggregation buffer.
  • Merge: Merges a partial aggregation returned by calling terminatePartial() on PARTIAL1 aggregation buffer into the current aggregation happening on PARTIAL2 aggregation buffer.
  • Terminate: Returns the final result of the aggregation stored in PARTIAL2 buffer to Hive.

In Spark, a UDAF will be run twice in two adjacent aggregate operators, called partial aggregate and final aggregate. In the partial aggregate, there are 3 steps:

  1. initialize the UDAF
  2. update UDAF with input data (so-called Iterate)
  3. return the UDAF buffer (so-called TerminatePartial)

In the final aggregate, also 3 steps:

  1. initialize the UDAF
  2. update UDAF with buffer data from the partial aggregate (so-called Merge)
  3. return final result (so-called Terminate)

But this doesn't work for the 3-phase UDAF which doesn't support partial aggregate.

Copy link
Contributor

@cloud-fan cloud-fan Mar 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Hive UDAF, when to use which agg buffer? I think this is the most important information to justify your patch. It will be better if you can point to some Hive doc/code comments.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I went through Hive docs and asked a couple of people around; officially, hive does not mention anything about using two different aggregation buffers, the main point is to have some kind of distinction between different phases of Hive.

Consider a classic map-reduce process. There are two phases: map and reduce (sometimes an optional combine phase in between). The phases can run on different nodes. The state lives within a phase and does not cross the boundaries. The map phase corresponds to the "partial1" mode (init + iterate + terminate partial). The reduce phase corresponds to the "final" mode (init + merge + terminate). The combine phase corresponds to the "partial2" mode (init + merge + terminate partial). The "complete" mode is a special shortcut to run the whole thing as a single phase (init + iterate + terminate). The bug here is about a state crossing the boundaries between the phases: initialized for one phase (mode), but then passed to a different phase. So by using different aggregation buffers, I am trying to encapsulate the corresponding state within a particular phase. The solution can also be modified to have a single aggregation buffer supporting states of different phases.

In my PR above, the assumption is that the Partial1 aggregation buffer supports phases PARTIAL1/COMPLETE and the Partial2 aggregation buffer supports phases PARTIAL2/FINAL.

I shall also paste a link to a good blog that explains the usage of aggregation buffers in a generic Hive UDAF : https://blog.dataiku.com/2013/05/01/a-complete-guide-to-writing-hive-udf

As this is also a kind of a design change problem, it is completely open to further discussions and improvements. My solution is just one of a kind solution and there are multiple solutions to achieve the same thing. However, as far as I can say, my solution is relatively cleaner and easier to understand and also it does not create a change of any manner in the way with which existing aggregation functions work with Spark SQL(does not break compatibility).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there are 4 ways to execute a UDAF

  1. init + iterate + terminate partial
  2. init + merge + terminate final
  3. init + merge + terminate partial
  4. init + iterate + terminate final

Spark doesn't really have terminate partial. The agg buffer needs to fit the spark schema so Spark can get agg buffer directly. Spark UDAF is flexible: after initialized, the buffer can be updated via either iterate or merge, the buffer can be terminated always.

IIUC init + merge + terminate final is pretty common in GROUP BY queries, and Hive UDAF works in this case. Do you know why?

And your test case is init + iterate + terminate final, what's the correct steps to do it? Is it

1. create partial1 buffer
2. iterate
3 turn partial1 buffer to partial2 buffer
4. terminate final

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan Sorry for the delayed response, I agree with your point above. However, I did not understand the question correctly.

IIUC init + merge + terminate final is pretty common in GROUP BY queries, and Hive UDAF works in this case. Do you know why?

And your test case is init + iterate + terminate final
Actually my test case is init + iterate + terminate partial alongwith init + iterate + merge + terminate partial and finally ending with init + merge + terminate final. So according to me, the correct steps here would roughly be:

1. create partial1 buffer
2. iterate
3. merge partial1 buffer into partial2 buffer
4. terminate final

Apologies if I have misread your above comment and have not answered it appropriately, please let me know. Thank you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make it clear, the partial 1 buffer can only be used in iterator to consume records, and partial 2 buffer can only be used in merge to consume buffers, do I understand it correctly?

buffer(mutableAggBufferOffset) = createAggregationBuffer()
}

final override def update(buffer: InternalRow, input: InternalRow): Unit = {
buffer(mutableAggBufferOffset) = update(getBufferObject(buffer), input)
partial2ModeBuffer = buffer.copy()
}

final override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
val bufferObject = getBufferObject(buffer)
val bufferObject = getBufferObject(partial2ModeBuffer)
// The inputBuffer stores serialized aggregation buffer object produced by partial aggregate
val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset))
buffer(mutableAggBufferOffset) = merge(bufferObject, inputObject)
partial2ModeBuffer(mutableAggBufferOffset) = merge(bufferObject, inputObject)
}

final override def eval(buffer: InternalRow): Any = {
eval(getBufferObject(buffer))
eval(getBufferObject(partial2ModeBuffer))
}

private[this] val anyObjectType = ObjectType(classOf[AnyRef])
Expand All @@ -566,7 +685,7 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
* This is only called when doing Partial or PartialMerge mode aggregation, before the framework
* shuffle out aggregate buffers.
*/
final def serializeAggregateBufferInPlace(buffer: InternalRow): Unit = {
final override def serializeAggregateBufferInPlace(buffer: InternalRow): Unit = {
buffer(mutableAggBufferOffset) = serialize(getBufferObject(buffer))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,21 @@ class ObjectAggregationIterator(
private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _

// Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers
private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = {
var (sortBasedAggExpressions, sortBasedAggFunctions): (
Seq[AggregateExpression], Array[AggregateFunction]) = {
val newExpressions = aggregateExpressions.map {
case agg @ AggregateExpression(_, Partial, _, _) =>
agg.copy(mode = PartialMerge)
case agg @ AggregateExpression(_, Complete, _, _) =>
agg.copy(mode = Final)
case other => other
}
val newFunctions = initializeAggregateFunctions(newExpressions, 0)
val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes)
generateProcessRow(newExpressions, newFunctions, newInputAttributes)
(newExpressions, initializeAggregateFunctions(newExpressions, 0))
}

private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = {
val newInputAttributes = sortBasedAggFunctions.flatMap(_.inputAggBufferAttributes)
generateProcessRow(sortBasedAggExpressions, sortBasedAggFunctions, newInputAttributes)
}

/**
Expand All @@ -93,7 +97,7 @@ class ObjectAggregationIterator(
*/
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
if (groupingExpressions.isEmpty) {
val defaultAggregationBuffer = createNewAggregationBuffer()
val defaultAggregationBuffer = createNewAggregationBuffer(aggregateFunctions)
generateOutput(UnsafeRow.createFromByteArray(0, 0), defaultAggregationBuffer)
} else {
throw new IllegalStateException(
Expand All @@ -106,26 +110,28 @@ class ObjectAggregationIterator(
//
// - when creating aggregation buffer for a new group in the hash map, and
// - when creating the re-used buffer for sort-based aggregation
private def createNewAggregationBuffer(): SpecificInternalRow = {
val bufferFieldTypes = aggregateFunctions.flatMap(_.aggBufferAttributes.map(_.dataType))
private def createNewAggregationBuffer(
functions: Array[AggregateFunction]): SpecificInternalRow = {
val bufferFieldTypes = functions.flatMap(_.aggBufferAttributes.map(_.dataType))
val buffer = new SpecificInternalRow(bufferFieldTypes)
initAggregationBuffer(buffer)
initAggregationBuffer(buffer, functions)
buffer
}

private def initAggregationBuffer(buffer: SpecificInternalRow): Unit = {
private def initAggregationBuffer(
buffer: SpecificInternalRow, functions: Array[AggregateFunction]): Unit = {
// Initializes declarative aggregates' buffer values
expressionAggInitialProjection.target(buffer)(EmptyRow)
// Initializes imperative aggregates' buffer values
aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer))
functions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer))
}

private def getAggregationBufferByKey(
hashMap: ObjectAggregationMap, groupingKey: UnsafeRow): InternalRow = {
var aggBuffer = hashMap.getAggregationBuffer(groupingKey)

if (aggBuffer == null) {
aggBuffer = createNewAggregationBuffer()
aggBuffer = createNewAggregationBuffer(aggregateFunctions)
hashMap.putAggregationBuffer(groupingKey.copy(), aggBuffer)
}

Expand Down Expand Up @@ -183,7 +189,7 @@ class ObjectAggregationIterator(
StructType.fromAttributes(groupingAttributes),
processRow,
mergeAggregationBuffers,
createNewAggregationBuffer())
createNewAggregationBuffer(sortBasedAggFunctions))

while (inputRows.hasNext) {
// NOTE: The input row is always UnsafeRow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ private[hive] case class HiveUDAFFunction(
isUDAFBridgeRequired: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
extends HiveTypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
with HiveInspectors
with UserDefinedExpression {

Expand Down Expand Up @@ -404,6 +404,9 @@ private[hive] case class HiveUDAFFunction(
override def createAggregationBuffer(): AggregationBuffer =
partial1HiveEvaluator.evaluator.getNewAggregationBuffer

override def createPartial2ModeAggregationBuffer(): AggregationBuffer =
partial2ModeEvaluator.getNewAggregationBuffer

@transient
private lazy val inputProjection = UnsafeProjection.create(children)

Expand Down
Loading