Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 149 additions & 50 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@

package org.apache.spark.sql.hive

import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.hive.ql.exec._
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector,
ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions

import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -58,7 +60,7 @@ private[hive] case class HiveSimpleUDF(

@transient
private lazy val isUDFDeterministic = {
val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic()
}

Expand All @@ -75,7 +77,7 @@ private[hive] case class HiveSimpleUDF(

@transient
lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
method.getGenericReturnType(), ObjectInspectorOptions.JAVA))
method.getGenericReturnType, ObjectInspectorOptions.JAVA))

@transient
private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
Expand Down Expand Up @@ -263,8 +265,35 @@ private[hive] case class HiveGenericUDTF(
}

/**
* Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt
* performance a lot.
* While being evaluated by Spark SQL, the aggregation state of a Hive UDAF may be in the following
* three formats:
*
* 1. An instance of some concrete `GenericUDAFEvaluator.AggregationBuffer` class
*
* This is the native Hive representation of an aggregation state. Hive `GenericUDAFEvaluator`
* methods like `iterate()`, `merge()`, `terminatePartial()`, and `terminate()` use this format.
* We call these methods to evaluate Hive UDAFs.
*
* 2. A Java object that can be inspected using the `ObjectInspector` returned by the
* `GenericUDAFEvaluator.init()` method.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides of explaining what are these three formats, let's also explain when we will use each of them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(we can just put the pr description to here)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(is the doc below enough?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, added.

*
* Hive uses this format to produce a serializable aggregation state so that it can shuffle
* partial aggregation results. Whenever we need to convert a Hive `AggregationBuffer` instance
* into a Spark SQL value, we have to convert it to this format first and then do the conversion
* with the help of `ObjectInspector`s.
*
* 3. A Spark SQL value
*
* We use this format for serializing Hive UDAF aggregation states on Spark side. To be more
* specific, we convert `AggregationBuffer`s into equivalent Spark SQL values, write them into
* `UnsafeRow`s, and then retrieve the byte array behind those `UnsafeRow`s as serialization
* results.
*
* We may use the following methods to convert the aggregation state back and forth:
*
* - `wrap()`/`wrapperFor()`: from 3 to 1
* - `unwrap()`/`unwrapperFor()`: from 1 to 3
* - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3
*/
private[hive] case class HiveUDAFFunction(
name: String,
Expand All @@ -273,89 +302,89 @@ private[hive] case class HiveUDAFFunction(
isUDAFBridgeRequired: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends ImperativeAggregate with HiveInspectors {
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

// Hive `ObjectInspector`s for all child expressions (input parameters of the function).
@transient
private lazy val resolver =
if (isUDAFBridgeRequired) {
private lazy val inputInspectors = children.map(toInspector).toArray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add docs to explain when these internal vals are used (like which vals are needed for a given mode).


// Spark SQL data types of input parameters.
@transient
private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray

private def newEvaluator(): GenericUDAFEvaluator = {
val resolver = if (isUDAFBridgeRequired) {
new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
} else {
funcWrapper.createFunction[AbstractGenericUDAFResolver]()
}

@transient
private lazy val inspectors = children.map(toInspector).toArray

@transient
private lazy val functionAndInspector = {
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
val f = resolver.getEvaluator(parameterInfo)
f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false)
resolver.getEvaluator(parameterInfo)
}

// The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
@transient
private lazy val function = functionAndInspector._1
private lazy val partial1ModeEvaluator = newEvaluator()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make it a lazy val since partialResultInspector is uses it right below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It has to be transient and lazy so that it's also available on executor side since Hive UDAF evaluators are not serializable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I think in general we should avoid of using this pattern. If we have to use it now, let's explain it in the comment.


// Hive `ObjectInspector` used to inspect partial aggregation results.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial aggregation result is aggregation buffer, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea. They are those objects returned by terminatePartial(), which is the inspectable version of Hive AggregationBuffer.

@transient
private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
private val partialResultInspector = partial1ModeEvaluator.init(
GenericUDAFEvaluator.Mode.PARTIAL1,
inputInspectors
)

// The UDAF evaluator used to merge partial aggregation results.
@transient
private lazy val returnInspector = functionAndInspector._2
private lazy val partial2ModeEvaluator = {
val evaluator = newEvaluator()
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector))
evaluator
}

// Spark SQL data type of partial aggregation results
@transient
private lazy val unwrapper = unwrapperFor(returnInspector)
private lazy val partialResultDataType = inspectorToDataType(partialResultInspector)

// The UDAF evaluator used to compute the final result from a partial aggregation result objects.
@transient
private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _

override def eval(input: InternalRow): Any = unwrapper(function.evaluate(buffer))
private lazy val finalModeEvaluator = newEvaluator()

// Hive `ObjectInspector` used to inspect the final aggregation result object.
@transient
private lazy val inputProjection = new InterpretedProjection(children)
private val returnInspector = finalModeEvaluator.init(
GenericUDAFEvaluator.Mode.FINAL,
Array(partialResultInspector)
)

// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
@transient
private lazy val cached = new Array[AnyRef](children.length)
private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray

// Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into
// Spark SQL specific format.
@transient
private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray

// Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation
// buffer for it.
override def aggBufferSchema: StructType = StructType(Nil)

override def update(_buffer: InternalRow, input: InternalRow): Unit = {
val inputs = inputProjection(input)
function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes))
}

override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = {
throw new UnsupportedOperationException(
"Hive UDAF doesn't support partial aggregate")
}
private lazy val resultUnwrapper = unwrapperFor(returnInspector)

override def initialize(_buffer: InternalRow): Unit = {
buffer = function.getNewAggregationBuffer
}

override val aggBufferAttributes: Seq[AttributeReference] = Nil
@transient
private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)

// Note: although this simply copies aggBufferAttributes, this common code can not be placed
// in the superclass because that will lead to initialization ordering issues.
override val inputAggBufferAttributes: Seq[AttributeReference] = Nil
@transient
private lazy val aggBufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc


// We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
// catalyst type checking framework.
override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)

override def nullable: Boolean = true

override def supportsPartial: Boolean = false
override def supportsPartial: Boolean = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any Hive UDAF that does not support partial aggregation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Hive doesn't have an equivalent flag and all UDAFs inherently support partial aggregation since they have to implement callbacks of all phases.


override lazy val dataType: DataType = inspectorToDataType(returnInspector)

Expand All @@ -365,4 +394,74 @@ private[hive] case class HiveUDAFFunction(
val distinct = if (isDistinct) "DISTINCT " else " "
s"$name($distinct${children.map(_.sql).mkString(", ")})"
}

override def createAggregationBuffer(): AggregationBuffer =
partial1ModeEvaluator.getNewAggregationBuffer

@transient
private lazy val inputProjection = UnsafeProjection.create(children)

override def update(buffer: AggregationBuffer, input: InternalRow): Unit = {
partial1ModeEvaluator.iterate(
buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
}

override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = {
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's explain what we are trying to do using partial1ModeEvaluator.terminatePartial(input).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we follow the code flow from interfaces.scala, we see that the results of aggregation buffer mode in PARTIAL2 is merged with the aggregation buffer in PARTIAL1. I am new to Spark and Hive, so just wanted to know the reason behind the above behaviour. If there are any docs suggesting this, do let me know. Thank you.

}

override def eval(buffer: AggregationBuffer): Any = {
resultUnwrapper(finalModeEvaluator.terminate(buffer))
}

override def serialize(buffer: AggregationBuffer): Array[Byte] = {
// Serializes an `AggregationBuffer` that holds partial aggregation results so that we can
// shuffle it for global aggregation later.
aggBufferSerDe.serialize(buffer)
}

override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
// Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare
// for global aggregation by merging multiple partial aggregation results within a single group.
aggBufferSerDe.deserialize(bytes)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add docs to explain what these functions are doing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


// Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
private class AggregationBufferSerDe {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we take this class out from HiveUDAFFunction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, but it doesn't seem to be necessary. Make it a nested class also simplifies implementation since it has access to fields of the outer class.

private val partialResultUnwrapper = unwrapperFor(partialResultInspector)

private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType)

private val projection = UnsafeProjection.create(Array(partialResultDataType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to run hive udaf in spark shell? IIUC, we can't create unsafe projection inside UDAF

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does work as expected:

scala> sql("CREATE TEMPORARY FUNCTION hive_max AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax'")
res0: org.apache.spark.sql.DataFrame = []

scala> spark.range(100).createOrReplaceTempView("t")

scala> sql("SELECT hive_max(id) FROM t").explain()
== Physical Plan ==
SortAggregate(key=[], functions=[hive_max(hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax@144792d5), id#1L, false, 0, 0)])
+- Exchange SinglePartition
   +- SortAggregate(key=[], functions=[partial_hive_max(hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax@144792d5), id#1L, false, 0, 0)])
      +- *Range (0, 100, step=1, splits=Some(8))

scala> sql("SELECT hive_max(id) FROM t").show()
+-------------+
|hive_max( id)|
+-------------+
|           99|
+-------------+

Why do you think UnsafeProjection can't be used here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to use unsafe projection in percentile_approx before, but failed in spark shell, maybe it's a different problem, nvm.


private val mutableRow = new GenericInternalRow(1)

def serialize(buffer: AggregationBuffer): Array[Byte] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
// Then we can unwrap it to a Spark SQL value.
mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer)))
val unsafeRow = projection(mutableRow)
val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
unsafeRow.writeTo(bytes)
bytes.array()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just use unsafeRow.getBytes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't they equivalent in this case? UnsafeRow.getBytes also performs some more checks that are not necessary here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but you also create an unnecessary ByteBuffer... as they are equivalent, doesn't unsafeRow.getBytes simpler?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually they are different. If the buffer type is fixed length, then the unsafeRow is just a fixed-length bytes array, and UnsafeRow.getBytes will just return that array, instead of copying the memory.

}

def deserialize(bytes: Array[Byte]): AggregationBuffer = {
// `GenericUDAFEvaluator` doesn't provide any method that is capable to convert an object
// returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The
// workaround here is creating an initial `AggregationBuffer` first and then merge the
// deserialized object into the buffer.
val buffer = partial2ModeEvaluator.getNewAggregationBuffer
val unsafeRow = new UnsafeRow(1)
unsafeRow.pointTo(bytes, bytes.length)
val partialResult = unsafeRow.get(0, partialResultDataType)
partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult))
buffer
}
}
}
Loading