Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput

def arguments: Seq[Expression]

def isVarargs: Boolean

def propagateNull: Boolean

// InvokeLike is stateful because of the evaluatedArgs Array
Expand Down Expand Up @@ -130,7 +132,13 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput
}
val argCode = ctx.splitExpressionsWithCurrentInputs(argCodes)

(argCode, argValues.mkString(", "), resultIsNull)
val argString = if (isVarargs) {
"new Object[] {" + argValues.mkString(", ") + "}"
} else {
argValues.mkString(", ")
}

(argCode, argString, resultIsNull)
}

/**
Expand All @@ -157,7 +165,12 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput
null
} else {
val ret = try {
method.invoke(obj, evaluatedArgs: _*)
if (isVarargs) {
method.invoke(obj, evaluatedArgs)
} else {
method.invoke(obj, evaluatedArgs: _*)
}

} catch {
// Re-throw the original exception.
case e: java.lang.reflect.InvocationTargetException if e.getCause != null =>
Expand Down Expand Up @@ -264,7 +277,8 @@ case class StaticInvoke(
inputTypes: Seq[AbstractDataType] = Nil,
propagateNull: Boolean = true,
returnNullable: Boolean = true,
isDeterministic: Boolean = true) extends InvokeLike {
isDeterministic: Boolean = true,
isVarargs: Boolean = false) extends InvokeLike {

val objectName = staticObject.getName.stripSuffix("$")
val cls = if (staticObject.getName == objectName) {
Expand Down Expand Up @@ -359,6 +373,7 @@ case class StaticInvoke(
* non-null value.
* @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark
* will not apply certain optimizations such as constant folding.
* @param isVarargs When true, the search method parameter type is Array [Object]
*/
case class Invoke(
targetObject: Expression,
Expand All @@ -368,7 +383,8 @@ case class Invoke(
methodInputTypes: Seq[AbstractDataType] = Nil,
propagateNull: Boolean = true,
returnNullable : Boolean = true,
isDeterministic: Boolean = true) extends InvokeLike {
isDeterministic: Boolean = true,
isVarargs: Boolean = false) extends InvokeLike {

lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)

Expand All @@ -387,6 +403,8 @@ case class Invoke(
private lazy val encodedFunctionName = ScalaReflection.encodeFieldNameToIdentifier(functionName)

@transient lazy val method = targetObject.dataType match {
case ObjectType(cls) if isVarargs =>
Some(findMethod(cls, encodedFunctionName, Seq(classOf[Array[Any]])))
case ObjectType(cls) =>
Some(findMethod(cls, encodedFunctionName, argClasses))
case _ => None
Expand Down Expand Up @@ -491,7 +509,17 @@ object NewInstance {
arguments: Seq[Expression],
dataType: DataType,
propagateNull: Boolean = true): NewInstance =
new NewInstance(cls, arguments, inputTypes = Nil, propagateNull, dataType, None)
new NewInstance(cls, arguments, inputTypes = Nil, propagateNull, dataType, None, false)

def apply(
cls: Class[_],
arguments: Seq[Expression],
inputTypes: Seq[AbstractDataType],
propagateNull: Boolean,
dataType: DataType,
outerPointer: Option[() => AnyRef]): NewInstance =
new NewInstance(cls, arguments, inputTypes = inputTypes,
propagateNull, dataType, outerPointer, false)
}

/**
Expand Down Expand Up @@ -523,7 +551,8 @@ case class NewInstance(
inputTypes: Seq[AbstractDataType],
propagateNull: Boolean,
dataType: DataType,
outerPointer: Option[() => AnyRef]) extends InvokeLike {
outerPointer: Option[() => AnyRef],
isVarargs: Boolean) extends InvokeLike {
private val className = cls.getName

override def nullable: Boolean = needNullCheck
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ trait StringBinaryPredicateExpressionBuilderBase extends ExpressionBuilder {

object BinaryPredicate {
def unapply(expr: Expression): Option[StaticInvoke] = expr match {
case s @ StaticInvoke(clz, _, "contains" | "startsWith" | "endsWith", Seq(_, _), _, _, _, _)
case s @ StaticInvoke(clz, _, "contains" | "startsWith" | "endsWith", Seq(_, _), _, _, _, _, _)
if clz == classOf[ByteArrayMethods] => Some(s)
case _ => None
}
Expand Down
132 changes: 63 additions & 69 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.Obje

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.{RuntimeReplaceable, _}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -131,11 +132,23 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp

private[hive] case class HiveGenericUDF(
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression
extends RuntimeReplaceable
with HiveInspectors
with Logging
with UserDefinedExpression {

override lazy val replacement: Expression = {
Invoke(
targetObject = Literal.fromObject(new HiveGenericUDFInvokeAdapter(
funcWrapper, dataType, children)),
functionName = "evaluate",
dataType = dataType,
arguments = children,
isDeterministic = deterministic,
isVarargs = true
)
}

override def nullable: Boolean = true

override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic)
Expand All @@ -144,7 +157,7 @@ private[hive] case class HiveGenericUDF(
isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]

@transient
lazy val function = funcWrapper.createFunction[GenericUDF]()
private lazy val function = funcWrapper.createFunction[GenericUDF]()

@transient
private lazy val argumentInspectors = children.map(toInspector)
Expand All @@ -154,38 +167,14 @@ private[hive] case class HiveGenericUDF(
function.initializeAndFoldConstants(argumentInspectors.toArray)
}

// Visible for codegen
@transient
lazy val unwrapper: Any => Any = unwrapperFor(returnInspector)

@transient
private lazy val isUDFDeterministic = {
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic() && !udfType.stateful()
}

// Visible for codegen
@transient
lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map {
case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType)
}.toArray[DeferredObject]

override lazy val dataType: DataType = inspectorToDataType(returnInspector)

override def eval(input: InternalRow): Any = {
returnInspector // Make sure initialized.

var i = 0
val length = children.length
while (i < length) {
val idx = i
deferredObjects(i).asInstanceOf[DeferredObjectAdapter]
.set(children(idx).eval(input))
i += 1
}
unwrapper(function.evaluate(deferredObjects))
}

override def prettyName: String = name

override def toString: String = {
Expand All @@ -194,47 +183,52 @@ private[hive] case class HiveGenericUDF(

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val refTerm = ctx.addReferenceObj("this", this)
val childrenEvals = children.map(_.genCode(ctx))

val setDeferredObjects = childrenEvals.zipWithIndex.map {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

codegen is performance critical. Previously, we generate code to directly set values for deferredObjects, but in this PR we always create a Object[] to wrap the arguments, which can be bad for performance.

This is a signal that Invoke doesn't work very well for Hive UDF. It's still valuable to have something like HiveGenericUDFInvokeAdapter to share code between interpreted code path and codegen code path.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan
Should I continue to rewrite HiveSimpleUDF with Invoke
Or the pr #39865 is ok?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should refactor HiveGenericUDF first, then follow it to implement codegen of HiveSimpleUDF.

What the refactor should do is to add something like HiveGenericUDFInvokeAdapter in this PR, to keep all the states. Then HiveGenericUDF just manipulates this stateful object in both interpreted code path and codegen.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, Let me try to do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have submitted a new pr: #40394 to refactor HiveGenericUDF.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow it to implement codegen of HiveSimpleUDF: #40397

case (eval, i) =>
val deferredObjectAdapterClz = classOf[DeferredObjectAdapter].getCanonicalName
s"""
|if (${eval.isNull}) {
| (($deferredObjectAdapterClz) $refTerm.deferredObjects()[$i]).set(null);
|} else {
| (($deferredObjectAdapterClz) $refTerm.deferredObjects()[$i]).set(${eval.value});
|}
|""".stripMargin
}

val resultType = CodeGenerator.boxedType(dataType)
val resultTerm = ctx.freshName("result")
ev.copy(code =
code"""
|${childrenEvals.map(_.code).mkString("\n")}
|${setDeferredObjects.mkString("\n")}
|$resultType $resultTerm = null;
|boolean ${ev.isNull} = false;
|try {
| $resultTerm = ($resultType) $refTerm.unwrapper().apply(
| $refTerm.function().evaluate($refTerm.deferredObjects()));
| ${ev.isNull} = $resultTerm == null;
|} catch (Throwable e) {
| throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
| "${funcWrapper.functionClassName}",
| "${children.map(_.dataType.catalogString).mkString(", ")}",
| "${dataType.catalogString}",
| e);
|}
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
|""".stripMargin
)
override lazy val canonicalized: Expression = {
withNewChildren(children.map(_.canonicalized)).asInstanceOf[HiveGenericUDF]
}
}

class HiveGenericUDFInvokeAdapter(
funcWrapper: HiveFunctionWrapper, dataType: DataType, children: Seq[Expression])
extends HiveInspectors with Serializable {

@transient
private lazy val function = funcWrapper.createFunction[GenericUDF]()

@transient
private lazy val argumentInspectors = children.map(toInspector)

@transient
private lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map {
case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType)
}.toArray[DeferredObject]

@transient
private lazy val returnInspector = {
function.initializeAndFoldConstants(argumentInspectors.toArray)
}

@transient
private lazy val unwrapper: Any => Any = unwrapperFor(returnInspector)

def evaluate(params: Array[Object]): Any = {
returnInspector // Make sure initialized.
params.zipWithIndex.map {
case (param, i) =>
deferredObjects(i).asInstanceOf[DeferredObjectAdapter].set(param)
}
try {
val ret = function.evaluate(deferredObjects)
unwrapper(ret)
} catch {
case e: Throwable =>
throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
s"${funcWrapper.functionClassName}",
s"${children.map(_.dataType.catalogString).mkString(", ")}",
s"${dataType.catalogString}",
e)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,17 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
checkAnswer(df, Seq(Row("14ab8df5135825bc9f5ff7c30609f02f")))
}
}
withUserDefinedFunction("MultiArgsGenericUDF" -> false) {
sql(s"CREATE FUNCTION MultiArgsGenericUDF AS '${classOf[GenericUDFConcat].getName}'")
withTable("MultiArgsGenericUDFTable") {
sql("create table MultiArgsGenericUDFTable as " +
"select 'Deep-going studying' as x, 'Spark SQL' as y")
val df = sql("SELECT MultiArgsGenericUDF(x, ' ', y) from MultiArgsGenericUDFTable")
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[WholeStageCodegenExec])
checkAnswer(df, Seq(Row("Deep-going studying Spark SQL")))
}
}
}

test("SPARK-42051: HiveGenericUDF Codegen Support w/ execution failure") {
Expand Down