Skip to content

Commit 1c9f95c

Browse files
BryanCutlerHyukjinKwon
authored andcommitted
[SPARK-22530][PYTHON][SQL] Adding Arrow support for ArrayType
## What changes were proposed in this pull request? This change adds `ArrayType` support for working with Arrow in pyspark when creating a DataFrame, calling `toPandas()`, and using vectorized `pandas_udf`. ## How was this patch tested? Added new Python unit tests using Array data. Author: Bryan Cutler <[email protected]> Closes #20114 from BryanCutler/arrow-ArrayType-support-SPARK-22530.
1 parent c284c4e commit 1c9f95c

File tree

3 files changed

+61
-3
lines changed

3 files changed

+61
-3
lines changed

python/pyspark/sql/tests.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3372,6 +3372,31 @@ def test_schema_conversion_roundtrip(self):
33723372
schema_rt = from_arrow_schema(arrow_schema)
33733373
self.assertEquals(self.schema, schema_rt)
33743374

3375+
def test_createDataFrame_with_array_type(self):
3376+
import pandas as pd
3377+
pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
3378+
df, df_arrow = self._createDataFrame_toggle(pdf)
3379+
result = df.collect()
3380+
result_arrow = df_arrow.collect()
3381+
expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
3382+
for r in range(len(expected)):
3383+
for e in range(len(expected[r])):
3384+
self.assertTrue(expected[r][e] == result_arrow[r][e] and
3385+
result[r][e] == result_arrow[r][e])
3386+
3387+
def test_toPandas_with_array_type(self):
3388+
expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])]
3389+
array_schema = StructType([StructField("a", ArrayType(IntegerType())),
3390+
StructField("b", ArrayType(StringType()))])
3391+
df = self.spark.createDataFrame(expected, schema=array_schema)
3392+
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
3393+
result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
3394+
result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)]
3395+
for r in range(len(expected)):
3396+
for e in range(len(expected[r])):
3397+
self.assertTrue(expected[r][e] == result_arrow[r][e] and
3398+
result[r][e] == result_arrow[r][e])
3399+
33753400

33763401
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
33773402
class PandasUDFTests(ReusedSQLTestCase):
@@ -3651,6 +3676,24 @@ def test_vectorized_udf_datatype_string(self):
36513676
bool_f(col('bool')))
36523677
self.assertEquals(df.collect(), res.collect())
36533678

3679+
def test_vectorized_udf_array_type(self):
3680+
from pyspark.sql.functions import pandas_udf, col
3681+
data = [([1, 2],), ([3, 4],)]
3682+
array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
3683+
df = self.spark.createDataFrame(data, schema=array_schema)
3684+
array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()))
3685+
result = df.select(array_f(col('array')))
3686+
self.assertEquals(df.collect(), result.collect())
3687+
3688+
def test_vectorized_udf_null_array(self):
3689+
from pyspark.sql.functions import pandas_udf, col
3690+
data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)]
3691+
array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
3692+
df = self.spark.createDataFrame(data, schema=array_schema)
3693+
array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()))
3694+
result = df.select(array_f(col('array')))
3695+
self.assertEquals(df.collect(), result.collect())
3696+
36543697
def test_vectorized_udf_complex(self):
36553698
from pyspark.sql.functions import pandas_udf, col, expr
36563699
df = self.spark.range(10).select(
@@ -3705,7 +3748,7 @@ def test_vectorized_udf_chained(self):
37053748
def test_vectorized_udf_wrong_return_type(self):
37063749
from pyspark.sql.functions import pandas_udf, col
37073750
df = self.spark.range(10)
3708-
f = pandas_udf(lambda x: x * 1.0, ArrayType(LongType()))
3751+
f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
37093752
with QuietTest(self.sc):
37103753
with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'):
37113754
df.select(f(col('id'))).collect()
@@ -4009,7 +4052,7 @@ def test_wrong_return_type(self):
40094052

40104053
foo = pandas_udf(
40114054
lambda pdf: pdf,
4012-
'id long, v array<int>',
4055+
'id long, v map<int, int>',
40134056
PandasUDFType.GROUP_MAP
40144057
)
40154058

python/pyspark/sql/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,6 +1625,8 @@ def to_arrow_type(dt):
16251625
elif type(dt) == TimestampType:
16261626
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
16271627
arrow_type = pa.timestamp('us', tz='UTC')
1628+
elif type(dt) == ArrayType:
1629+
arrow_type = pa.list_(to_arrow_type(dt.elementType))
16281630
else:
16291631
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
16301632
return arrow_type
@@ -1665,6 +1667,8 @@ def from_arrow_type(at):
16651667
spark_type = DateType()
16661668
elif types.is_timestamp(at):
16671669
spark_type = TimestampType()
1670+
elif types.is_list(at):
1671+
spark_type = ArrayType(from_arrow_type(at.value_type))
16681672
else:
16691673
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
16701674
return spark_type

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ private abstract static class ArrowVectorAccessor {
326326
this.vector = vector;
327327
}
328328

329-
final boolean isNullAt(int rowId) {
329+
// TODO: should be final after removing ArrayAccessor workaround
330+
boolean isNullAt(int rowId) {
330331
return vector.isNull(rowId);
331332
}
332333

@@ -589,6 +590,16 @@ private static class ArrayAccessor extends ArrowVectorAccessor {
589590
this.accessor = vector;
590591
}
591592

593+
@Override
594+
final boolean isNullAt(int rowId) {
595+
// TODO: Workaround if vector has all non-null values, see ARROW-1948
596+
if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity() == 0) {
597+
return false;
598+
} else {
599+
return super.isNullAt(rowId);
600+
}
601+
}
602+
592603
@Override
593604
final int getArrayLength(int rowId) {
594605
return accessor.getInnerValueCountAt(rowId);

0 commit comments

Comments
 (0)