Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ def registerJavaFunction(self, name, javaClassName, returnType=None):
>>> sqlContext.registerJavaFunction("javaStringLength",
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
>>> sqlContext.sql("SELECT javaStringLength('test')").collect()
[Row(UDF(test)=4)]
[Row(UDF:javaStringLength(test)=4)]
>>> sqlContext.registerJavaFunction("javaStringLength2",
... "test.org.apache.spark.sql.JavaStringLength")
>>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
[Row(UDF(test)=4)]
[Row(UDF:javaStringLength2(test)=4)]

"""
jdt = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1950,7 +1950,7 @@ class Analyzer(

case p => p transformExpressionsUp {

case udf @ ScalaUDF(func, _, inputs, _, _, _) =>
case udf @ ScalaUDF(func, _, inputs, _, _, _, _) =>
val parameterTypes = ScalaReflection.getParameterTypes(func)
assert(parameterTypes.length == inputs.length)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.types.DataType

/**
* User-defined function.
* Note that the user-defined functions must be deterministic.
* @param function The user defined scala function to run.
* Note that if you use primitive parameters, you are not able to check if it is
* null or not, and the UDF will return null for you if the primitive input is
Expand All @@ -35,18 +34,23 @@ import org.apache.spark.sql.types.DataType
* not want to perform coercion, simply use "Nil". Note that it would've been
* better to use Option of Seq[DataType] so we can use "None" as the case for no
* type coercion. However, that would require more refactoring of the codebase.
* @param udfName The user-specified name of this UDF.
* @param udfName The user-specified name of this UDF.
* @param nullable True if the UDF can return null value.
* @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result
* each time it is invoked with a particular input.
*/
case class ScalaUDF(
function: AnyRef,
dataType: DataType,
children: Seq[Expression],
inputTypes: Seq[DataType] = Nil,
udfName: Option[String] = None,
nullable: Boolean = true)
nullable: Boolean = true,
udfDeterministic: Boolean = true)
extends Expression with ImplicitCastInputTypes with NonSQLExpression {

override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)

override def toString: String =
s"${udfName.map(name => s"UDF:$name").getOrElse("UDF")}(${children.mkString(", ")})"

Expand Down
243 changes: 139 additions & 104 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.expressions
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -35,10 +34,6 @@ import org.apache.spark.sql.types.DataType
* df.select( predict(df("score")) )
* }}}
*
* @note The user-defined functions must be deterministic. Due to optimization,
* duplicate invocations may be eliminated or the function may even be invoked more times than
* it is present in the query.
*
* @since 1.3.0
*/
@InterfaceStability.Stable
Expand All @@ -49,6 +44,7 @@ case class UserDefinedFunction protected[sql] (

private var _nameOption: Option[String] = None
private var _nullable: Boolean = true
private var _deterministic: Boolean = true

/**
* Returns true when the UDF can return a nullable value.
Expand All @@ -57,6 +53,14 @@ case class UserDefinedFunction protected[sql] (
*/
def nullable: Boolean = _nullable

/**
* Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the same
* input.
*
* @since 2.3.0
*/
def deterministic: Boolean = _deterministic

/**
* Returns an expression that invokes the UDF, using the given arguments.
*
Expand All @@ -69,13 +73,15 @@ case class UserDefinedFunction protected[sql] (
exprs.map(_.expr),
inputTypes.getOrElse(Nil),
udfName = _nameOption,
nullable = _nullable))
nullable = _nullable,
udfDeterministic = _deterministic))
}

private def copyAll(): UserDefinedFunction = {
val udf = copy()
udf._nameOption = _nameOption
udf._nullable = _nullable
udf._deterministic = _deterministic
udf
}

Expand All @@ -84,22 +90,38 @@ case class UserDefinedFunction protected[sql] (
*
* @since 2.3.0
*/
def withName(name: String): this.type = {
this._nameOption = Option(name)
this
def withName(name: String): UserDefinedFunction = {
val udf = copyAll()
udf._nameOption = Option(name)
udf
}

/**
* Updates UserDefinedFunction to non-nullable.
*
* @since 2.3.0
*/
def asNonNullabe(): UserDefinedFunction = {
if (!nullable) {
this
} else {
val udf = copyAll()
udf._nullable = false
udf
}
}

/**
* Updates UserDefinedFunction with a given nullability.
* Updates UserDefinedFunction to nondeterministic.
*
* @since 2.3.0
*/
def withNullability(nullable: Boolean): UserDefinedFunction = {
if (nullable == _nullable) {
def asNondeterministic(): UserDefinedFunction = {
if (!_deterministic) {
this
} else {
val udf = copyAll()
udf._nullable = nullable
udf._deterministic = false
udf
}
}
Expand Down
Loading