diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 929beb660ad6..7c14f6420f67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -48,6 +48,8 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput def arguments: Seq[Expression] + def isVarargs: Boolean + def propagateNull: Boolean // InvokeLike is stateful because of the evaluatedArgs Array @@ -130,7 +132,13 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput } val argCode = ctx.splitExpressionsWithCurrentInputs(argCodes) - (argCode, argValues.mkString(", "), resultIsNull) + val argString = if (isVarargs) { + "new Object[] {" + argValues.mkString(", ") + "}" + } else { + argValues.mkString(", ") + } + + (argCode, argString, resultIsNull) } /** @@ -157,7 +165,12 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput null } else { val ret = try { - method.invoke(obj, evaluatedArgs: _*) + if (isVarargs) { + method.invoke(obj, evaluatedArgs) + } else { + method.invoke(obj, evaluatedArgs: _*) + } + } catch { // Re-throw the original exception. case e: java.lang.reflect.InvocationTargetException if e.getCause != null => @@ -264,7 +277,8 @@ case class StaticInvoke( inputTypes: Seq[AbstractDataType] = Nil, propagateNull: Boolean = true, returnNullable: Boolean = true, - isDeterministic: Boolean = true) extends InvokeLike { + isDeterministic: Boolean = true, + isVarargs: Boolean = false) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") val cls = if (staticObject.getName == objectName) { @@ -359,6 +373,7 @@ case class StaticInvoke( * non-null value. * @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark * will not apply certain optimizations such as constant folding. + * @param isVarargs When true, the search method parameter type is Array [Object] */ case class Invoke( targetObject: Expression, @@ -368,7 +383,8 @@ case class Invoke( methodInputTypes: Seq[AbstractDataType] = Nil, propagateNull: Boolean = true, returnNullable : Boolean = true, - isDeterministic: Boolean = true) extends InvokeLike { + isDeterministic: Boolean = true, + isVarargs: Boolean = false) extends InvokeLike { lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) @@ -387,6 +403,8 @@ case class Invoke( private lazy val encodedFunctionName = ScalaReflection.encodeFieldNameToIdentifier(functionName) @transient lazy val method = targetObject.dataType match { + case ObjectType(cls) if isVarargs => + Some(findMethod(cls, encodedFunctionName, Seq(classOf[Array[Any]]))) case ObjectType(cls) => Some(findMethod(cls, encodedFunctionName, argClasses)) case _ => None @@ -491,7 +509,17 @@ object NewInstance { arguments: Seq[Expression], dataType: DataType, propagateNull: Boolean = true): NewInstance = - new NewInstance(cls, arguments, inputTypes = Nil, propagateNull, dataType, None) + new NewInstance(cls, arguments, inputTypes = Nil, propagateNull, dataType, None, false) + + def apply( + cls: Class[_], + arguments: Seq[Expression], + inputTypes: Seq[AbstractDataType], + propagateNull: Boolean, + dataType: DataType, + outerPointer: Option[() => AnyRef]): NewInstance = + new NewInstance(cls, arguments, inputTypes = inputTypes, + propagateNull, dataType, outerPointer, false) } /** @@ -523,7 +551,8 @@ case class NewInstance( inputTypes: Seq[AbstractDataType], propagateNull: Boolean, dataType: DataType, - outerPointer: Option[() => AnyRef]) extends InvokeLike { + outerPointer: Option[() => AnyRef], + isVarargs: Boolean) extends InvokeLike { private val className = cls.getName override def nullable: Boolean = needNullCheck diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index f664244107b5..b5ae45e08f4e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -519,7 +519,7 @@ trait StringBinaryPredicateExpressionBuilderBase extends ExpressionBuilder { object BinaryPredicate { def unapply(expr: Expression): Option[StaticInvoke] = expr match { - case s @ StaticInvoke(clz, _, "contains" | "startsWith" | "endsWith", Seq(_, _), _, _, _, _) + case s @ StaticInvoke(clz, _, "contains" | "startsWith" | "endsWith", Seq(_, _), _, _, _, _, _) if clz == classOf[ByteArrayMethods] => Some(s) case _ => None } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index d5cff31ed642..3fa13c71ec2d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -33,10 +33,11 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.Obje import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{RuntimeReplaceable, _} import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} -import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ @@ -131,11 +132,23 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp private[hive] case class HiveGenericUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression + extends RuntimeReplaceable with HiveInspectors with Logging with UserDefinedExpression { + override lazy val replacement: Expression = { + Invoke( + targetObject = Literal.fromObject(new HiveGenericUDFInvokeAdapter( + funcWrapper, dataType, children)), + functionName = "evaluate", + dataType = dataType, + arguments = children, + isDeterministic = deterministic, + isVarargs = true + ) + } + override def nullable: Boolean = true override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) @@ -144,7 +157,7 @@ private[hive] case class HiveGenericUDF( isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] @transient - lazy val function = funcWrapper.createFunction[GenericUDF]() + private lazy val function = funcWrapper.createFunction[GenericUDF]() @transient private lazy val argumentInspectors = children.map(toInspector) @@ -154,38 +167,14 @@ private[hive] case class HiveGenericUDF( function.initializeAndFoldConstants(argumentInspectors.toArray) } - // Visible for codegen - @transient - lazy val unwrapper: Any => Any = unwrapperFor(returnInspector) - @transient private lazy val isUDFDeterministic = { val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) udfType != null && udfType.deterministic() && !udfType.stateful() } - // Visible for codegen - @transient - lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map { - case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType) - }.toArray[DeferredObject] - override lazy val dataType: DataType = inspectorToDataType(returnInspector) - override def eval(input: InternalRow): Any = { - returnInspector // Make sure initialized. - - var i = 0 - val length = children.length - while (i < length) { - val idx = i - deferredObjects(i).asInstanceOf[DeferredObjectAdapter] - .set(children(idx).eval(input)) - i += 1 - } - unwrapper(function.evaluate(deferredObjects)) - } - override def prettyName: String = name override def toString: String = { @@ -194,47 +183,52 @@ private[hive] case class HiveGenericUDF( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val refTerm = ctx.addReferenceObj("this", this) - val childrenEvals = children.map(_.genCode(ctx)) - - val setDeferredObjects = childrenEvals.zipWithIndex.map { - case (eval, i) => - val deferredObjectAdapterClz = classOf[DeferredObjectAdapter].getCanonicalName - s""" - |if (${eval.isNull}) { - | (($deferredObjectAdapterClz) $refTerm.deferredObjects()[$i]).set(null); - |} else { - | (($deferredObjectAdapterClz) $refTerm.deferredObjects()[$i]).set(${eval.value}); - |} - |""".stripMargin - } - val resultType = CodeGenerator.boxedType(dataType) - val resultTerm = ctx.freshName("result") - ev.copy(code = - code""" - |${childrenEvals.map(_.code).mkString("\n")} - |${setDeferredObjects.mkString("\n")} - |$resultType $resultTerm = null; - |boolean ${ev.isNull} = false; - |try { - | $resultTerm = ($resultType) $refTerm.unwrapper().apply( - | $refTerm.function().evaluate($refTerm.deferredObjects())); - | ${ev.isNull} = $resultTerm == null; - |} catch (Throwable e) { - | throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError( - | "${funcWrapper.functionClassName}", - | "${children.map(_.dataType.catalogString).mkString(", ")}", - | "${dataType.catalogString}", - | e); - |} - |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - |if (!${ev.isNull}) { - | ${ev.value} = $resultTerm; - |} - |""".stripMargin - ) + override lazy val canonicalized: Expression = { + withNewChildren(children.map(_.canonicalized)).asInstanceOf[HiveGenericUDF] + } +} + +class HiveGenericUDFInvokeAdapter( + funcWrapper: HiveFunctionWrapper, dataType: DataType, children: Seq[Expression]) + extends HiveInspectors with Serializable { + + @transient + private lazy val function = funcWrapper.createFunction[GenericUDF]() + + @transient + private lazy val argumentInspectors = children.map(toInspector) + + @transient + private lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map { + case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType) + }.toArray[DeferredObject] + + @transient + private lazy val returnInspector = { + function.initializeAndFoldConstants(argumentInspectors.toArray) + } + + @transient + private lazy val unwrapper: Any => Any = unwrapperFor(returnInspector) + + def evaluate(params: Array[Object]): Any = { + returnInspector // Make sure initialized. + params.zipWithIndex.map { + case (param, i) => + deferredObjects(i).asInstanceOf[DeferredObjectAdapter].set(param) + } + try { + val ret = function.evaluate(deferredObjects) + unwrapper(ret) + } catch { + case e: Throwable => + throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError( + s"${funcWrapper.functionClassName}", + s"${children.map(_.dataType.catalogString).mkString(", ")}", + s"${dataType.catalogString}", + e) + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index baa25843d48b..5dcd1b54e8e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -724,6 +724,17 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { checkAnswer(df, Seq(Row("14ab8df5135825bc9f5ff7c30609f02f"))) } } + withUserDefinedFunction("MultiArgsGenericUDF" -> false) { + sql(s"CREATE FUNCTION MultiArgsGenericUDF AS '${classOf[GenericUDFConcat].getName}'") + withTable("MultiArgsGenericUDFTable") { + sql("create table MultiArgsGenericUDFTable as " + + "select 'Deep-going studying' as x, 'Spark SQL' as y") + val df = sql("SELECT MultiArgsGenericUDF(x, ' ', y) from MultiArgsGenericUDFTable") + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[WholeStageCodegenExec]) + checkAnswer(df, Seq(Row("Deep-going studying Spark SQL"))) + } + } } test("SPARK-42051: HiveGenericUDF Codegen Support w/ execution failure") {