Skip to content

Commit 0cfefa7

Browse files
pgandhicloud-fan
authored andcommitted
[SPARK-24935][SQL] fix Hive UDAF with two aggregation buffers
## What changes were proposed in this pull request? Hive UDAF knows the aggregation mode when creating the aggregation buffer, so that it can create different buffers for different inputs: the original data or the aggregation buffer. Please see an example in the [sketches library](https://github.com/DataSketches/sketches-hive/blob/7f9e76e9e03807277146291beb2c7bec40e8672b/src/main/java/com/yahoo/sketches/hive/cpc/DataToSketchUDAF.java#L107). However, the Hive UDAF adapter in Spark always creates the buffer with partial1 mode, which can only deal with one input: the original data. This PR fixes it. All credits go to pgandhi999 , who investigate the problem and study the Hive UDAF behaviors, and write the tests. close #23778 ## How was this patch tested? a new test Closes #24144 from cloud-fan/hive. Lead-authored-by: pgandhi <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: gatorsmile <[email protected]> (cherry picked from commit a6c207c) Signed-off-by: gatorsmile <[email protected]>
1 parent 3fc626d commit 0cfefa7

File tree

2 files changed

+147
-23
lines changed

2 files changed

+147
-23
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -352,29 +352,21 @@ private[hive] case class HiveUDAFFunction(
352352
HiveEvaluator(evaluator, evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors))
353353
}
354354

355-
// The UDAF evaluator used to merge partial aggregation results.
355+
// The UDAF evaluator used to consume partial aggregation results and produce final results.
356+
// Hive `ObjectInspector` used to inspect final results.
356357
@transient
357-
private lazy val partial2ModeEvaluator = {
358+
private lazy val finalHiveEvaluator = {
358359
val evaluator = newEvaluator()
359-
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partial1HiveEvaluator.objectInspector))
360-
evaluator
360+
HiveEvaluator(
361+
evaluator,
362+
evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
361363
}
362364

363365
// Spark SQL data type of partial aggregation results
364366
@transient
365367
private lazy val partialResultDataType =
366368
inspectorToDataType(partial1HiveEvaluator.objectInspector)
367369

368-
// The UDAF evaluator used to compute the final result from a partial aggregation result objects.
369-
// Hive `ObjectInspector` used to inspect the final aggregation result object.
370-
@transient
371-
private lazy val finalHiveEvaluator = {
372-
val evaluator = newEvaluator()
373-
HiveEvaluator(
374-
evaluator,
375-
evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
376-
}
377-
378370
// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
379371
@transient
380372
private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
@@ -401,25 +393,43 @@ private[hive] case class HiveUDAFFunction(
401393
s"$name($distinct${children.map(_.sql).mkString(", ")})"
402394
}
403395

404-
override def createAggregationBuffer(): AggregationBuffer =
405-
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
396+
// The hive UDAF may create different buffers to handle different inputs: original data or
397+
// aggregate buffer. However, the Spark UDAF framework does not expose this information when
398+
// creating the buffer. Here we return null, and create the buffer in `update` and `merge`
399+
// on demand, so that we can know what input we are dealing with.
400+
override def createAggregationBuffer(): AggregationBuffer = null
406401

407402
@transient
408403
private lazy val inputProjection = UnsafeProjection.create(children)
409404

410405
override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
406+
// The input is original data, we create buffer with the partial1 evaluator.
407+
val nonNullBuffer = if (buffer == null) {
408+
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
409+
} else {
410+
buffer
411+
}
412+
411413
partial1HiveEvaluator.evaluator.iterate(
412-
buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
413-
buffer
414+
nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
415+
nonNullBuffer
414416
}
415417

416418
override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = {
419+
// The input is aggregate buffer, we create buffer with the final evaluator.
420+
val nonNullBuffer = if (buffer == null) {
421+
finalHiveEvaluator.evaluator.getNewAggregationBuffer
422+
} else {
423+
buffer
424+
}
425+
417426
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
418427
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
419428
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
420429
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
421-
partial2ModeEvaluator.merge(buffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
422-
buffer
430+
finalHiveEvaluator.evaluator.merge(
431+
nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
432+
nonNullBuffer
423433
}
424434

425435
override def eval(buffer: AggregationBuffer): Any = {
@@ -450,11 +460,19 @@ private[hive] case class HiveUDAFFunction(
450460
private val mutableRow = new GenericInternalRow(1)
451461

452462
def serialize(buffer: AggregationBuffer): Array[Byte] = {
463+
// The buffer may be null if there is no input. It's unclear if the hive UDAF accepts null
464+
// buffer, for safety we create an empty buffer here.
465+
val nonNullBuffer = if (buffer == null) {
466+
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
467+
} else {
468+
buffer
469+
}
470+
453471
// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
454472
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
455473
// Then we can unwrap it to a Spark SQL value.
456474
mutableRow.update(0, partialResultUnwrapper(
457-
partial1HiveEvaluator.evaluator.terminatePartial(buffer)))
475+
partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer)))
458476
val unsafeRow = projection(mutableRow)
459477
val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
460478
unsafeRow.writeTo(bytes)
@@ -466,11 +484,11 @@ private[hive] case class HiveUDAFFunction(
466484
// returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The
467485
// workaround here is creating an initial `AggregationBuffer` first and then merge the
468486
// deserialized object into the buffer.
469-
val buffer = partial2ModeEvaluator.getNewAggregationBuffer
487+
val buffer = finalHiveEvaluator.evaluator.getNewAggregationBuffer
470488
val unsafeRow = new UnsafeRow(1)
471489
unsafeRow.pointTo(bytes, bytes.length)
472490
val partialResult = unsafeRow.get(0, partialResultDataType)
473-
partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult))
491+
finalHiveEvaluator.evaluator.merge(buffer, partialResultWrapper(partialResult))
474492
buffer
475493
}
476494
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
2828
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
2929
import test.org.apache.spark.sql.MyDoubleAvg
3030

31+
import org.apache.spark.SparkException
3132
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
3233
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
3334
import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -39,6 +40,7 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
3940
protected override def beforeAll(): Unit = {
4041
sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'")
4142
sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
43+
sql(s"CREATE TEMPORARY FUNCTION mock2 AS '${classOf[MockUDAF2].getName}'")
4244

4345
Seq(
4446
(0: Integer) -> "val_0",
@@ -91,6 +93,23 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
9193
))
9294
}
9395

96+
test("customized Hive UDAF with two aggregation buffers") {
97+
val df = sql("SELECT key % 2, mock2(value) FROM t GROUP BY key % 2")
98+
99+
val aggs = df.queryExecution.executedPlan.collect {
100+
case agg: ObjectHashAggregateExec => agg
101+
}
102+
103+
// There should be two aggregate operators, one for partial aggregation, and the other for
104+
// global aggregation.
105+
assert(aggs.length == 2)
106+
107+
checkAnswer(df, Seq(
108+
Row(0, Row(1, 1)),
109+
Row(1, Row(1, 1))
110+
))
111+
}
112+
94113
test("call JAVA UDAF") {
95114
withTempView("temp") {
96115
withUserDefinedFunction("myDoubleAvg" -> false) {
@@ -126,12 +145,22 @@ class MockUDAF extends AbstractGenericUDAFResolver {
126145
override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator
127146
}
128147

148+
class MockUDAF2 extends AbstractGenericUDAFResolver {
149+
override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator2
150+
}
151+
129152
class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long)
130153
extends GenericUDAFEvaluator.AbstractAggregationBuffer {
131154

132155
override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
133156
}
134157

158+
class MockUDAFBuffer2(var nonNullCount: Long, var nullCount: Long)
159+
extends GenericUDAFEvaluator.AbstractAggregationBuffer {
160+
161+
override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
162+
}
163+
135164
class MockUDAFEvaluator extends GenericUDAFEvaluator {
136165
private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
137166

@@ -183,3 +212,80 @@ class MockUDAFEvaluator extends GenericUDAFEvaluator {
183212

184213
override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg)
185214
}
215+
216+
// Same as MockUDAFEvaluator but using two aggregation buffers, one for PARTIAL1 and the other
217+
// for PARTIAL2.
218+
class MockUDAFEvaluator2 extends GenericUDAFEvaluator {
219+
private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
220+
221+
private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
222+
private var aggMode: Mode = null
223+
224+
private val bufferOI = {
225+
val fieldNames = Seq("nonNullCount", "nullCount").asJava
226+
val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava
227+
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
228+
}
229+
230+
private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount")
231+
232+
private val nullCountField = bufferOI.getStructFieldRef("nullCount")
233+
234+
override def getNewAggregationBuffer: AggregationBuffer = {
235+
// These 2 modes consume original data.
236+
if (aggMode == Mode.PARTIAL1 || aggMode == Mode.COMPLETE) {
237+
new MockUDAFBuffer(0L, 0L)
238+
} else {
239+
new MockUDAFBuffer2(0L, 0L)
240+
}
241+
}
242+
243+
override def reset(agg: AggregationBuffer): Unit = {
244+
val buffer = agg.asInstanceOf[MockUDAFBuffer]
245+
buffer.nonNullCount = 0L
246+
buffer.nullCount = 0L
247+
}
248+
249+
override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = {
250+
aggMode = mode
251+
bufferOI
252+
}
253+
254+
override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = {
255+
val buffer = agg.asInstanceOf[MockUDAFBuffer]
256+
if (parameters.head eq null) {
257+
buffer.nullCount += 1L
258+
} else {
259+
buffer.nonNullCount += 1L
260+
}
261+
}
262+
263+
override def merge(agg: AggregationBuffer, partial: Object): Unit = {
264+
if (partial ne null) {
265+
val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField))
266+
val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField))
267+
val buffer = agg.asInstanceOf[MockUDAFBuffer2]
268+
buffer.nonNullCount += nonNullCount
269+
buffer.nullCount += nullCount
270+
}
271+
}
272+
273+
// As this method is called for both states, Partial1 and Partial2, the hack in the method
274+
// to check for class of aggregation buffer was necessary.
275+
override def terminatePartial(agg: AggregationBuffer): AnyRef = {
276+
var result: AnyRef = null
277+
if (agg.getClass.toString.contains("MockUDAFBuffer2")) {
278+
val buffer = agg.asInstanceOf[MockUDAFBuffer2]
279+
result = Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
280+
} else {
281+
val buffer = agg.asInstanceOf[MockUDAFBuffer]
282+
result = Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
283+
}
284+
result
285+
}
286+
287+
override def terminate(agg: AggregationBuffer): AnyRef = {
288+
val buffer = agg.asInstanceOf[MockUDAFBuffer2]
289+
Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
290+
}
291+
}

0 commit comments

Comments
 (0)