Skip to content

Commit 43f5e40

Browse files
HyukjinKwongatorsmile
authored andcommitted
[SPARK-23352][PYTHON][BRANCH-2.3] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request? This PR backports #20531: It explicitly specifies supported types in Pandas UDFs. The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things. 1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see: ```python from pyspark.sql.functions import pandas_udf pudf = pandas_udf(lambda x: x, "binary") df = spark.createDataFrame([[bytearray(1)]]) df.select(pudf("_1")).show() ``` ``` ... TypeError: Unsupported type in conversion to Arrow: BinaryType ``` We can document this behaviour for its guide. 2. Since we can check the return type ahead, we can fail fast before actual execution. ```python # we can fail fast at this stage because we know the schema ahead pandas_udf(lambda x: x, BinaryType()) ``` ## How was this patch tested? Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added. Author: hyukjinkwon <[email protected]> Closes #20588 from HyukjinKwon/PR_TOOL_PICK_PR_20531_BRANCH-2.3.
1 parent befb22d commit 43f5e40

File tree

5 files changed

+77
-44
lines changed

5 files changed

+77
-44
lines changed

docs/sql-programming-guide.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,7 +1676,7 @@ Using the above optimizations with Arrow will produce the same results as when A
16761676
enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the
16771677
DataFrame to the driver program and should be done on a small subset of the data. Not all Spark
16781678
data types are currently supported and an error can be raised if a column has an unsupported type,
1679-
see [Supported Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`,
1679+
see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`,
16801680
Spark will fall back to create the DataFrame without Arrow.
16811681

16821682
## Pandas UDFs (a.k.a. Vectorized UDFs)
@@ -1734,7 +1734,7 @@ For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/p
17341734

17351735
### Supported SQL Types
17361736

1737-
Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`,
1737+
Currently, all Spark SQL data types are supported by Arrow-based conversion except `BinaryType`, `MapType`,
17381738
`ArrayType` of `TimestampType`, and nested `StructType`.
17391739

17401740
### Setting Arrow Batch Size

python/pyspark/sql/tests.py

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3736,10 +3736,10 @@ def foo(x):
37363736
self.assertEqual(foo.returnType, schema)
37373737
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
37383738

3739-
@pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR)
3739+
@pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR)
37403740
def foo(x):
37413741
return x
3742-
self.assertEqual(foo.returnType, schema)
3742+
self.assertEqual(foo.returnType, DoubleType())
37433743
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
37443744

37453745
@pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
@@ -3776,7 +3776,7 @@ def zero_with_type():
37763776
@pandas_udf(returnType=PandasUDFType.GROUPED_MAP)
37773777
def foo(df):
37783778
return df
3779-
with self.assertRaisesRegexp(ValueError, 'Invalid returnType'):
3779+
with self.assertRaisesRegexp(TypeError, 'Invalid returnType'):
37803780
@pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP)
37813781
def foo(df):
37823782
return df
@@ -3825,15 +3825,16 @@ def random_udf(v):
38253825
return random_udf
38263826

38273827
def test_vectorized_udf_basic(self):
3828-
from pyspark.sql.functions import pandas_udf, col
3828+
from pyspark.sql.functions import pandas_udf, col, array
38293829
df = self.spark.range(10).select(
38303830
col('id').cast('string').alias('str'),
38313831
col('id').cast('int').alias('int'),
38323832
col('id').alias('long'),
38333833
col('id').cast('float').alias('float'),
38343834
col('id').cast('double').alias('double'),
38353835
col('id').cast('decimal').alias('decimal'),
3836-
col('id').cast('boolean').alias('bool'))
3836+
col('id').cast('boolean').alias('bool'),
3837+
array(col('id')).alias('array_long'))
38373838
f = lambda x: x
38383839
str_f = pandas_udf(f, StringType())
38393840
int_f = pandas_udf(f, IntegerType())
@@ -3842,10 +3843,11 @@ def test_vectorized_udf_basic(self):
38423843
double_f = pandas_udf(f, DoubleType())
38433844
decimal_f = pandas_udf(f, DecimalType())
38443845
bool_f = pandas_udf(f, BooleanType())
3846+
array_long_f = pandas_udf(f, ArrayType(LongType()))
38453847
res = df.select(str_f(col('str')), int_f(col('int')),
38463848
long_f(col('long')), float_f(col('float')),
38473849
double_f(col('double')), decimal_f('decimal'),
3848-
bool_f(col('bool')))
3850+
bool_f(col('bool')), array_long_f('array_long'))
38493851
self.assertEquals(df.collect(), res.collect())
38503852

38513853
def test_register_nondeterministic_vectorized_udf_basic(self):
@@ -4050,10 +4052,11 @@ def test_vectorized_udf_chained(self):
40504052
def test_vectorized_udf_wrong_return_type(self):
40514053
from pyspark.sql.functions import pandas_udf, col
40524054
df = self.spark.range(10)
4053-
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
40544055
with QuietTest(self.sc):
4055-
with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'):
4056-
df.select(f(col('id'))).collect()
4056+
with self.assertRaisesRegexp(
4057+
NotImplementedError,
4058+
'Invalid returnType.*scalar Pandas UDF.*MapType'):
4059+
pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
40574060

40584061
def test_vectorized_udf_return_scalar(self):
40594062
from pyspark.sql.functions import pandas_udf, col
@@ -4088,13 +4091,18 @@ def test_vectorized_udf_varargs(self):
40884091
self.assertEquals(df.collect(), res.collect())
40894092

40904093
def test_vectorized_udf_unsupported_types(self):
4091-
from pyspark.sql.functions import pandas_udf, col
4092-
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
4093-
df = self.spark.createDataFrame([(None,)], schema=schema)
4094-
f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
4094+
from pyspark.sql.functions import pandas_udf
40954095
with QuietTest(self.sc):
4096-
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
4097-
df.select(f(col('map'))).collect()
4096+
with self.assertRaisesRegexp(
4097+
NotImplementedError,
4098+
'Invalid returnType.*scalar Pandas UDF.*MapType'):
4099+
pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
4100+
4101+
with QuietTest(self.sc):
4102+
with self.assertRaisesRegexp(
4103+
NotImplementedError,
4104+
'Invalid returnType.*scalar Pandas UDF.*BinaryType'):
4105+
pandas_udf(lambda x: x, BinaryType())
40984106

40994107
def test_vectorized_udf_dates(self):
41004108
from pyspark.sql.functions import pandas_udf, col
@@ -4325,15 +4333,16 @@ def data(self):
43254333
.withColumn("vs", array([lit(i) for i in range(20, 30)])) \
43264334
.withColumn("v", explode(col('vs'))).drop('vs')
43274335

4328-
def test_simple(self):
4329-
from pyspark.sql.functions import pandas_udf, PandasUDFType
4330-
df = self.data
4336+
def test_supported_types(self):
4337+
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
4338+
df = self.data.withColumn("arr", array(col("id")))
43314339

43324340
foo_udf = pandas_udf(
43334341
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
43344342
StructType(
43354343
[StructField('id', LongType()),
43364344
StructField('v', IntegerType()),
4345+
StructField('arr', ArrayType(LongType())),
43374346
StructField('v1', DoubleType()),
43384347
StructField('v2', LongType())]),
43394348
PandasUDFType.GROUPED_MAP
@@ -4436,17 +4445,15 @@ def test_datatype_string(self):
44364445

44374446
def test_wrong_return_type(self):
44384447
from pyspark.sql.functions import pandas_udf, PandasUDFType
4439-
df = self.data
4440-
4441-
foo = pandas_udf(
4442-
lambda pdf: pdf,
4443-
'id long, v map<int, int>',
4444-
PandasUDFType.GROUPED_MAP
4445-
)
44464448

44474449
with QuietTest(self.sc):
4448-
with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'):
4449-
df.groupby('id').apply(foo).sort('id').toPandas()
4450+
with self.assertRaisesRegexp(
4451+
NotImplementedError,
4452+
'Invalid returnType.*grouped map Pandas UDF.*MapType'):
4453+
pandas_udf(
4454+
lambda pdf: pdf,
4455+
'id long, v map<int, int>',
4456+
PandasUDFType.GROUPED_MAP)
44504457

44514458
def test_wrong_args(self):
44524459
from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType
@@ -4465,23 +4472,30 @@ def test_wrong_args(self):
44654472
df.groupby('id').apply(
44664473
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
44674474
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
4468-
df.groupby('id').apply(
4469-
pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())])))
4475+
df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType()))
44704476
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'):
44714477
df.groupby('id').apply(
4472-
pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]),
4473-
PandasUDFType.SCALAR))
4478+
pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
44744479

44754480
def test_unsupported_types(self):
4476-
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
4481+
from pyspark.sql.functions import pandas_udf, PandasUDFType
44774482
schema = StructType(
44784483
[StructField("id", LongType(), True),
44794484
StructField("map", MapType(StringType(), IntegerType()), True)])
4480-
df = self.spark.createDataFrame([(1, None,)], schema=schema)
4481-
f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUPED_MAP)
44824485
with QuietTest(self.sc):
4483-
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
4484-
df.groupby('id').apply(f).collect()
4486+
with self.assertRaisesRegexp(
4487+
NotImplementedError,
4488+
'Invalid returnType.*grouped map Pandas UDF.*MapType'):
4489+
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
4490+
4491+
schema = StructType(
4492+
[StructField("id", LongType(), True),
4493+
StructField("arr_ts", ArrayType(TimestampType()), True)])
4494+
with QuietTest(self.sc):
4495+
with self.assertRaisesRegexp(
4496+
NotImplementedError,
4497+
'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
4498+
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
44854499

44864500
# Regression test for SPARK-23314
44874501
def test_timestamp_dst(self):

python/pyspark/sql/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,8 @@ def to_arrow_type(dt):
16381638
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
16391639
arrow_type = pa.timestamp('us', tz='UTC')
16401640
elif type(dt) == ArrayType:
1641+
if type(dt.elementType) == TimestampType:
1642+
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
16411643
arrow_type = pa.list_(to_arrow_type(dt.elementType))
16421644
else:
16431645
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
@@ -1680,6 +1682,8 @@ def from_arrow_type(at):
16801682
elif types.is_timestamp(at):
16811683
spark_type = TimestampType()
16821684
elif types.is_list(at):
1685+
if types.is_timestamp(at.value_type):
1686+
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
16831687
spark_type = ArrayType(from_arrow_type(at.value_type))
16841688
else:
16851689
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))

python/pyspark/sql/udf.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from pyspark import SparkContext, since
2323
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
2424
from pyspark.sql.column import Column, _to_java_column, _to_seq
25-
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string
25+
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string, \
26+
to_arrow_type, to_arrow_schema
2627

2728
__all__ = ["UDFRegistration"]
2829

@@ -109,10 +110,24 @@ def returnType(self):
109110
else:
110111
self._returnType_placeholder = _parse_datatype_string(self._returnType)
111112

112-
if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
113-
and not isinstance(self._returnType_placeholder, StructType):
114-
raise ValueError("Invalid returnType: returnType must be a StructType for "
115-
"pandas_udf with function type GROUPED_MAP")
113+
if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
114+
try:
115+
to_arrow_type(self._returnType_placeholder)
116+
except TypeError:
117+
raise NotImplementedError(
118+
"Invalid returnType with scalar Pandas UDFs: %s is "
119+
"not supported" % str(self._returnType_placeholder))
120+
elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
121+
if isinstance(self._returnType_placeholder, StructType):
122+
try:
123+
to_arrow_schema(self._returnType_placeholder)
124+
except TypeError:
125+
raise NotImplementedError(
126+
"Invalid returnType with grouped map Pandas UDFs: "
127+
"%s is not supported" % str(self._returnType_placeholder))
128+
else:
129+
raise TypeError("Invalid returnType for grouped map Pandas "
130+
"UDFs: returnType must be a StructType.")
116131

117132
return self._returnType_placeholder
118133

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,7 @@ object SQLConf {
10521052
"for use with pyspark.sql.DataFrame.toPandas, and " +
10531053
"pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " +
10541054
"The following data types are unsupported: " +
1055-
"MapType, ArrayType of TimestampType, and nested StructType.")
1055+
"BinaryType, MapType, ArrayType of TimestampType, and nested StructType.")
10561056
.booleanConf
10571057
.createWithDefault(false)
10581058

0 commit comments

Comments
 (0)