Skip to content

Commit ebc24a9

Browse files
committed
[SPARK-20586][SQL] Add deterministic to ScalaUDF
### What changes were proposed in this pull request? Like [Hive UDFType](https://hive.apache.org/javadocs/r2.0.1/api/org/apache/hadoop/hive/ql/udf/UDFType.html), we should allow users to add the extra flags for ScalaUDF and JavaUDF too. _stateful_/_impliesOrder_ are not applicable to our Scala UDF. Thus, we only add the following two flags. - deterministic: Certain optimizations should not be applied if UDF is not deterministic. Deterministic UDF returns same result each time it is invoked with a particular input. This determinism just needs to hold within the context of a query. When the deterministic flag is not correctly set, the results could be wrong. For ScalaUDF in Dataset APIs, users can call the following extra APIs for `UserDefinedFunction` to make the corresponding changes. - `nonDeterministic`: Updates UserDefinedFunction to non-deterministic. Also fixed the Java UDF name loss issue. Will submit a separate PR for `distinctLike` for UDAF ### How was this patch tested? Added test cases for both ScalaUDF Author: gatorsmile <[email protected]> Author: Wenchen Fan <[email protected]> Closes #17848 from gatorsmile/udfRegister.
1 parent 9b4da7b commit ebc24a9

File tree

7 files changed

+278
-164
lines changed

7 files changed

+278
-164
lines changed

python/pyspark/sql/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,11 @@ def registerJavaFunction(self, name, javaClassName, returnType=None):
220220
>>> sqlContext.registerJavaFunction("javaStringLength",
221221
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
222222
>>> sqlContext.sql("SELECT javaStringLength('test')").collect()
223-
[Row(UDF(test)=4)]
223+
[Row(UDF:javaStringLength(test)=4)]
224224
>>> sqlContext.registerJavaFunction("javaStringLength2",
225225
... "test.org.apache.spark.sql.JavaStringLength")
226226
>>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
227-
[Row(UDF(test)=4)]
227+
[Row(UDF:javaStringLength2(test)=4)]
228228
229229
"""
230230
jdt = None

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1950,7 +1950,7 @@ class Analyzer(
19501950

19511951
case p => p transformExpressionsUp {
19521952

1953-
case udf @ ScalaUDF(func, _, inputs, _, _, _) =>
1953+
case udf @ ScalaUDF(func, _, inputs, _, _, _, _) =>
19541954
val parameterTypes = ScalaReflection.getParameterTypes(func)
19551955
assert(parameterTypes.length == inputs.length)
19561956

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import org.apache.spark.sql.types.DataType
2424

2525
/**
2626
* User-defined function.
27-
* Note that the user-defined functions must be deterministic.
2827
* @param function The user defined scala function to run.
2928
* Note that if you use primitive parameters, you are not able to check if it is
3029
* null or not, and the UDF will return null for you if the primitive input is
@@ -35,18 +34,23 @@ import org.apache.spark.sql.types.DataType
3534
* not want to perform coercion, simply use "Nil". Note that it would've been
3635
* better to use Option of Seq[DataType] so we can use "None" as the case for no
3736
* type coercion. However, that would require more refactoring of the codebase.
38-
* @param udfName The user-specified name of this UDF.
37+
* @param udfName The user-specified name of this UDF.
3938
* @param nullable True if the UDF can return null value.
39+
* @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result
40+
* each time it is invoked with a particular input.
4041
*/
4142
case class ScalaUDF(
4243
function: AnyRef,
4344
dataType: DataType,
4445
children: Seq[Expression],
4546
inputTypes: Seq[DataType] = Nil,
4647
udfName: Option[String] = None,
47-
nullable: Boolean = true)
48+
nullable: Boolean = true,
49+
udfDeterministic: Boolean = true)
4850
extends Expression with ImplicitCastInputTypes with NonSQLExpression {
4951

52+
override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
53+
5054
override def toString: String =
5155
s"${udfName.map(name => s"UDF:$name").getOrElse("UDF")}(${children.mkString(", ")})"
5256

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Lines changed: 139 additions & 104 deletions
Large diffs are not rendered by default.

sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql.expressions
2020
import org.apache.spark.annotation.InterfaceStability
2121
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
2222
import org.apache.spark.sql.Column
23-
import org.apache.spark.sql.functions
2423
import org.apache.spark.sql.types.DataType
2524

2625
/**
@@ -35,10 +34,6 @@ import org.apache.spark.sql.types.DataType
3534
* df.select( predict(df("score")) )
3635
* }}}
3736
*
38-
* @note The user-defined functions must be deterministic. Due to optimization,
39-
* duplicate invocations may be eliminated or the function may even be invoked more times than
40-
* it is present in the query.
41-
*
4237
* @since 1.3.0
4338
*/
4439
@InterfaceStability.Stable
@@ -49,6 +44,7 @@ case class UserDefinedFunction protected[sql] (
4944

5045
private var _nameOption: Option[String] = None
5146
private var _nullable: Boolean = true
47+
private var _deterministic: Boolean = true
5248

5349
/**
5450
* Returns true when the UDF can return a nullable value.
@@ -57,6 +53,14 @@ case class UserDefinedFunction protected[sql] (
5753
*/
5854
def nullable: Boolean = _nullable
5955

56+
/**
57+
* Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the same
58+
* input.
59+
*
60+
* @since 2.3.0
61+
*/
62+
def deterministic: Boolean = _deterministic
63+
6064
/**
6165
* Returns an expression that invokes the UDF, using the given arguments.
6266
*
@@ -69,13 +73,15 @@ case class UserDefinedFunction protected[sql] (
6973
exprs.map(_.expr),
7074
inputTypes.getOrElse(Nil),
7175
udfName = _nameOption,
72-
nullable = _nullable))
76+
nullable = _nullable,
77+
udfDeterministic = _deterministic))
7378
}
7479

7580
private def copyAll(): UserDefinedFunction = {
7681
val udf = copy()
7782
udf._nameOption = _nameOption
7883
udf._nullable = _nullable
84+
udf._deterministic = _deterministic
7985
udf
8086
}
8187

@@ -84,22 +90,38 @@ case class UserDefinedFunction protected[sql] (
8490
*
8591
* @since 2.3.0
8692
*/
87-
def withName(name: String): this.type = {
88-
this._nameOption = Option(name)
89-
this
93+
def withName(name: String): UserDefinedFunction = {
94+
val udf = copyAll()
95+
udf._nameOption = Option(name)
96+
udf
97+
}
98+
99+
/**
100+
* Updates UserDefinedFunction to non-nullable.
101+
*
102+
* @since 2.3.0
103+
*/
104+
def asNonNullabe(): UserDefinedFunction = {
105+
if (!nullable) {
106+
this
107+
} else {
108+
val udf = copyAll()
109+
udf._nullable = false
110+
udf
111+
}
90112
}
91113

92114
/**
93-
* Updates UserDefinedFunction with a given nullability.
115+
* Updates UserDefinedFunction to nondeterministic.
94116
*
95117
* @since 2.3.0
96118
*/
97-
def withNullability(nullable: Boolean): UserDefinedFunction = {
98-
if (nullable == _nullable) {
119+
def asNondeterministic(): UserDefinedFunction = {
120+
if (!_deterministic) {
99121
this
100122
} else {
101123
val udf = copyAll()
102-
udf._nullable = nullable
124+
udf._deterministic = false
103125
udf
104126
}
105127
}

0 commit comments

Comments
 (0)