From 8280f424138293e4fd411a6ef11136caec987ef7 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sat, 7 Jan 2017 15:30:01 +0100 Subject: [PATCH 1/3] Add udf decorator This PR adds `udf` decorator syntax as proposed in [SPARK-19160](https://issues.apache.org/jira/browse/SPARK-19160). This allows users to define UDF using simplified syntax: ``` from pyspark.sql.decorators import udf @udf(IntegerType()) def add_one(x): """Adds one""" if x is not None: return x + 1 ``` without need to define a separate function and udf. Tested wiht existing unit tests to ensure backward compatibility and additional unit tests covering new functionality. --- python/pyspark/sql/tests.py | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d8b7b3137c1c..3cae5bcd69bd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -266,9 +266,6 @@ def test_explode(self): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") - with self.assertRaises(ValueError): - data.select(explode(data.mapfield).alias("a", "b", metadata={'max': 99})).count() - def test_and_in_expression(self): self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) @@ -543,32 +540,23 @@ def substr(x, start, end): if x is not None: return x[start:end] - @udf("long") - def trunc(x): - return int(x) - - @udf(returnType="double") - def as_double(x): - return float(x) - df = ( self.spark .createDataFrame( - [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float")) + [(1, "Foo", "foobar")], ("one", "Foo", "foobar")) .select( add_one("one"), add_two("one"), to_upper("Foo"), to_lower("Foo"), - substr("foobar", lit(0), lit(3)), - trunc("float"), as_double("one"))) + substr("foobar", lit(0), lit(3)))) self.assertListEqual( [tpe for _, tpe in df.dtypes], - ["int", "double", "string", "string", "string", "bigint", "double"] + ["int", "double", "string", "string", "string"] ) self.assertListEqual( list(df.first()), - [2, 3.0, "FOO", "foo", "foo", 3, 1.0] + [2, 3.0, "FOO", "foo", "foo"] ) def test_basic_functions(self): @@ -955,13 +943,6 @@ def test_column_select(self): self.assertEqual(self.testData, df.select(df.key, df.value).collect()) self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) - def test_column_alias_metadata(self): - df = self.df - df_with_meta = df.select(df.key.alias('pk', metadata={'label': 'Primary Key'})) - self.assertEqual(df_with_meta.schema['pk'].metadata['label'], 'Primary Key') - with self.assertRaises(AssertionError): - df.select(df.key.alias('pk', metdata={'label': 'Primary Key'})) - def test_freqItems(self): vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)] df = self.sc.parallelize(vals).toDF() From af4a11c4c025391a7eb26b97d07c58da8146740f Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 13 Feb 2017 19:59:52 +0100 Subject: [PATCH 2/3] Add tests for udf decorator with str returnType --- python/pyspark/sql/tests.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3cae5bcd69bd..b2a2f5c8927f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -540,23 +540,32 @@ def substr(x, start, end): if x is not None: return x[start:end] + @udf("long") + def trunc(x): + return int(x) + + @udf(returnType="double") + def as_double(x): + return float(x) + df = ( self.spark .createDataFrame( - [(1, "Foo", "foobar")], ("one", "Foo", "foobar")) + [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float")) .select( add_one("one"), add_two("one"), to_upper("Foo"), to_lower("Foo"), - substr("foobar", lit(0), lit(3)))) + substr("foobar", lit(0), lit(3)), + trunc("float"), as_double("one"))) self.assertListEqual( [tpe for _, tpe in df.dtypes], - ["int", "double", "string", "string", "string"] + ["int", "double", "string", "string", "string", "bigint", "double"] ) self.assertListEqual( list(df.first()), - [2, 3.0, "FOO", "foo", "foo"] + [2, 3.0, "FOO", "foo", "foo", 3, 1.0] ) def test_basic_functions(self): From 64bba41fe062dc39ad8708fa4dd825e609254814 Mon Sep 17 00:00:00 2001 From: zero323 Date: Tue, 14 Feb 2017 22:43:39 +0100 Subject: [PATCH 3/3] Wrap UserDefinedFunction object with a function and preserve dosctring --- python/pyspark/sql/functions.py | 11 ++++++++++- python/pyspark/sql/tests.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d2617203140f..426a4a8c93a6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1940,7 +1940,16 @@ def udf(f=None, returnType=StringType()): +----------+--------------+------------+ """ def _udf(f, returnType=StringType()): - return UserDefinedFunction(f, returnType) + udf_obj = UserDefinedFunction(f, returnType) + + @functools.wraps(f) + def wrapper(*args): + return udf_obj(*args) + + wrapper.func = udf_obj.func + wrapper.returnType = udf_obj.returnType + + return wrapper # decorator @udf, @udf() or @udf(dataType()) if f is None or isinstance(f, (str, DataType)): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b2a2f5c8927f..319c88f0b3fd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -568,6 +568,21 @@ def as_double(x): [2, 3.0, "FOO", "foo", "foo", 3, 1.0] ) + def test_udf_wrapper(self): + from pyspark.sql.functions import udf + from pyspark.sql.types import IntegerType + + def f(x): + """Identity""" + return x + + return_type = IntegerType() + f_ = udf(f, return_type) + + self.assertTrue(f.__doc__ in f_.__doc__) + self.assertEqual(f, f_.func) + self.assertEqual(return_type, f_.returnType) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd)