Skip to content

Commit d206b7c

Browse files
committed
Print warning message instead of throwing exception.
1 parent 2fa15bd commit d206b7c

File tree

2 files changed

+17
-43
lines changed

2 files changed

+17
-43
lines changed

python/pyspark/sql/tests.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4637,44 +4637,6 @@ def foofoo(x, y):
46374637
).collect
46384638
)
46394639

4640-
def test_pandas_udf_when_input_has_none(self):
4641-
import math
4642-
from pyspark.sql.functions import pandas_udf
4643-
import pandas as pd
4644-
4645-
values = [1.0] * 10 + [None] * 10 + [2.0] * 10
4646-
pdf = pd.DataFrame({'A': values})
4647-
df = self.spark.createDataFrame(pdf).repartition(1)
4648-
4649-
@pandas_udf(returnType=DoubleType())
4650-
def gt_2_double(column):
4651-
return (column >= 2).where(column.notnull())
4652-
4653-
# This pandas udf returns Pandas.Series of dtype as float64.
4654-
# If we define the pandas udf with incorrect data type BooleanType,
4655-
# we should see an exception.
4656-
@pandas_udf(returnType=BooleanType())
4657-
def gt_2_boolean(column):
4658-
return (column >= 2).where(column.notnull())
4659-
4660-
udf_double = df.select(['A']).withColumn('udf', gt_2_double('A'))
4661-
udf_boolean = df.select(['A']).withColumn('udf', gt_2_boolean('A'))
4662-
4663-
result = udf_double.collect()
4664-
result_part1 = [x[1] for x in result if x[0] == 1.0]
4665-
self.assertEqual(set(result_part1), set([0.0]))
4666-
result_part2 = [x[1] for x in result if x[0] == 2.0]
4667-
self.assertEqual(set(result_part2), set([1.0]))
4668-
result_part3 = [x[1] for x in result if math.isnan(x[0])]
4669-
self.assertEqual(set(result_part3), set([None]))
4670-
4671-
with QuietTest(self.sc):
4672-
with self.assertRaisesRegexp(Exception, "Return Pandas.Series of the user-defined " +
4673-
"function's dtype is float64 which doesn't " +
4674-
"match the arrow type bool of defined type " +
4675-
"BooleanType"):
4676-
udf_boolean.collect()
4677-
46784640

46794641
@unittest.skipIf(
46804642
not _have_pandas or not _have_pyarrow,

python/pyspark/worker.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,23 @@ def verify_result_length(*a):
9595

9696
# Ensure return type of Pandas.Series matches the arrow return type of the user-defined
9797
# function. Otherwise, we may produce incorrect serialized data.
98-
arrow_type_of_result = pa.from_numpy_dtype(result.dtype)
99-
if arrow_return_type != arrow_type_of_result:
100-
raise TypeError("Return Pandas.Series of the user-defined function's dtype is %s "
101-
"which doesn't match the arrow type %s "
102-
"of defined type %s" % (result.dtype, arrow_return_type, return_type))
98+
# Note: for timestamp type, we only need to ensure both types are timestamp because the
99+
# serializer will do conversion.
100+
try:
101+
arrow_type_of_result = pa.from_numpy_dtype(result.dtype)
102+
both_are_timestamp = pa.types.is_timestamp(arrow_type_of_result) and \
103+
pa.types.is_timestamp(arrow_return_type)
104+
if not both_are_timestamp and arrow_return_type != arrow_type_of_result:
105+
print("WARN: Arrow type %s of return Pandas.Series of the user-defined function's "
106+
"dtype %s doesn't match the arrow type %s "
107+
"of defined return type %s" % (arrow_type_of_result, result.dtype,
108+
arrow_return_type, return_type),
109+
file=sys.stderr)
110+
except:
111+
print("WARN: Can't infer arrow type of Pandas.Series's dtype: %s, which might not match "
112+
"the arrow type %s of defined return type %s" % (result.dtype, arrow_return_type,
113+
return_type),
114+
file=sys.stderr)
103115

104116
return result
105117

0 commit comments

Comments
 (0)