From 2fa15bda48ba64a102f114dc9119cb3c310200c4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 26 Sep 2018 09:01:40 +0000 Subject: [PATCH 1/5] Ensure return type of Pandas.Series matches the arrow return type of pandas udf. --- python/pyspark/sql/tests.py | 38 +++++++++++++++++++++++++++++++++++++ python/pyspark/worker.py | 10 ++++++++++ 2 files changed, 48 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 74642d46d1cd..f498b13d0955 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4637,6 +4637,44 @@ def foofoo(x, y): ).collect ) + def test_pandas_udf_when_input_has_none(self): + import math + from pyspark.sql.functions import pandas_udf + import pandas as pd + + values = [1.0] * 10 + [None] * 10 + [2.0] * 10 + pdf = pd.DataFrame({'A': values}) + df = self.spark.createDataFrame(pdf).repartition(1) + + @pandas_udf(returnType=DoubleType()) + def gt_2_double(column): + return (column >= 2).where(column.notnull()) + + # This pandas udf returns Pandas.Series of dtype as float64. + # If we define the pandas udf with incorrect data type BooleanType, + # we should see an exception. + @pandas_udf(returnType=BooleanType()) + def gt_2_boolean(column): + return (column >= 2).where(column.notnull()) + + udf_double = df.select(['A']).withColumn('udf', gt_2_double('A')) + udf_boolean = df.select(['A']).withColumn('udf', gt_2_boolean('A')) + + result = udf_double.collect() + result_part1 = [x[1] for x in result if x[0] == 1.0] + self.assertEqual(set(result_part1), set([0.0])) + result_part2 = [x[1] for x in result if x[0] == 2.0] + self.assertEqual(set(result_part2), set([1.0])) + result_part3 = [x[1] for x in result if math.isnan(x[0])] + self.assertEqual(set(result_part3), set([None])) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "Return Pandas.Series of the user-defined " + + "function's dtype is float64 which doesn't " + + "match the arrow type bool of defined type " + + "BooleanType"): + udf_boolean.collect() + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8c59f1f999f1..e0c6bb8d27d4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -84,6 +84,7 @@ def wrap_scalar_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def verify_result_length(*a): + import pyarrow as pa result = f(*a) if not hasattr(result, "__len__"): raise TypeError("Return type of the user-defined function should be " @@ -91,6 +92,15 @@ def verify_result_length(*a): if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " "expected %d, got %d" % (len(a[0]), len(result))) + + # Ensure return type of Pandas.Series matches the arrow return type of the user-defined + # function. Otherwise, we may produce incorrect serialized data. + arrow_type_of_result = pa.from_numpy_dtype(result.dtype) + if arrow_return_type != arrow_type_of_result: + raise TypeError("Return Pandas.Series of the user-defined function's dtype is %s " + "which doesn't match the arrow type %s " + "of defined type %s" % (result.dtype, arrow_return_type, return_type)) + return result return lambda *a: (verify_result_length(*a), arrow_return_type) From d206b7cf78f898e622f539a15e45515fcbd9e54a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 2 Oct 2018 05:29:44 +0000 Subject: [PATCH 2/5] Print warning message instead of throwing exception. --- python/pyspark/sql/tests.py | 38 ------------------------------------- python/pyspark/worker.py | 22 ++++++++++++++++----- 2 files changed, 17 insertions(+), 43 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f498b13d0955..74642d46d1cd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4637,44 +4637,6 @@ def foofoo(x, y): ).collect ) - def test_pandas_udf_when_input_has_none(self): - import math - from pyspark.sql.functions import pandas_udf - import pandas as pd - - values = [1.0] * 10 + [None] * 10 + [2.0] * 10 - pdf = pd.DataFrame({'A': values}) - df = self.spark.createDataFrame(pdf).repartition(1) - - @pandas_udf(returnType=DoubleType()) - def gt_2_double(column): - return (column >= 2).where(column.notnull()) - - # This pandas udf returns Pandas.Series of dtype as float64. - # If we define the pandas udf with incorrect data type BooleanType, - # we should see an exception. - @pandas_udf(returnType=BooleanType()) - def gt_2_boolean(column): - return (column >= 2).where(column.notnull()) - - udf_double = df.select(['A']).withColumn('udf', gt_2_double('A')) - udf_boolean = df.select(['A']).withColumn('udf', gt_2_boolean('A')) - - result = udf_double.collect() - result_part1 = [x[1] for x in result if x[0] == 1.0] - self.assertEqual(set(result_part1), set([0.0])) - result_part2 = [x[1] for x in result if x[0] == 2.0] - self.assertEqual(set(result_part2), set([1.0])) - result_part3 = [x[1] for x in result if math.isnan(x[0])] - self.assertEqual(set(result_part3), set([None])) - - with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, "Return Pandas.Series of the user-defined " + - "function's dtype is float64 which doesn't " + - "match the arrow type bool of defined type " + - "BooleanType"): - udf_boolean.collect() - @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e0c6bb8d27d4..509424357989 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -95,11 +95,23 @@ def verify_result_length(*a): # Ensure return type of Pandas.Series matches the arrow return type of the user-defined # function. Otherwise, we may produce incorrect serialized data. - arrow_type_of_result = pa.from_numpy_dtype(result.dtype) - if arrow_return_type != arrow_type_of_result: - raise TypeError("Return Pandas.Series of the user-defined function's dtype is %s " - "which doesn't match the arrow type %s " - "of defined type %s" % (result.dtype, arrow_return_type, return_type)) + # Note: for timestamp type, we only need to ensure both types are timestamp because the + # serializer will do conversion. + try: + arrow_type_of_result = pa.from_numpy_dtype(result.dtype) + both_are_timestamp = pa.types.is_timestamp(arrow_type_of_result) and \ + pa.types.is_timestamp(arrow_return_type) + if not both_are_timestamp and arrow_return_type != arrow_type_of_result: + print("WARN: Arrow type %s of return Pandas.Series of the user-defined function's " + "dtype %s doesn't match the arrow type %s " + "of defined return type %s" % (arrow_type_of_result, result.dtype, + arrow_return_type, return_type), + file=sys.stderr) + except: + print("WARN: Can't infer arrow type of Pandas.Series's dtype: %s, which might not match " + "the arrow type %s of defined return type %s" % (result.dtype, arrow_return_type, + return_type), + file=sys.stderr) return result From c084e745007d455a6ea99e10cc403b55ead6278d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 2 Oct 2018 08:56:29 +0000 Subject: [PATCH 3/5] Fix python style. --- python/pyspark/worker.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 509424357989..13327cc3d6d0 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -100,7 +100,7 @@ def verify_result_length(*a): try: arrow_type_of_result = pa.from_numpy_dtype(result.dtype) both_are_timestamp = pa.types.is_timestamp(arrow_type_of_result) and \ - pa.types.is_timestamp(arrow_return_type) + pa.types.is_timestamp(arrow_return_type) if not both_are_timestamp and arrow_return_type != arrow_type_of_result: print("WARN: Arrow type %s of return Pandas.Series of the user-defined function's " "dtype %s doesn't match the arrow type %s " @@ -108,9 +108,10 @@ def verify_result_length(*a): arrow_return_type, return_type), file=sys.stderr) except: - print("WARN: Can't infer arrow type of Pandas.Series's dtype: %s, which might not match " - "the arrow type %s of defined return type %s" % (result.dtype, arrow_return_type, - return_type), + print("WARN: Can't infer arrow type of Pandas.Series's dtype: %s, which might not " + "match the arrow type %s of defined return type %s" % (result.dtype, + arrow_return_type, + return_type), file=sys.stderr) return result From a756c0b40f74f35027d65a5c143bfa4b9f5f89fb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 4 Oct 2018 10:33:57 +0000 Subject: [PATCH 4/5] Add document. --- python/pyspark/sql/functions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1c3d9725b285..26dfdd6c975f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2909,6 +2909,11 @@ def pandas_udf(f=None, returnType=None, functionType=None): can fail on special rows, the workaround is to incorporate the condition into the functions. .. note:: The user-defined functions do not take keyword arguments on the calling side. + + .. note:: The data type of returned `pandas.Series` from the user-defined functions should be + matched with defined returnType. When there is mismatch between them, it is not guaranteed + that the conversion by SparkSQL during serialization is correct at all and users might get + unexpected results. """ # decorator @pandas_udf(returnType, functionType) is_decorator = f is None or isinstance(f, (str, DataType)) From 6c6f8a1abfc8150c0acf0b0c43fd8430d9f8e5c4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 6 Oct 2018 09:54:42 +0000 Subject: [PATCH 5/5] Remove warning message and modify code comments. --- python/pyspark/sql/functions.py | 7 ++++--- python/pyspark/worker.py | 23 ----------------------- 2 files changed, 4 insertions(+), 26 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 26dfdd6c975f..163a2577ce82 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2911,9 +2911,10 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. note:: The user-defined functions do not take keyword arguments on the calling side. .. note:: The data type of returned `pandas.Series` from the user-defined functions should be - matched with defined returnType. When there is mismatch between them, it is not guaranteed - that the conversion by SparkSQL during serialization is correct at all and users might get - unexpected results. + matched with defined returnType (see :meth:`types.to_arrow_type` and + :meth:`types.from_arrow_type`). When there is mismatch between them, Spark might do + conversion on returned data. The conversion is not guaranteed to be correct and results + should be checked for accuracy by users. """ # decorator @pandas_udf(returnType, functionType) is_decorator = f is None or isinstance(f, (str, DataType)) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 13327cc3d6d0..8c59f1f999f1 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -84,7 +84,6 @@ def wrap_scalar_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def verify_result_length(*a): - import pyarrow as pa result = f(*a) if not hasattr(result, "__len__"): raise TypeError("Return type of the user-defined function should be " @@ -92,28 +91,6 @@ def verify_result_length(*a): if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " "expected %d, got %d" % (len(a[0]), len(result))) - - # Ensure return type of Pandas.Series matches the arrow return type of the user-defined - # function. Otherwise, we may produce incorrect serialized data. - # Note: for timestamp type, we only need to ensure both types are timestamp because the - # serializer will do conversion. - try: - arrow_type_of_result = pa.from_numpy_dtype(result.dtype) - both_are_timestamp = pa.types.is_timestamp(arrow_type_of_result) and \ - pa.types.is_timestamp(arrow_return_type) - if not both_are_timestamp and arrow_return_type != arrow_type_of_result: - print("WARN: Arrow type %s of return Pandas.Series of the user-defined function's " - "dtype %s doesn't match the arrow type %s " - "of defined return type %s" % (arrow_type_of_result, result.dtype, - arrow_return_type, return_type), - file=sys.stderr) - except: - print("WARN: Can't infer arrow type of Pandas.Series's dtype: %s, which might not " - "match the arrow type %s of defined return type %s" % (result.dtype, - arrow_return_type, - return_type), - file=sys.stderr) - return result return lambda *a: (verify_result_length(*a), arrow_return_type)