Skip to content

Commit 73aa23b

Browse files
rxingatorsmile
authored andcommitted
[SPARK-20674][SQL] Support registering UserDefinedFunction as named UDF
## What changes were proposed in this pull request? For some reason we don't have an API to register UserDefinedFunction as named UDF. It is a no brainer to add one, in addition to the existing register functions we have. ## How was this patch tested? Added a test case in UDFSuite for the new API. Author: Reynold Xin <[email protected]> Closes #17915 from rxin/SPARK-20674. (cherry picked from commit d099f41) Signed-off-by: Xiao Li <[email protected]>
1 parent 08e1b78 commit 73aa23b

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,31 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
7070
* @param name the name of the UDAF.
7171
* @param udaf the UDAF needs to be registered.
7272
* @return the registered UDAF.
73+
*
74+
* @since 1.5.0
7375
*/
74-
def register(
75-
name: String,
76-
udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
76+
def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
7777
def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf)
7878
functionRegistry.registerFunction(name, builder)
7979
udaf
8080
}
8181

82+
/**
83+
* Register a user-defined function (UDF), for a UDF that's already defined using the DataFrame
84+
* API (i.e. of type UserDefinedFunction).
85+
*
86+
* @param name the name of the UDF.
87+
* @param udf the UDF needs to be registered.
88+
* @return the registered UDF.
89+
*
90+
* @since 2.2.0
91+
*/
92+
def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = {
93+
def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr
94+
functionRegistry.registerFunction(name, builder)
95+
udf
96+
}
97+
8298
// scalastyle:off line.size.limit
8399

84100
/* register 0-22 were generated by this script

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
9393
assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
9494
}
9595

96+
test("UDF defined using UserDefinedFunction") {
97+
import functions.udf
98+
val foo = udf((x: Int) => x + 1)
99+
spark.udf.register("foo", foo)
100+
assert(sql("select foo(5)").head().getInt(0) == 6)
101+
}
102+
96103
test("ZeroArgument UDF") {
97104
spark.udf.register("random0", () => { Math.random()})
98105
assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)

0 commit comments

Comments
 (0)