@@ -4637,44 +4637,6 @@ def foofoo(x, y):
46374637 ).collect
46384638 )
46394639
4640- def test_pandas_udf_when_input_has_none (self ):
4641- import math
4642- from pyspark .sql .functions import pandas_udf
4643- import pandas as pd
4644-
4645- values = [1.0 ] * 10 + [None ] * 10 + [2.0 ] * 10
4646- pdf = pd .DataFrame ({'A' : values })
4647- df = self .spark .createDataFrame (pdf ).repartition (1 )
4648-
4649- @pandas_udf (returnType = DoubleType ())
4650- def gt_2_double (column ):
4651- return (column >= 2 ).where (column .notnull ())
4652-
4653- # This pandas udf returns Pandas.Series of dtype as float64.
4654- # If we define the pandas udf with incorrect data type BooleanType,
4655- # we should see an exception.
4656- @pandas_udf (returnType = BooleanType ())
4657- def gt_2_boolean (column ):
4658- return (column >= 2 ).where (column .notnull ())
4659-
4660- udf_double = df .select (['A' ]).withColumn ('udf' , gt_2_double ('A' ))
4661- udf_boolean = df .select (['A' ]).withColumn ('udf' , gt_2_boolean ('A' ))
4662-
4663- result = udf_double .collect ()
4664- result_part1 = [x [1 ] for x in result if x [0 ] == 1.0 ]
4665- self .assertEqual (set (result_part1 ), set ([0.0 ]))
4666- result_part2 = [x [1 ] for x in result if x [0 ] == 2.0 ]
4667- self .assertEqual (set (result_part2 ), set ([1.0 ]))
4668- result_part3 = [x [1 ] for x in result if math .isnan (x [0 ])]
4669- self .assertEqual (set (result_part3 ), set ([None ]))
4670-
4671- with QuietTest (self .sc ):
4672- with self .assertRaisesRegexp (Exception , "Return Pandas.Series of the user-defined " +
4673- "function's dtype is float64 which doesn't " +
4674- "match the arrow type bool of defined type " +
4675- "BooleanType" ):
4676- udf_boolean .collect ()
4677-
46784640
46794641@unittest .skipIf (
46804642 not _have_pandas or not _have_pyarrow ,
0 commit comments