Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -352,29 +352,21 @@ private[hive] case class HiveUDAFFunction(
HiveEvaluator(evaluator, evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors))
}

// The UDAF evaluator used to merge partial aggregation results.
// The UDAF evaluator used to consume partial aggregation results and produce final results.
// Hive `ObjectInspector` used to inspect final results.
@transient
private lazy val partial2ModeEvaluator = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need a partial2 evaluator and a final evaluator. We just need one final evaluator.

The partial2 evaluator consumes agg buffer and produces agg buffer, while the final evaluator consumers agg buffer and produce final result. That said, the final evaluator can execute merge, and we don't need the partial2 evaluator.

private lazy val finalHiveEvaluator = {
val evaluator = newEvaluator()
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partial1HiveEvaluator.objectInspector))
evaluator
HiveEvaluator(
evaluator,
evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
}

// Spark SQL data type of partial aggregation results
@transient
private lazy val partialResultDataType =
inspectorToDataType(partial1HiveEvaluator.objectInspector)

// The UDAF evaluator used to compute the final result from a partial aggregation result objects.
// Hive `ObjectInspector` used to inspect the final aggregation result object.
@transient
private lazy val finalHiveEvaluator = {
val evaluator = newEvaluator()
HiveEvaluator(
evaluator,
evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
}

// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
@transient
private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
Expand All @@ -401,25 +393,43 @@ private[hive] case class HiveUDAFFunction(
s"$name($distinct${children.map(_.sql).mkString(", ")})"
}

override def createAggregationBuffer(): AggregationBuffer =
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
// The hive UDAF may create different buffers to handle different inputs: original data or
// aggregate buffer. However, the Spark UDAF framework does not expose this information when
// creating the buffer. Here we return null, and create the buffer in `update` and `merge`
// on demand, so that we can know what input we are dealing with.
override def createAggregationBuffer(): AggregationBuffer = null

@transient
private lazy val inputProjection = UnsafeProjection.create(children)

override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
// The input is original data, we create buffer with the partial1 evaluator.
val nonNullBuffer = if (buffer == null) {
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
} else {
buffer
}

partial1HiveEvaluator.evaluator.iterate(
buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
buffer
nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
nonNullBuffer
}

override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = {
// The input is aggregate buffer, we create buffer with the final evaluator.
val nonNullBuffer = if (buffer == null) {
finalHiveEvaluator.evaluator.getNewAggregationBuffer
} else {
buffer
}

// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
partial2ModeEvaluator.merge(buffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
buffer
finalHiveEvaluator.evaluator.merge(
nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
nonNullBuffer
}

override def eval(buffer: AggregationBuffer): Any = {
Expand Down Expand Up @@ -450,11 +460,19 @@ private[hive] case class HiveUDAFFunction(
private val mutableRow = new GenericInternalRow(1)

def serialize(buffer: AggregationBuffer): Array[Byte] = {
// The buffer may be null if there is no input. It's unclear if the hive UDAF accepts null
// buffer, for safety we create an empty buffer here.
val nonNullBuffer = if (buffer == null) {
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
} else {
buffer
}

// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
// Then we can unwrap it to a Spark SQL value.
mutableRow.update(0, partialResultUnwrapper(
partial1HiveEvaluator.evaluator.terminatePartial(buffer)))
partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer)))
val unsafeRow = projection(mutableRow)
val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
unsafeRow.writeTo(bytes)
Expand All @@ -466,11 +484,11 @@ private[hive] case class HiveUDAFFunction(
// returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The
// workaround here is creating an initial `AggregationBuffer` first and then merge the
// deserialized object into the buffer.
val buffer = partial2ModeEvaluator.getNewAggregationBuffer
val buffer = finalHiveEvaluator.evaluator.getNewAggregationBuffer
val unsafeRow = new UnsafeRow(1)
unsafeRow.pointTo(bytes, bytes.length)
val partialResult = unsafeRow.get(0, partialResultDataType)
partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult))
finalHiveEvaluator.evaluator.merge(buffer, partialResultWrapper(partialResult))
buffer
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
import test.org.apache.spark.sql.MyDoubleAvg

import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
import org.apache.spark.sql.hive.test.TestHiveSingleton
Expand All @@ -40,6 +41,7 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
super.beforeAll()
sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'")
sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
sql(s"CREATE TEMPORARY FUNCTION mock2 AS '${classOf[MockUDAF2].getName}'")

Seq(
(0: Integer) -> "val_0",
Expand Down Expand Up @@ -92,6 +94,23 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
))
}

test("customized Hive UDAF with two aggregation buffers") {
val df = sql("SELECT key % 2, mock2(value) FROM t GROUP BY key % 2")

val aggs = df.queryExecution.executedPlan.collect {
case agg: ObjectHashAggregateExec => agg
}

// There should be two aggregate operators, one for partial aggregation, and the other for
// global aggregation.
assert(aggs.length == 2)

checkAnswer(df, Seq(
Row(0, Row(1, 1)),
Row(1, Row(1, 1))
))
}

test("call JAVA UDAF") {
withTempView("temp") {
withUserDefinedFunction("myDoubleAvg" -> false) {
Expand Down Expand Up @@ -127,12 +146,22 @@ class MockUDAF extends AbstractGenericUDAFResolver {
override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator
}

class MockUDAF2 extends AbstractGenericUDAFResolver {
override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator2
}

class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long)
extends GenericUDAFEvaluator.AbstractAggregationBuffer {

override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
}

class MockUDAFBuffer2(var nonNullCount: Long, var nullCount: Long)
extends GenericUDAFEvaluator.AbstractAggregationBuffer {

override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
}

class MockUDAFEvaluator extends GenericUDAFEvaluator {
private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector

Expand Down Expand Up @@ -184,3 +213,80 @@ class MockUDAFEvaluator extends GenericUDAFEvaluator {

override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg)
}

// Same as MockUDAFEvaluator but using two aggregation buffers, one for PARTIAL1 and the other
// for PARTIAL2.
class MockUDAFEvaluator2 extends GenericUDAFEvaluator {
private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector

private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
private var aggMode: Mode = null

private val bufferOI = {
val fieldNames = Seq("nonNullCount", "nullCount").asJava
val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
}

private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount")

private val nullCountField = bufferOI.getStructFieldRef("nullCount")

override def getNewAggregationBuffer: AggregationBuffer = {
// These 2 modes consume original data.
if (aggMode == Mode.PARTIAL1 || aggMode == Mode.COMPLETE) {
new MockUDAFBuffer(0L, 0L)
} else {
new MockUDAFBuffer2(0L, 0L)
}
}

override def reset(agg: AggregationBuffer): Unit = {
val buffer = agg.asInstanceOf[MockUDAFBuffer]
buffer.nonNullCount = 0L
buffer.nullCount = 0L
}

override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = {
aggMode = mode
bufferOI
}

override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = {
val buffer = agg.asInstanceOf[MockUDAFBuffer]
if (parameters.head eq null) {
buffer.nullCount += 1L
} else {
buffer.nonNullCount += 1L
}
}

override def merge(agg: AggregationBuffer, partial: Object): Unit = {
if (partial ne null) {
val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField))
val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField))
val buffer = agg.asInstanceOf[MockUDAFBuffer2]
buffer.nonNullCount += nonNullCount
buffer.nullCount += nullCount
}
}

// As this method is called for both states, Partial1 and Partial2, the hack in the method
// to check for class of aggregation buffer was necessary.
override def terminatePartial(agg: AggregationBuffer): AnyRef = {
var result: AnyRef = null
if (agg.getClass.toString.contains("MockUDAFBuffer2")) {
val buffer = agg.asInstanceOf[MockUDAFBuffer2]
result = Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
} else {
val buffer = agg.asInstanceOf[MockUDAFBuffer]
result = Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
}
result
}

override def terminate(agg: AggregationBuffer): AnyRef = {
val buffer = agg.asInstanceOf[MockUDAFBuffer2]
Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
}
}