@@ -3372,6 +3372,31 @@ def test_schema_conversion_roundtrip(self):
33723372 schema_rt = from_arrow_schema (arrow_schema )
33733373 self .assertEquals (self .schema , schema_rt )
33743374
3375+ def test_createDataFrame_with_array_type (self ):
3376+ import pandas as pd
3377+ pdf = pd .DataFrame ({"a" : [[1 , 2 ], [3 , 4 ]], "b" : [[u"x" , u"y" ], [u"y" , u"z" ]]})
3378+ df , df_arrow = self ._createDataFrame_toggle (pdf )
3379+ result = df .collect ()
3380+ result_arrow = df_arrow .collect ()
3381+ expected = [tuple (list (e ) for e in rec ) for rec in pdf .to_records (index = False )]
3382+ for r in range (len (expected )):
3383+ for e in range (len (expected [r ])):
3384+ self .assertTrue (expected [r ][e ] == result_arrow [r ][e ] and
3385+ result [r ][e ] == result_arrow [r ][e ])
3386+
3387+ def test_toPandas_with_array_type (self ):
3388+ expected = [([1 , 2 ], [u"x" , u"y" ]), ([3 , 4 ], [u"y" , u"z" ])]
3389+ array_schema = StructType ([StructField ("a" , ArrayType (IntegerType ())),
3390+ StructField ("b" , ArrayType (StringType ()))])
3391+ df = self .spark .createDataFrame (expected , schema = array_schema )
3392+ pdf , pdf_arrow = self ._toPandas_arrow_toggle (df )
3393+ result = [tuple (list (e ) for e in rec ) for rec in pdf .to_records (index = False )]
3394+ result_arrow = [tuple (list (e ) for e in rec ) for rec in pdf_arrow .to_records (index = False )]
3395+ for r in range (len (expected )):
3396+ for e in range (len (expected [r ])):
3397+ self .assertTrue (expected [r ][e ] == result_arrow [r ][e ] and
3398+ result [r ][e ] == result_arrow [r ][e ])
3399+
33753400
33763401@unittest .skipIf (not _have_pandas or not _have_arrow , "Pandas or Arrow not installed" )
33773402class PandasUDFTests (ReusedSQLTestCase ):
@@ -3651,6 +3676,24 @@ def test_vectorized_udf_datatype_string(self):
36513676 bool_f (col ('bool' )))
36523677 self .assertEquals (df .collect (), res .collect ())
36533678
3679+ def test_vectorized_udf_array_type (self ):
3680+ from pyspark .sql .functions import pandas_udf , col
3681+ data = [([1 , 2 ],), ([3 , 4 ],)]
3682+ array_schema = StructType ([StructField ("array" , ArrayType (IntegerType ()))])
3683+ df = self .spark .createDataFrame (data , schema = array_schema )
3684+ array_f = pandas_udf (lambda x : x , ArrayType (IntegerType ()))
3685+ result = df .select (array_f (col ('array' )))
3686+ self .assertEquals (df .collect (), result .collect ())
3687+
3688+ def test_vectorized_udf_null_array (self ):
3689+ from pyspark .sql .functions import pandas_udf , col
3690+ data = [([1 , 2 ],), (None ,), (None ,), ([3 , 4 ],), (None ,)]
3691+ array_schema = StructType ([StructField ("array" , ArrayType (IntegerType ()))])
3692+ df = self .spark .createDataFrame (data , schema = array_schema )
3693+ array_f = pandas_udf (lambda x : x , ArrayType (IntegerType ()))
3694+ result = df .select (array_f (col ('array' )))
3695+ self .assertEquals (df .collect (), result .collect ())
3696+
36543697 def test_vectorized_udf_complex (self ):
36553698 from pyspark .sql .functions import pandas_udf , col , expr
36563699 df = self .spark .range (10 ).select (
@@ -3705,7 +3748,7 @@ def test_vectorized_udf_chained(self):
37053748 def test_vectorized_udf_wrong_return_type (self ):
37063749 from pyspark .sql .functions import pandas_udf , col
37073750 df = self .spark .range (10 )
3708- f = pandas_udf (lambda x : x * 1.0 , ArrayType ( LongType ()))
3751+ f = pandas_udf (lambda x : x * 1.0 , MapType ( LongType (), LongType ()))
37093752 with QuietTest (self .sc ):
37103753 with self .assertRaisesRegexp (Exception , 'Unsupported.*type.*conversion' ):
37113754 df .select (f (col ('id' ))).collect ()
@@ -4009,7 +4052,7 @@ def test_wrong_return_type(self):
40094052
40104053 foo = pandas_udf (
40114054 lambda pdf : pdf ,
4012- 'id long, v array< int>' ,
4055+ 'id long, v map<int, int>' ,
40134056 PandasUDFType .GROUP_MAP
40144057 )
40154058
0 commit comments