Skip to content

Commit 43c2d3d

Browse files
panbingkuncloud-fan
authored andcommitted
[SPARK-42771][SQL] Refactor HiveGenericUDF
### What changes were proposed in this pull request? The pr aims to refactor HiveGenericUDF. ### Why are the changes needed? Following #39949. Make the code more concise. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. Closes #40394 from panbingkun/refactor_HiveGenericUDF. Lead-authored-by: panbingkun <[email protected]> Co-authored-by: panbingkun <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 049aa38 commit 43c2d3d

File tree

1 file changed

+51
-47
lines changed

1 file changed

+51
-47
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -130,57 +130,32 @@ private[hive] case class HiveGenericUDF(
130130
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
131131
extends Expression
132132
with HiveInspectors
133-
with Logging
134133
with UserDefinedExpression {
135134

136135
override def nullable: Boolean = true
137136

138-
override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic)
139-
140-
override def foldable: Boolean =
141-
isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]
142-
143-
@transient
144-
lazy val function = funcWrapper.createFunction[GenericUDF]()
137+
override lazy val deterministic: Boolean =
138+
isUDFDeterministic && children.forall(_.deterministic)
145139

146-
@transient
147-
private lazy val argumentInspectors = children.map(toInspector)
140+
override def foldable: Boolean = isUDFDeterministic &&
141+
evaluator.returnInspector.isInstanceOf[ConstantObjectInspector]
148142

149-
@transient
150-
private lazy val returnInspector = {
151-
function.initializeAndFoldConstants(argumentInspectors.toArray)
152-
}
143+
override lazy val dataType: DataType = inspectorToDataType(evaluator.returnInspector)
153144

154-
// Visible for codegen
155145
@transient
156-
lazy val unwrapper: Any => Any = unwrapperFor(returnInspector)
146+
private lazy val evaluator = new HiveGenericUDFEvaluator(funcWrapper, children)
157147

158148
@transient
159-
private lazy val isUDFDeterministic = {
160-
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
149+
private val isUDFDeterministic = {
150+
val udfType = evaluator.function.getClass.getAnnotation(classOf[HiveUDFType])
161151
udfType != null && udfType.deterministic() && !udfType.stateful()
162152
}
163153

164-
// Visible for codegen
165-
@transient
166-
lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map {
167-
case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType)
168-
}.toArray[DeferredObject]
169-
170-
override lazy val dataType: DataType = inspectorToDataType(returnInspector)
171-
172154
override def eval(input: InternalRow): Any = {
173-
returnInspector // Make sure initialized.
174-
175-
var i = 0
176-
val length = children.length
177-
while (i < length) {
178-
val idx = i
179-
deferredObjects(i).asInstanceOf[DeferredObjectAdapter]
180-
.set(children(idx).eval(input))
181-
i += 1
155+
children.zipWithIndex.map {
156+
case (child, idx) => evaluator.setArg(idx, child.eval(input))
182157
}
183-
unwrapper(function.evaluate(deferredObjects))
158+
evaluator.evaluate()
184159
}
185160

186161
override def prettyName: String = name
@@ -191,18 +166,18 @@ private[hive] case class HiveGenericUDF(
191166

192167
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
193168
copy(children = newChildren)
194-
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
195-
val refTerm = ctx.addReferenceObj("this", this)
196-
val childrenEvals = children.map(_.genCode(ctx))
197169

198-
val setDeferredObjects = childrenEvals.zipWithIndex.map {
170+
protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
171+
val refEvaluator = ctx.addReferenceObj("evaluator", evaluator)
172+
val evals = children.map(_.genCode(ctx))
173+
174+
val setValues = evals.zipWithIndex.map {
199175
case (eval, i) =>
200-
val deferredObjectAdapterClz = classOf[DeferredObjectAdapter].getCanonicalName
201176
s"""
202177
|if (${eval.isNull}) {
203-
| (($deferredObjectAdapterClz) $refTerm.deferredObjects()[$i]).set(null);
178+
| $refEvaluator.setArg($i, null);
204179
|} else {
205-
| (($deferredObjectAdapterClz) $refTerm.deferredObjects()[$i]).set(${eval.value});
180+
| $refEvaluator.setArg($i, ${eval.value});
206181
|}
207182
|""".stripMargin
208183
}
@@ -211,13 +186,12 @@ private[hive] case class HiveGenericUDF(
211186
val resultTerm = ctx.freshName("result")
212187
ev.copy(code =
213188
code"""
214-
|${childrenEvals.map(_.code).mkString("\n")}
215-
|${setDeferredObjects.mkString("\n")}
189+
|${evals.map(_.code).mkString("\n")}
190+
|${setValues.mkString("\n")}
216191
|$resultType $resultTerm = null;
217192
|boolean ${ev.isNull} = false;
218193
|try {
219-
| $resultTerm = ($resultType) $refTerm.unwrapper().apply(
220-
| $refTerm.function().evaluate($refTerm.deferredObjects()));
194+
| $resultTerm = ($resultType) $refEvaluator.evaluate();
221195
| ${ev.isNull} = $resultTerm == null;
222196
|} catch (Throwable e) {
223197
| throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
@@ -235,6 +209,36 @@ private[hive] case class HiveGenericUDF(
235209
}
236210
}
237211

212+
class HiveGenericUDFEvaluator(
213+
funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
214+
extends HiveInspectors
215+
with Serializable {
216+
217+
@transient
218+
lazy val function = funcWrapper.createFunction[GenericUDF]()
219+
220+
@transient
221+
private lazy val argumentInspectors = children.map(toInspector)
222+
223+
@transient
224+
lazy val returnInspector = {
225+
function.initializeAndFoldConstants(argumentInspectors.toArray)
226+
}
227+
228+
@transient
229+
private lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map {
230+
case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType)
231+
}.toArray[DeferredObject]
232+
233+
@transient
234+
private lazy val unwrapper: Any => Any = unwrapperFor(returnInspector)
235+
236+
def setArg(index: Int, arg: Any): Unit =
237+
deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg)
238+
239+
def evaluate(): Any = unwrapper(function.evaluate(deferredObjects))
240+
}
241+
238242
/**
239243
* Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
240244
* `Generator`. Note that the semantics of Generators do not allow

0 commit comments

Comments
 (0)