Skip to content

Commit e08fc08

Browse files
committed
fix.
1 parent d276b44 commit e08fc08

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

python/pyspark/sql/context.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pyspark.sql.dataframe import DataFrame
2929
from pyspark.sql.readwriter import DataFrameReader
3030
from pyspark.sql.streaming import DataStreamReader
31-
from pyspark.sql.types import IntegerType, Row, StringType
31+
from pyspark.sql.types import DoubleType, IntegerType, Row, StringType
3232
from pyspark.sql.utils import install_exception_handler
3333

3434
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
@@ -208,14 +208,19 @@ def registerFunction(self, name, f, returnType=StringType()):
208208

209209
@ignore_unicode_prefix
210210
@since(2.1)
211-
def registerJavaFunction(self, name, javaClassName, returnType=None):
211+
def registerJavaFunction(self, name, javaClassName, returnType=None, deterministic=True,
212+
distinctLike=False):
212213
"""Register a java UDF so it can be used in SQL statements.
213214
214215
In addition to a name and the function itself, the return type can be optionally specified.
215216
When the return type is not specified we would infer it via reflection.
216217
:param name: name of the UDF
217218
:param javaClassName: fully qualified name of java class
218219
:param returnType: a :class:`pyspark.sql.types.DataType` object
220+
:param deterministic: A flag indicating if the UDF is deterministic. Deterministic UDF
221+
returns same result each time it is invoked with a particular input.
222+
:param distinctLike: a UDF is considered distinctLike if the UDF can be evaluated on just
223+
the distinct values of a column.
219224
220225
>>> sqlContext.registerJavaFunction("javaStringLength",
221226
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
@@ -225,12 +230,17 @@ def registerJavaFunction(self, name, javaClassName, returnType=None):
225230
... "test.org.apache.spark.sql.JavaStringLength")
226231
>>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
227232
[Row(UDF:javaStringLength2(test)=4)]
233+
>>> sqlContext.registerJavaFunction("javaRand",
234+
... "test.org.apache.spark.sql.randUDFTest", DoubleType(), deterministic=False)
235+
>>> sqlContext.sql("SELECT javaRand(3)").collect()
236+
[Row(UDF:javaRand(test)=4)]
228237
229238
"""
230239
jdt = None
231240
if returnType is not None:
232241
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
233-
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
242+
self.sparkSession._jsparkSession.udf().registerJava(
243+
name, javaClassName, jdt, deterministic, distinctLike)
234244

235245
# TODO(andrew): delete this once we refactor things to take in SparkSession
236246
def _inferSchema(self, rdd, samplingRatio=None):

0 commit comments

Comments
 (0)