From 9edc5e939140e886959a4c48f59735036577578a Mon Sep 17 00:00:00 2001 From: zero323 Date: Tue, 2 May 2017 14:04:32 +0200 Subject: [PATCH 1/4] Return UDF from udf.register --- python/pyspark/sql/catalog.py | 10 +++++++--- python/pyspark/sql/context.py | 12 ++++++++---- python/pyspark/sql/functions.py | 23 ++++++++++++++--------- python/pyspark/sql/tests.py | 10 ++++++++++ 4 files changed, 39 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 41e68a45a615..815763d51e30 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -238,22 +238,26 @@ def registerFunction(self, name, f, returnType=StringType()): :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object - >>> spark.catalog.registerFunction("stringLengthString", lambda x: len(x)) + >>> strlen = spark.catalog.registerFunction("stringLengthString", len) >>> spark.sql("SELECT stringLengthString('test')").collect() [Row(stringLengthString(test)=u'4')] + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + >>> from pyspark.sql.types import IntegerType - >>> spark.catalog.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> strlen = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType - >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> strlen = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ udf = UserDefinedFunction(f, returnType, name) self._jsparkSession.udf().registerPython(name, udf._judf) + return udf._wrapped() @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index fdb7abbad4e5..6b8ce9b3dc9b 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -185,22 +185,26 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object + :return a wrapped :class:`UserDefinedFunction` - >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) + >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() [Row(stringLengthString(test)=u'4')] + >>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + >>> from pyspark.sql.types import IntegerType - >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> strlen = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType - >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> strlen = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ - self.sparkSession.catalog.registerFunction(name, f, returnType) + return self.sparkSession.catalog.registerFunction(name, f, returnType) @ignore_unicode_prefix @since(2.1) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 843ae3816f06..8b3487c3f108 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1917,6 +1917,19 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + def _wrapped(self): + """ + Wrap this udf with a function and attach docstring from func + """ + @functools.wraps(self.func) + def wrapper(*args): + return self(*args) + + wrapper.func = self.func + wrapper.returnType = self.returnType + + return wrapper + @since(1.3) def udf(f=None, returnType=StringType()): @@ -1951,15 +1964,7 @@ def udf(f=None, returnType=StringType()): """ def _udf(f, returnType=StringType()): udf_obj = UserDefinedFunction(f, returnType) - - @functools.wraps(f) - def wrapper(*args): - return udf_obj(*args) - - wrapper.func = udf_obj.func - wrapper.returnType = udf_obj.returnType - - return wrapper + return udf_obj._wrapped() # decorator @udf, @udf() or @udf(dataType()) if f is None or isinstance(f, (str, DataType)): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ce4abf8fb7e5..cd5d0cd4ffc8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -436,6 +436,15 @@ def test_udf_with_order_by_and_limit(self): res.explain(True) self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_udf_registration_returns_udf(self): + df = self.spark.range(10) + add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_three(id) AS plus_three").collect(), + df.select(add_three("id").alias("plus_three")).collect() + ) + def test_wholefile_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", @@ -615,6 +624,7 @@ def f(x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) From e8615c20d004eee896cac81b794609015eb46e5d Mon Sep 17 00:00:00 2001 From: zero323 Date: Tue, 2 May 2017 14:32:52 +0200 Subject: [PATCH 2/4] Add return annotation to catalog.registerFunction and fix style --- python/pyspark/sql/catalog.py | 1 + python/pyspark/sql/context.py | 2 +- python/pyspark/sql/tests.py | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 815763d51e30..51df93aeb076 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -237,6 +237,7 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object + :return a wrapped :class:`UserDefinedFunction` >>> strlen = spark.catalog.registerFunction("stringLengthString", len) >>> spark.sql("SELECT stringLengthString('test')").collect() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 6b8ce9b3dc9b..f4bcf0f8dc44 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -185,7 +185,7 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object - :return a wrapped :class:`UserDefinedFunction` + :return a wrapped :class:`UserDefinedFunction` >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cd5d0cd4ffc8..6138c0ff4101 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -624,7 +624,6 @@ def f(x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) - def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) From 18c025332318f547902888db981e5c7c25c1ea72 Mon Sep 17 00:00:00 2001 From: zero323 Date: Tue, 2 May 2017 15:08:29 +0200 Subject: [PATCH 3/4] Add missing : --- python/pyspark/sql/catalog.py | 2 +- python/pyspark/sql/context.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 51df93aeb076..fc5f2ed8acbe 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -237,7 +237,7 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object - :return a wrapped :class:`UserDefinedFunction` + :return: a wrapped :class:`UserDefinedFunction` >>> strlen = spark.catalog.registerFunction("stringLengthString", len) >>> spark.sql("SELECT stringLengthString('test')").collect() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index f4bcf0f8dc44..4a75fb2d5c41 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -185,7 +185,7 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object - :return a wrapped :class:`UserDefinedFunction` + :return: a wrapped :class:`UserDefinedFunction` >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() From bce80c413832738034224782227f73d5b9151625 Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 3 May 2017 14:45:56 +0200 Subject: [PATCH 4/4] Discard returned udfs when not used in tests --- python/pyspark/sql/catalog.py | 4 ++-- python/pyspark/sql/context.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index fc5f2ed8acbe..5f25dce16196 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -247,12 +247,12 @@ def registerFunction(self, name, f, returnType=StringType()): [Row(stringLengthString(text)=u'3')] >>> from pyspark.sql.types import IntegerType - >>> strlen = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) + >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType - >>> strlen = spark.udf.register("stringLengthInt", len, IntegerType()) + >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 4a75fb2d5c41..5197a9e00461 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -195,12 +195,12 @@ def registerFunction(self, name, f, returnType=StringType()): [Row(stringLengthString(text)=u'3')] >>> from pyspark.sql.types import IntegerType - >>> strlen = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType - >>> strlen = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """