From 3d523b6996157026d61edc705054437dee362f6d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 25 Apr 2019 21:22:04 +0800 Subject: [PATCH] support INIT -> UPDATE -> MERGE -> FINISH in Hive UDAF adapter --- .../org/apache/spark/sql/hive/hiveUDFs.scala | 54 +++++++++++++------ .../sql/hive/execution/HiveUDAFSuite.scala | 38 ++++++++----- 2 files changed, 64 insertions(+), 28 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 0938576a71472..76e4085712bdd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -304,6 +304,13 @@ private[hive] case class HiveGenericUDTF( * - `wrap()`/`wrapperFor()`: from 3 to 1 * - `unwrap()`/`unwrapperFor()`: from 1 to 3 * - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3 + * + * Note that, Hive UDAF is initialized with aggregate mode, and some specific Hive UDAFs can't + * mix UPDATE and MERGE actions during its life cycle. However, Spark may do UPDATE on a UDAF and + * then do MERGE, in case of hash aggregate falling back to sort aggregate. To work around this + * issue, we track the ability to do MERGE in the Hive UDAF aggregate buffer. If Spark does + * UPDATE then MERGE, we can detect it and re-create the aggregate buffer with a different + * aggregate mode. */ private[hive] case class HiveUDAFFunction( name: String, @@ -312,7 +319,7 @@ private[hive] case class HiveUDAFFunction( isUDAFBridgeRequired: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] + extends TypedImperativeAggregate[HiveUDAFBuffer] with HiveInspectors with UserDefinedExpression { @@ -410,55 +417,70 @@ private[hive] case class HiveUDAFFunction( // aggregate buffer. However, the Spark UDAF framework does not expose this information when // creating the buffer. Here we return null, and create the buffer in `update` and `merge` // on demand, so that we can know what input we are dealing with. - override def createAggregationBuffer(): AggregationBuffer = null + override def createAggregationBuffer(): HiveUDAFBuffer = null @transient private lazy val inputProjection = UnsafeProjection.create(children) - override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = { + override def update(buffer: HiveUDAFBuffer, input: InternalRow): HiveUDAFBuffer = { // The input is original data, we create buffer with the partial1 evaluator. val nonNullBuffer = if (buffer == null) { - partial1HiveEvaluator.evaluator.getNewAggregationBuffer + HiveUDAFBuffer(partial1HiveEvaluator.evaluator.getNewAggregationBuffer, false) } else { buffer } + assert(!nonNullBuffer.canDoMerge, "can not call `merge` then `update` on a Hive UDAF.") + partial1HiveEvaluator.evaluator.iterate( - nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes)) + nonNullBuffer.buf, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes)) nonNullBuffer } - override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = { + override def merge(buffer: HiveUDAFBuffer, input: HiveUDAFBuffer): HiveUDAFBuffer = { // The input is aggregate buffer, we create buffer with the final evaluator. val nonNullBuffer = if (buffer == null) { - finalHiveEvaluator.evaluator.getNewAggregationBuffer + HiveUDAFBuffer(finalHiveEvaluator.evaluator.getNewAggregationBuffer, true) } else { buffer } + // It's possible that we've called `update` of this Hive UDAF, and some specific Hive UDAF + // implementation can't mix the `update` and `merge` calls during its life cycle. To work + // around it, here we create a fresh buffer with final evaluator, and merge the existing buffer + // to it, and replace the existing buffer with it. + val mergeableBuf = if (!nonNullBuffer.canDoMerge) { + val newBuf = finalHiveEvaluator.evaluator.getNewAggregationBuffer + finalHiveEvaluator.evaluator.merge( + newBuf, partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer.buf)) + HiveUDAFBuffer(newBuf, true) + } else { + nonNullBuffer + } + // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion. finalHiveEvaluator.evaluator.merge( - nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input)) - nonNullBuffer + mergeableBuf.buf, partial1HiveEvaluator.evaluator.terminatePartial(input.buf)) + mergeableBuf } - override def eval(buffer: AggregationBuffer): Any = { - resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer)) + override def eval(buffer: HiveUDAFBuffer): Any = { + resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer.buf)) } - override def serialize(buffer: AggregationBuffer): Array[Byte] = { + override def serialize(buffer: HiveUDAFBuffer): Array[Byte] = { // Serializes an `AggregationBuffer` that holds partial aggregation results so that we can // shuffle it for global aggregation later. - aggBufferSerDe.serialize(buffer) + aggBufferSerDe.serialize(buffer.buf) } - override def deserialize(bytes: Array[Byte]): AggregationBuffer = { + override def deserialize(bytes: Array[Byte]): HiveUDAFBuffer = { // Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare // for global aggregation by merging multiple partial aggregation results within a single group. - aggBufferSerDe.deserialize(bytes) + HiveUDAFBuffer(aggBufferSerDe.deserialize(bytes), false) } // Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects @@ -506,3 +528,5 @@ private[hive] case class HiveUDAFFunction( } } } + +case class HiveUDAFBuffer(buf: AggregationBuffer, canDoMerge: Boolean) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index ef40323c61315..3252cdafa1be1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -28,10 +28,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo import test.org.apache.spark.sql.MyDoubleAvg -import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { @@ -94,21 +94,33 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { )) } - test("customized Hive UDAF with two aggregation buffers") { - val df = sql("SELECT key % 2, mock2(value) FROM t GROUP BY key % 2") + test("SPARK-24935: customized Hive UDAF with two aggregation buffers") { + withTempView("v") { + spark.range(100).createTempView("v") + val df = sql("SELECT id % 2, mock2(id) FROM v GROUP BY id % 2") - val aggs = df.queryExecution.executedPlan.collect { - case agg: ObjectHashAggregateExec => agg - } + val aggs = df.queryExecution.executedPlan.collect { + case agg: ObjectHashAggregateExec => agg + } - // There should be two aggregate operators, one for partial aggregation, and the other for - // global aggregation. - assert(aggs.length == 2) + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) - checkAnswer(df, Seq( - Row(0, Row(1, 1)), - Row(1, Row(1, 1)) - )) + withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1") { + checkAnswer(df, Seq( + Row(0, Row(50, 0)), + Row(1, Row(50, 0)) + )) + } + + withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") { + checkAnswer(df, Seq( + Row(0, Row(50, 0)), + Row(1, Row(50, 0)) + )) + } + } } test("call JAVA UDAF") {