@@ -3736,10 +3736,10 @@ def foo(x):
37363736 self .assertEqual (foo .returnType , schema )
37373737 self .assertEqual (foo .evalType , PythonEvalType .SQL_GROUPED_MAP_PANDAS_UDF )
37383738
3739- @pandas_udf (returnType = 'v double' , functionType = PandasUDFType .SCALAR )
3739+ @pandas_udf (returnType = 'double' , functionType = PandasUDFType .SCALAR )
37403740 def foo (x ):
37413741 return x
3742- self .assertEqual (foo .returnType , schema )
3742+ self .assertEqual (foo .returnType , DoubleType () )
37433743 self .assertEqual (foo .evalType , PythonEvalType .SQL_SCALAR_PANDAS_UDF )
37443744
37453745 @pandas_udf (returnType = schema , functionType = PandasUDFType .GROUPED_MAP )
@@ -3776,7 +3776,7 @@ def zero_with_type():
37763776 @pandas_udf (returnType = PandasUDFType .GROUPED_MAP )
37773777 def foo (df ):
37783778 return df
3779- with self .assertRaisesRegexp (ValueError , 'Invalid returnType' ):
3779+ with self .assertRaisesRegexp (TypeError , 'Invalid returnType' ):
37803780 @pandas_udf (returnType = 'double' , functionType = PandasUDFType .GROUPED_MAP )
37813781 def foo (df ):
37823782 return df
@@ -3825,15 +3825,16 @@ def random_udf(v):
38253825 return random_udf
38263826
38273827 def test_vectorized_udf_basic (self ):
3828- from pyspark .sql .functions import pandas_udf , col
3828+ from pyspark .sql .functions import pandas_udf , col , array
38293829 df = self .spark .range (10 ).select (
38303830 col ('id' ).cast ('string' ).alias ('str' ),
38313831 col ('id' ).cast ('int' ).alias ('int' ),
38323832 col ('id' ).alias ('long' ),
38333833 col ('id' ).cast ('float' ).alias ('float' ),
38343834 col ('id' ).cast ('double' ).alias ('double' ),
38353835 col ('id' ).cast ('decimal' ).alias ('decimal' ),
3836- col ('id' ).cast ('boolean' ).alias ('bool' ))
3836+ col ('id' ).cast ('boolean' ).alias ('bool' ),
3837+ array (col ('id' )).alias ('array_long' ))
38373838 f = lambda x : x
38383839 str_f = pandas_udf (f , StringType ())
38393840 int_f = pandas_udf (f , IntegerType ())
@@ -3842,10 +3843,11 @@ def test_vectorized_udf_basic(self):
38423843 double_f = pandas_udf (f , DoubleType ())
38433844 decimal_f = pandas_udf (f , DecimalType ())
38443845 bool_f = pandas_udf (f , BooleanType ())
3846+ array_long_f = pandas_udf (f , ArrayType (LongType ()))
38453847 res = df .select (str_f (col ('str' )), int_f (col ('int' )),
38463848 long_f (col ('long' )), float_f (col ('float' )),
38473849 double_f (col ('double' )), decimal_f ('decimal' ),
3848- bool_f (col ('bool' )))
3850+ bool_f (col ('bool' )), array_long_f ( 'array_long' ) )
38493851 self .assertEquals (df .collect (), res .collect ())
38503852
38513853 def test_register_nondeterministic_vectorized_udf_basic (self ):
@@ -4050,10 +4052,11 @@ def test_vectorized_udf_chained(self):
40504052 def test_vectorized_udf_wrong_return_type (self ):
40514053 from pyspark .sql .functions import pandas_udf , col
40524054 df = self .spark .range (10 )
4053- f = pandas_udf (lambda x : x * 1.0 , MapType (LongType (), LongType ()))
40544055 with QuietTest (self .sc ):
4055- with self .assertRaisesRegexp (Exception , 'Unsupported.*type.*conversion' ):
4056- df .select (f (col ('id' ))).collect ()
4056+ with self .assertRaisesRegexp (
4057+ NotImplementedError ,
4058+ 'Invalid returnType.*scalar Pandas UDF.*MapType' ):
4059+ pandas_udf (lambda x : x * 1.0 , MapType (LongType (), LongType ()))
40574060
40584061 def test_vectorized_udf_return_scalar (self ):
40594062 from pyspark .sql .functions import pandas_udf , col
@@ -4088,13 +4091,18 @@ def test_vectorized_udf_varargs(self):
40884091 self .assertEquals (df .collect (), res .collect ())
40894092
40904093 def test_vectorized_udf_unsupported_types (self ):
4091- from pyspark .sql .functions import pandas_udf , col
4092- schema = StructType ([StructField ("map" , MapType (StringType (), IntegerType ()), True )])
4093- df = self .spark .createDataFrame ([(None ,)], schema = schema )
4094- f = pandas_udf (lambda x : x , MapType (StringType (), IntegerType ()))
4094+ from pyspark .sql .functions import pandas_udf
40954095 with QuietTest (self .sc ):
4096- with self .assertRaisesRegexp (Exception , 'Unsupported data type' ):
4097- df .select (f (col ('map' ))).collect ()
4096+ with self .assertRaisesRegexp (
4097+ NotImplementedError ,
4098+ 'Invalid returnType.*scalar Pandas UDF.*MapType' ):
4099+ pandas_udf (lambda x : x , MapType (StringType (), IntegerType ()))
4100+
4101+ with QuietTest (self .sc ):
4102+ with self .assertRaisesRegexp (
4103+ NotImplementedError ,
4104+ 'Invalid returnType.*scalar Pandas UDF.*BinaryType' ):
4105+ pandas_udf (lambda x : x , BinaryType ())
40984106
40994107 def test_vectorized_udf_dates (self ):
41004108 from pyspark .sql .functions import pandas_udf , col
@@ -4325,15 +4333,16 @@ def data(self):
43254333 .withColumn ("vs" , array ([lit (i ) for i in range (20 , 30 )])) \
43264334 .withColumn ("v" , explode (col ('vs' ))).drop ('vs' )
43274335
4328- def test_simple (self ):
4329- from pyspark .sql .functions import pandas_udf , PandasUDFType
4330- df = self .data
4336+ def test_supported_types (self ):
4337+ from pyspark .sql .functions import pandas_udf , PandasUDFType , array , col
4338+ df = self .data . withColumn ( "arr" , array ( col ( "id" )))
43314339
43324340 foo_udf = pandas_udf (
43334341 lambda pdf : pdf .assign (v1 = pdf .v * pdf .id * 1.0 , v2 = pdf .v + pdf .id ),
43344342 StructType (
43354343 [StructField ('id' , LongType ()),
43364344 StructField ('v' , IntegerType ()),
4345+ StructField ('arr' , ArrayType (LongType ())),
43374346 StructField ('v1' , DoubleType ()),
43384347 StructField ('v2' , LongType ())]),
43394348 PandasUDFType .GROUPED_MAP
@@ -4436,17 +4445,15 @@ def test_datatype_string(self):
44364445
44374446 def test_wrong_return_type (self ):
44384447 from pyspark .sql .functions import pandas_udf , PandasUDFType
4439- df = self .data
4440-
4441- foo = pandas_udf (
4442- lambda pdf : pdf ,
4443- 'id long, v map<int, int>' ,
4444- PandasUDFType .GROUPED_MAP
4445- )
44464448
44474449 with QuietTest (self .sc ):
4448- with self .assertRaisesRegexp (Exception , 'Unsupported.*type.*conversion' ):
4449- df .groupby ('id' ).apply (foo ).sort ('id' ).toPandas ()
4450+ with self .assertRaisesRegexp (
4451+ NotImplementedError ,
4452+ 'Invalid returnType.*grouped map Pandas UDF.*MapType' ):
4453+ pandas_udf (
4454+ lambda pdf : pdf ,
4455+ 'id long, v map<int, int>' ,
4456+ PandasUDFType .GROUPED_MAP )
44504457
44514458 def test_wrong_args (self ):
44524459 from pyspark .sql .functions import udf , pandas_udf , sum , PandasUDFType
@@ -4465,23 +4472,30 @@ def test_wrong_args(self):
44654472 df .groupby ('id' ).apply (
44664473 pandas_udf (lambda : 1 , StructType ([StructField ("d" , DoubleType ())])))
44674474 with self .assertRaisesRegexp (ValueError , 'Invalid udf' ):
4468- df .groupby ('id' ).apply (
4469- pandas_udf (lambda x , y : x , StructType ([StructField ("d" , DoubleType ())])))
4475+ df .groupby ('id' ).apply (pandas_udf (lambda x , y : x , DoubleType ()))
44704476 with self .assertRaisesRegexp (ValueError , 'Invalid udf.*GROUPED_MAP' ):
44714477 df .groupby ('id' ).apply (
4472- pandas_udf (lambda x , y : x , StructType ([StructField ("d" , DoubleType ())]),
4473- PandasUDFType .SCALAR ))
4478+ pandas_udf (lambda x , y : x , DoubleType (), PandasUDFType .SCALAR ))
44744479
44754480 def test_unsupported_types (self ):
4476- from pyspark .sql .functions import pandas_udf , col , PandasUDFType
4481+ from pyspark .sql .functions import pandas_udf , PandasUDFType
44774482 schema = StructType (
44784483 [StructField ("id" , LongType (), True ),
44794484 StructField ("map" , MapType (StringType (), IntegerType ()), True )])
4480- df = self .spark .createDataFrame ([(1 , None ,)], schema = schema )
4481- f = pandas_udf (lambda x : x , df .schema , PandasUDFType .GROUPED_MAP )
44824485 with QuietTest (self .sc ):
4483- with self .assertRaisesRegexp (Exception , 'Unsupported data type' ):
4484- df .groupby ('id' ).apply (f ).collect ()
4486+ with self .assertRaisesRegexp (
4487+ NotImplementedError ,
4488+ 'Invalid returnType.*grouped map Pandas UDF.*MapType' ):
4489+ pandas_udf (lambda x : x , schema , PandasUDFType .GROUPED_MAP )
4490+
4491+ schema = StructType (
4492+ [StructField ("id" , LongType (), True ),
4493+ StructField ("arr_ts" , ArrayType (TimestampType ()), True )])
4494+ with QuietTest (self .sc ):
4495+ with self .assertRaisesRegexp (
4496+ NotImplementedError ,
4497+ 'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType' ):
4498+ pandas_udf (lambda x : x , schema , PandasUDFType .GROUPED_MAP )
44854499
44864500 # Regression test for SPARK-23314
44874501 def test_timestamp_dst (self ):
0 commit comments