diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 9d8437b068d5..dffd38217af5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -817,6 +817,14 @@ private[hive] trait HiveInspectors { cache } + def wrap( + row: Array[Any], + wrappers: Array[(Any) => Any], + cache: Array[AnyRef], + dataTypes: Array[DataType]): Array[AnyRef] = { + wrap(row.toSeq, wrappers, cache, dataTypes) + } + /** * @param dataType Catalyst data type * @return Hive java object inspector (recursively), not the Writable ObjectInspector diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 32ade60e20d0..b08d89bb6099 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -49,7 +49,6 @@ private[hive] case class HiveSimpleUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors - with CodegenFallback with Logging with UserDefinedExpression { @@ -61,7 +60,7 @@ private[hive] case class HiveSimpleUDF( lazy val function = funcWrapper.createFunction[UDF]() @transient - private lazy val method = + protected lazy val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo).asJava) @transient @@ -77,22 +76,22 @@ private[hive] case class HiveSimpleUDF( // Create parameter converters @transient - private lazy val conversionHelper = new ConversionHelper(method, arguments) + protected lazy val conversionHelper = new ConversionHelper(method, arguments) override lazy val dataType = javaTypeToDataType(method.getGenericReturnType) @transient - private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + protected lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray @transient lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector( method.getGenericReturnType, ObjectInspectorOptions.JAVA)) @transient - private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) + protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) @transient - private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + protected lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray // TODO: Finish input output types. override def eval(input: InternalRow): Any = { @@ -104,6 +103,67 @@ private[hive] case class HiveSimpleUDF( unwrapper(ret) } + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val refTerm = ctx.addReferenceObj("this", this) + val evals = children.map(_.genCode(ctx)) + val resultType = CodeGenerator.boxedType(dataType) + val resultTerm = ctx.freshName("result") + val inputsTerm = ctx.freshName("inputs") + val inputsWrapTerm = ctx.freshName("inputsWrap") + + val initInputs = + s""" + |Object[] $inputsTerm = new Object[${evals.size}]; + |""".stripMargin + + val setInputs = evals.zipWithIndex.map { + case (eval, i) => + s""" + |if (${eval.isNull}) { + | $inputsTerm[$i] = null; + |} else { + | $inputsTerm[$i] = ${eval.value}; + |} + |""".stripMargin + } + + val inputsWrap = { + s""" + |Object[] $inputsWrapTerm = $refTerm.wrap($inputsTerm, $refTerm.wrappers(), + | $refTerm.cached(), $refTerm.inputDataTypes()); + |""".stripMargin + } + + ev.copy(code = + code""" + |${evals.map(_.code).mkString("\n")} + |$initInputs + |${setInputs.mkString("\n")} + |$inputsWrap + |$resultType $resultTerm = null; + |boolean ${ev.isNull} = false; + |try { + | $resultTerm = ($resultType) $refTerm.unwrapper().apply( + | org.apache.hadoop.hive.ql.exec.FunctionRegistry.invoke( + | $refTerm.method(), + | $refTerm.function(), + | $refTerm.conversionHelper().convertIfNecessary($inputsWrapTerm))); + | ${ev.isNull} = $resultTerm == null; + |} catch (Throwable e) { + | throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError( + | "${funcWrapper.functionClassName}", + | "${children.map(_.dataType.catalogString).mkString(", ")}", + | "${dataType.catalogString}", + | e); + |} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${ev.isNull}) { + | ${ev.value} = $resultTerm; + |} + |""".stripMargin + ) + } + override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } @@ -194,6 +254,7 @@ private[hive] case class HiveGenericUDF( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val refTerm = ctx.addReferenceObj("this", this) val childrenEvals = children.map(_.genCode(ctx)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index baa25843d48b..8fb9209f9cb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.exec.UDF +import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject @@ -743,6 +744,38 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } } + + test("SPARK-42052: HiveSimpleUDF Codegen Support") { + withUserDefinedFunction("CodeGenHiveSimpleUDF" -> false) { + sql(s"CREATE FUNCTION CodeGenHiveSimpleUDF AS '${classOf[UDFStringString].getName}'") + withTable("HiveSimpleUDFTable") { + sql(s"create table HiveSimpleUDFTable as select 'Spark SQL' as v") + val df = sql("SELECT CodeGenHiveSimpleUDF('Hello', v) from HiveSimpleUDFTable") + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[WholeStageCodegenExec]) + checkAnswer(df, Seq(Row("Hello Spark SQL"))) + } + } + } + + test("SPARK-42052: HiveSimpleUDF Codegen Support w/ execution failure") { + withUserDefinedFunction("CodeGenHiveSimpleUDF" -> false) { + sql(s"CREATE FUNCTION CodeGenHiveSimpleUDF AS '${classOf[SimpleUDFAssertTrue].getName}'") + withTable("HiveSimpleUDFTable") { + sql(s"create table HiveSimpleUDFTable as select false as v") + val df = sql("SELECT CodeGenHiveSimpleUDF(v) from HiveSimpleUDFTable") + checkError( + exception = intercept[SparkException](df.collect()).getCause.asInstanceOf[SparkException], + errorClass = "FAILED_EXECUTE_UDF", + parameters = Map( + "functionName" -> s"${classOf[SimpleUDFAssertTrue].getName}", + "signature" -> "boolean", + "result" -> "boolean" + ) + ) + } + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { @@ -844,3 +877,12 @@ class ListFiles extends UDF { if (fileArray != null) Arrays.asList(fileArray: _*) else new ArrayList[String]() } } + +class SimpleUDFAssertTrue extends UDF { + def evaluate(condition: Boolean): Boolean = { + if (!condition) { + throw new HiveException("ASSERT_TRUE(): assertion failed."); + } + condition + } +}