From 47b88734b91a7f9a4335bc3c667640eb4600b8e1 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 5 Feb 2018 18:30:20 +0900 Subject: [PATCH 1/3] Fix pandas_udf with return type StringType() to handle str type properly. --- python/pyspark/serializers.py | 3 +++ python/pyspark/sql/tests.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 88d6a191babca..1469a44dc1caf 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -230,6 +230,9 @@ def create_array(s, t): s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) + elif t is not None and pa.types.is_string(t) and sys.version < '3': + # TODO: need decode before converting to Arrow in Python 2 + return pa.Array.from_pandas(s.str.decode('utf-8'), mask=mask, type=t) return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b27363023ae77..ed303c9141de3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3920,6 +3920,14 @@ def test_vectorized_udf_null_string(self): res = df.select(str_f(col('str'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_string_in_udf(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), StringType()) + res = df.select(str_f(col('id'))) + self.assertEquals(df.select(col('id').cast('string')).collect(), res.collect()) + def test_vectorized_udf_datatype_string(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( From 06ae568df2088652754c2df66d2f78c8fbdac48d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 5 Feb 2018 23:07:00 +0900 Subject: [PATCH 2/3] Address comments. --- python/pyspark/sql/tests.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ed303c9141de3..1517757e15e53 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3924,9 +3924,10 @@ def test_vectorized_udf_string_in_udf(self): from pyspark.sql.functions import pandas_udf, col import pandas as pd df = self.spark.range(10) - str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), StringType()) - res = df.select(str_f(col('id'))) - self.assertEquals(df.select(col('id').cast('string')).collect(), res.collect()) + str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType()) + actual = df.select(str_f(col('id'))) + expected = df.select(col('id').cast('string')) + self.assertEquals(expected.collect(), actual.collect()) def test_vectorized_udf_datatype_string(self): from pyspark.sql.functions import pandas_udf, col From b3d5209b26322329d7e4ba1fd1b1457f86b44a8a Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Feb 2018 10:55:37 +0900 Subject: [PATCH 3/3] Address a comment. --- python/pyspark/serializers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 1469a44dc1caf..45f8290e5c17d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -232,7 +232,8 @@ def create_array(s, t): return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) elif t is not None and pa.types.is_string(t) and sys.version < '3': # TODO: need decode before converting to Arrow in Python 2 - return pa.Array.from_pandas(s.str.decode('utf-8'), mask=mask, type=t) + return pa.Array.from_pandas(s.apply( + lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series]