Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,13 @@ private[hive] case class HiveGenericUDTF(
* - `wrap()`/`wrapperFor()`: from 3 to 1
* - `unwrap()`/`unwrapperFor()`: from 1 to 3
* - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3
*
* Note that, Hive UDAF is initialized with aggregate mode, and some specific Hive UDAFs can't
* mix UPDATE and MERGE actions during its life cycle. However, Spark may do UPDATE on a UDAF and
* then do MERGE, in case of hash aggregate falling back to sort aggregate. To work around this
* issue, we track the ability to do MERGE in the Hive UDAF aggregate buffer. If Spark does
* UPDATE then MERGE, we can detect it and re-create the aggregate buffer with a different
* aggregate mode.
*/
private[hive] case class HiveUDAFFunction(
name: String,
Expand All @@ -312,7 +319,7 @@ private[hive] case class HiveUDAFFunction(
isUDAFBridgeRequired: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
extends TypedImperativeAggregate[HiveUDAFBuffer]
with HiveInspectors
with UserDefinedExpression {

Expand Down Expand Up @@ -410,55 +417,70 @@ private[hive] case class HiveUDAFFunction(
// aggregate buffer. However, the Spark UDAF framework does not expose this information when
// creating the buffer. Here we return null, and create the buffer in `update` and `merge`
// on demand, so that we can know what input we are dealing with.
override def createAggregationBuffer(): AggregationBuffer = null
override def createAggregationBuffer(): HiveUDAFBuffer = null

@transient
private lazy val inputProjection = UnsafeProjection.create(children)

override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
override def update(buffer: HiveUDAFBuffer, input: InternalRow): HiveUDAFBuffer = {
// The input is original data, we create buffer with the partial1 evaluator.
val nonNullBuffer = if (buffer == null) {
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
HiveUDAFBuffer(partial1HiveEvaluator.evaluator.getNewAggregationBuffer, false)
} else {
buffer
}

assert(!nonNullBuffer.canDoMerge, "can not call `merge` then `update` on a Hive UDAF.")

partial1HiveEvaluator.evaluator.iterate(
nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
nonNullBuffer.buf, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
nonNullBuffer
}

override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = {
override def merge(buffer: HiveUDAFBuffer, input: HiveUDAFBuffer): HiveUDAFBuffer = {
// The input is aggregate buffer, we create buffer with the final evaluator.
val nonNullBuffer = if (buffer == null) {
finalHiveEvaluator.evaluator.getNewAggregationBuffer
HiveUDAFBuffer(finalHiveEvaluator.evaluator.getNewAggregationBuffer, true)
} else {
buffer
}

// It's possible that we've called `update` of this Hive UDAF, and some specific Hive UDAF
// implementation can't mix the `update` and `merge` calls during its life cycle. To work
// around it, here we create a fresh buffer with final evaluator, and merge the existing buffer
// to it, and replace the existing buffer with it.
val mergeableBuf = if (!nonNullBuffer.canDoMerge) {
val newBuf = finalHiveEvaluator.evaluator.getNewAggregationBuffer
finalHiveEvaluator.evaluator.merge(
newBuf, partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer.buf))
HiveUDAFBuffer(newBuf, true)
} else {
nonNullBuffer
}

// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
finalHiveEvaluator.evaluator.merge(
nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
nonNullBuffer
mergeableBuf.buf, partial1HiveEvaluator.evaluator.terminatePartial(input.buf))
mergeableBuf
}

override def eval(buffer: AggregationBuffer): Any = {
resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
override def eval(buffer: HiveUDAFBuffer): Any = {
resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer.buf))
}

override def serialize(buffer: AggregationBuffer): Array[Byte] = {
override def serialize(buffer: HiveUDAFBuffer): Array[Byte] = {
// Serializes an `AggregationBuffer` that holds partial aggregation results so that we can
// shuffle it for global aggregation later.
aggBufferSerDe.serialize(buffer)
aggBufferSerDe.serialize(buffer.buf)
}

override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
override def deserialize(bytes: Array[Byte]): HiveUDAFBuffer = {
// Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare
// for global aggregation by merging multiple partial aggregation results within a single group.
aggBufferSerDe.deserialize(bytes)
HiveUDAFBuffer(aggBufferSerDe.deserialize(bytes), false)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once the value of canDoMerge is always set false after deserialization, in the merge() function, the aggregationBuffer will be always re-created even the passed buffer parameter is actually a Partial2 or Final state. This, correct me if I am wrong, is a flaw causing performance downgrade.
May need to do none-trivial work in serialize() to include the state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the deserialized buffer can only appear as the second parameter in merge, so canDoMerge doesn't matter here.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, except for the case of falling back from hash agg, and that's what you want to address here, and this is not impacting spark udaf. The logic looks clear and good to me, thanks!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

}

// Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
Expand Down Expand Up @@ -506,3 +528,5 @@ private[hive] case class HiveUDAFFunction(
}
}
}

case class HiveUDAFBuffer(buf: AggregationBuffer, canDoMerge: Boolean)
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
import test.org.apache.spark.sql.MyDoubleAvg

import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils

class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
Expand Down Expand Up @@ -94,21 +94,33 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
))
}

test("customized Hive UDAF with two aggregation buffers") {
val df = sql("SELECT key % 2, mock2(value) FROM t GROUP BY key % 2")
test("SPARK-24935: customized Hive UDAF with two aggregation buffers") {
withTempView("v") {
spark.range(100).createTempView("v")
val df = sql("SELECT id % 2, mock2(id) FROM v GROUP BY id % 2")

val aggs = df.queryExecution.executedPlan.collect {
case agg: ObjectHashAggregateExec => agg
}
val aggs = df.queryExecution.executedPlan.collect {
case agg: ObjectHashAggregateExec => agg
}

// There should be two aggregate operators, one for partial aggregation, and the other for
// global aggregation.
assert(aggs.length == 2)
// There should be two aggregate operators, one for partial aggregation, and the other for
// global aggregation.
assert(aggs.length == 2)

checkAnswer(df, Seq(
Row(0, Row(1, 1)),
Row(1, Row(1, 1))
))
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1") {
checkAnswer(df, Seq(
Row(0, Row(50, 0)),
Row(1, Row(50, 0))
))
}

withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
checkAnswer(df, Seq(
Row(0, Row(50, 0)),
Row(1, Row(50, 0))
))
}
}
}

test("call JAVA UDAF") {
Expand Down