Skip to content

Commit 07bccca

Browse files
committed
Add support for dtypes as returnType
1 parent f109afb commit 07bccca

File tree

4 files changed

+63
-7
lines changed

4 files changed

+63
-7
lines changed

python/pyspark/sql/functions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pyspark import since, SparkContext
2929
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
3030
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
31-
from pyspark.sql.types import StringType, DataType, _parse_datatype_string
31+
from pyspark.sql.types import StringType, DataType, _parse_datatype_string, from_pandas_dtypes
3232
from pyspark.sql.column import Column, _to_java_column, _to_seq
3333
from pyspark.sql.dataframe import DataFrame
3434

@@ -2207,6 +2207,10 @@ def pandas_udf(f=None, returnType=StringType()):
22072207
| 8| JOHN DOE| 22|
22082208
+----------+--------------+------------+
22092209
"""
2210+
import pandas as pd
2211+
if isinstance(returnType, pd.Series):
2212+
returnType = from_pandas_dtypes(returnType)
2213+
22102214
return _create_udf(f, returnType=returnType, vectorized=True)
22112215

22122216

python/pyspark/sql/tests.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3395,9 +3395,16 @@ def assertFramesEqual(self, expected, result):
33953395
("\n\nResult:\n%s\n%s" % (result, result.dtypes)))
33963396
self.assertTrue(expected.equals(result), msg=msg)
33973397

3398-
def test_groupby_apply(self):
3398+
@property
3399+
def data(self):
33993400
from pyspark.sql.functions import pandas_udf, array, explode, col, lit
3400-
df = self.spark.range(10).toDF('id').withColumn("vs", array([lit(i) for i in range(20, 30)])).withColumn("v", explode(col('vs'))).drop('vs')
3401+
return self.spark.range(10).toDF('id') \
3402+
.withColumn("vs", array([lit(i) for i in range(20, 30)])) \
3403+
.withColumn("v", explode(col('vs'))).drop('vs')
3404+
3405+
def test_groupby_apply_simple(self):
3406+
from pyspark.sql.functions import pandas_udf
3407+
df = self.data
34013408

34023409
def foo(df):
34033410
ret = df
@@ -3417,6 +3424,26 @@ def foo(df):
34173424
expected = df.toPandas().groupby('id').apply(foo).reset_index(drop=True)
34183425
self.assertFramesEqual(expected, result)
34193426

3427+
def test_groupby_apply_dtypes(self):
3428+
from pyspark.sql.functions import pandas_udf
3429+
df = self.data
3430+
3431+
def foo(df):
3432+
ret = df
3433+
ret = ret.assign(v3=df.v * 5.0 + 1)
3434+
return ret
3435+
3436+
sample_df = df.filter(df.id == 1).toPandas()
3437+
3438+
foo_udf = pandas_udf(
3439+
foo,
3440+
foo(sample_df).dtypes
3441+
)
3442+
3443+
result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
3444+
expected = df.toPandas().groupby('id').apply(foo).reset_index(drop=True)
3445+
self.assertFramesEqual(expected, result)
3446+
34203447

34213448
if __name__ == "__main__":
34223449
from pyspark.sql.tests import *

python/pyspark/sql/types.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1597,7 +1597,7 @@ def convert(self, obj, gateway_client):
15971597
register_input_converter(DateConverter())
15981598

15991599

1600-
def toArrowType(dt):
1600+
def to_arrow_type(dt):
16011601
""" Convert Spark data type to pyarrow type
16021602
"""
16031603
import pyarrow as pa
@@ -1623,6 +1623,31 @@ def toArrowType(dt):
16231623
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
16241624
return arrow_type
16251625

1626+
def from_pandas_type(dt):
1627+
""" Convert pandas data type to Spark data type
1628+
"""
1629+
import pandas as pd
1630+
import numpy as np
1631+
if dt == np.int32:
1632+
return IntegerType()
1633+
elif dt == np.int64:
1634+
return LongType()
1635+
elif dt == np.float32:
1636+
return FloatType()
1637+
elif dt == np.float64:
1638+
return DoubleType()
1639+
elif dt == np.object:
1640+
return StringType()
1641+
elif dt == np.dtype('datetime64[ns]') or type(dt) == pd.api.types.DatetimeTZDtype:
1642+
return TimestampType()
1643+
else:
1644+
raise ValueError("Unsupported numpy type in conversion to Spark: {}".format(dt))
1645+
1646+
def from_pandas_dtypes(dtypes):
1647+
""" Convert pandas DataFrame dtypes to Spark schema
1648+
"""
1649+
return StructType([StructField(dtypes.axes[0][i], from_pandas_type(dtypes[i]))
1650+
for i in range(len(dtypes))])
16261651

16271652
def _test():
16281653
import doctest

python/pyspark/worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pyspark.serializers import write_with_length, write_int, read_long, \
3333
write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \
3434
BatchedSerializer, ArrowStreamPandasSerializer
35-
from pyspark.sql.types import toArrowType
35+
from pyspark.sql.types import to_arrow_type
3636
from pyspark import shuffle
3737
from pyspark.sql.types import StructType, IntegerType, LongType, FloatType, DoubleType
3838

@@ -76,7 +76,7 @@ def wrap_udf(f, return_type):
7676

7777
def wrap_pandas_udf(f, return_type):
7878
if isinstance(return_type, StructType):
79-
arrow_return_types = list(toArrowType(field.dataType) for field in return_type)
79+
arrow_return_types = list(to_arrow_type(field.dataType) for field in return_type)
8080

8181
def fn(*a):
8282
import pandas as pd
@@ -89,7 +89,7 @@ def fn(*a):
8989
return fn
9090

9191
else:
92-
arrow_return_type = toArrowType(return_type)
92+
arrow_return_type = to_arrow_type(return_type)
9393

9494
def verify_result_length(*a):
9595
result = f(*a)

0 commit comments

Comments
 (0)