2828from pyspark .sql .dataframe import DataFrame
2929from pyspark .sql .readwriter import DataFrameReader
3030from pyspark .sql .streaming import DataStreamReader
31- from pyspark .sql .types import IntegerType , Row , StringType
31+ from pyspark .sql .types import DoubleType , IntegerType , Row , StringType
3232from pyspark .sql .utils import install_exception_handler
3333
3434__all__ = ["SQLContext" , "HiveContext" , "UDFRegistration" ]
@@ -208,14 +208,19 @@ def registerFunction(self, name, f, returnType=StringType()):
208208
209209 @ignore_unicode_prefix
210210 @since (2.1 )
211- def registerJavaFunction (self , name , javaClassName , returnType = None ):
211+ def registerJavaFunction (self , name , javaClassName , returnType = None , deterministic = True ,
212+ distinctLike = False ):
212213 """Register a java UDF so it can be used in SQL statements.
213214
214215 In addition to a name and the function itself, the return type can be optionally specified.
215216 When the return type is not specified we would infer it via reflection.
216217 :param name: name of the UDF
217218 :param javaClassName: fully qualified name of java class
218219 :param returnType: a :class:`pyspark.sql.types.DataType` object
220+ :param deterministic: A flag indicating if the UDF is deterministic. Deterministic UDF
221+ returns same result each time it is invoked with a particular input.
222+ :param distinctLike: a UDF is considered distinctLike if the UDF can be evaluated on just
223+ the distinct values of a column.
219224
220225 >>> sqlContext.registerJavaFunction("javaStringLength",
221226 ... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
@@ -225,12 +230,17 @@ def registerJavaFunction(self, name, javaClassName, returnType=None):
225230 ... "test.org.apache.spark.sql.JavaStringLength")
226231 >>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
227232 [Row(UDF:javaStringLength2(test)=4)]
233+ >>> sqlContext.registerJavaFunction("javaRand",
234+ ... "test.org.apache.spark.sql.randUDFTest", DoubleType(), deterministic=False)
235+ >>> sqlContext.sql("SELECT javaRand(3)").collect()
236+ [Row(UDF:javaRand(test)=4)]
228237
229238 """
230239 jdt = None
231240 if returnType is not None :
232241 jdt = self .sparkSession ._jsparkSession .parseDataType (returnType .json ())
233- self .sparkSession ._jsparkSession .udf ().registerJava (name , javaClassName , jdt )
242+ self .sparkSession ._jsparkSession .udf ().registerJava (
243+ name , javaClassName , jdt , deterministic , distinctLike )
234244
235245 # TODO(andrew): delete this once we refactor things to take in SparkSession
236246 def _inferSchema (self , rdd , samplingRatio = None ):
0 commit comments