@@ -3086,18 +3086,38 @@ class ArrowTests(ReusedPySparkTestCase):
30863086
30873087 @classmethod
30883088 def setUpClass (cls ):
3089+ from datetime import datetime
30893090 ReusedPySparkTestCase .setUpClass ()
3091+
3092+ # Synchronize default timezone between Python and Java
3093+ cls .tz_prev = os .environ .get ("TZ" , None ) # save current tz if set
3094+ tz = "America/Los_Angeles"
3095+ os .environ ["TZ" ] = tz
3096+ time .tzset ()
3097+
30903098 cls .spark = SparkSession (cls .sc )
3099+ cls .spark .conf .set ("spark.sql.session.timeZone" , tz )
30913100 cls .spark .conf .set ("spark.sql.execution.arrow.enabled" , "true" )
30923101 cls .schema = StructType ([
30933102 StructField ("1_str_t" , StringType (), True ),
30943103 StructField ("2_int_t" , IntegerType (), True ),
30953104 StructField ("3_long_t" , LongType (), True ),
30963105 StructField ("4_float_t" , FloatType (), True ),
3097- StructField ("5_double_t" , DoubleType (), True )])
3098- cls .data = [("a" , 1 , 10 , 0.2 , 2.0 ),
3099- ("b" , 2 , 20 , 0.4 , 4.0 ),
3100- ("c" , 3 , 30 , 0.8 , 6.0 )]
3106+ StructField ("5_double_t" , DoubleType (), True ),
3107+ StructField ("6_date_t" , DateType (), True ),
3108+ StructField ("7_timestamp_t" , TimestampType (), True )])
3109+ cls .data = [("a" , 1 , 10 , 0.2 , 2.0 , datetime (1969 , 1 , 1 ), datetime (1969 , 1 , 1 , 1 , 1 , 1 )),
3110+ ("b" , 2 , 20 , 0.4 , 4.0 , datetime (2012 , 2 , 2 ), datetime (2012 , 2 , 2 , 2 , 2 , 2 )),
3111+ ("c" , 3 , 30 , 0.8 , 6.0 , datetime (2100 , 3 , 3 ), datetime (2100 , 3 , 3 , 3 , 3 , 3 ))]
3112+
3113+ @classmethod
3114+ def tearDownClass (cls ):
3115+ del os .environ ["TZ" ]
3116+ if cls .tz_prev is not None :
3117+ os .environ ["TZ" ] = cls .tz_prev
3118+ time .tzset ()
3119+ ReusedPySparkTestCase .tearDownClass ()
3120+ cls .spark .stop ()
31013121
31023122 def assertFramesEqual (self , df_with_arrow , df_without ):
31033123 msg = ("DataFrame from Arrow is not equal" +
@@ -3106,8 +3126,8 @@ def assertFramesEqual(self, df_with_arrow, df_without):
31063126 self .assertTrue (df_without .equals (df_with_arrow ), msg = msg )
31073127
31083128 def test_unsupported_datatype (self ):
3109- schema = StructType ([StructField ("dt " , DateType (), True )])
3110- df = self .spark .createDataFrame ([(datetime . date ( 1970 , 1 , 1 ) ,)], schema = schema )
3129+ schema = StructType ([StructField ("decimal " , DecimalType (), True )])
3130+ df = self .spark .createDataFrame ([(None ,)], schema = schema )
31113131 with QuietTest (self .sc ):
31123132 self .assertRaises (Exception , lambda : df .toPandas ())
31133133
@@ -3385,13 +3405,77 @@ def test_vectorized_udf_varargs(self):
33853405
33863406 def test_vectorized_udf_unsupported_types (self ):
33873407 from pyspark .sql .functions import pandas_udf , col
3388- schema = StructType ([StructField ("dt" , DateType (), True )])
3389- df = self .spark .createDataFrame ([(datetime . date ( 1970 , 1 , 1 ) ,)], schema = schema )
3390- f = pandas_udf (lambda x : x , DateType ())
3408+ schema = StructType ([StructField ("dt" , DecimalType (), True )])
3409+ df = self .spark .createDataFrame ([(None ,)], schema = schema )
3410+ f = pandas_udf (lambda x : x , DecimalType ())
33913411 with QuietTest (self .sc ):
33923412 with self .assertRaisesRegexp (Exception , 'Unsupported data type' ):
33933413 df .select (f (col ('dt' ))).collect ()
33943414
3415+ def test_vectorized_udf_null_date (self ):
3416+ from pyspark .sql .functions import pandas_udf , col
3417+ from datetime import date
3418+ schema = StructType ().add ("date" , DateType ())
3419+ data = [(date (1969 , 1 , 1 ),),
3420+ (date (2012 , 2 , 2 ),),
3421+ (None ,),
3422+ (date (2100 , 4 , 4 ),)]
3423+ df = self .spark .createDataFrame (data , schema = schema )
3424+ date_f = pandas_udf (lambda t : t , returnType = DateType ())
3425+ res = df .select (date_f (col ("date" )))
3426+ self .assertEquals (df .collect (), res .collect ())
3427+
3428+ def test_vectorized_udf_timestamps (self ):
3429+ from pyspark .sql .functions import pandas_udf , col
3430+ from datetime import datetime
3431+ schema = StructType ([
3432+ StructField ("idx" , LongType (), True ),
3433+ StructField ("timestamp" , TimestampType (), True )])
3434+ data = [(0 , datetime (1969 , 1 , 1 , 1 , 1 , 1 )),
3435+ (1 , datetime (2012 , 2 , 2 , 2 , 2 , 2 )),
3436+ (2 , None ),
3437+ (3 , datetime (2100 , 4 , 4 , 4 , 4 , 4 ))]
3438+ df = self .spark .createDataFrame (data , schema = schema )
3439+
3440+ # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc
3441+ f_timestamp_copy = pandas_udf (lambda t : t , returnType = TimestampType ())
3442+ df = df .withColumn ("timestamp_copy" , f_timestamp_copy (col ("timestamp" )))
3443+
3444+ @pandas_udf (returnType = BooleanType ())
3445+ def check_data (idx , timestamp , timestamp_copy ):
3446+ is_equal = timestamp .isnull () # use this array to check values are equal
3447+ for i in range (len (idx )):
3448+ # Check that timestamps are as expected in the UDF
3449+ is_equal [i ] = (is_equal [i ] and data [idx [i ]][1 ] is None ) or \
3450+ timestamp [i ].to_pydatetime () == data [idx [i ]][1 ]
3451+ return is_equal
3452+
3453+ result = df .withColumn ("is_equal" , check_data (col ("idx" ), col ("timestamp" ),
3454+ col ("timestamp_copy" ))).collect ()
3455+ # Check that collection values are correct
3456+ self .assertEquals (len (data ), len (result ))
3457+ for i in range (len (result )):
3458+ self .assertEquals (data [i ][1 ], result [i ][1 ]) # "timestamp" col
3459+ self .assertTrue (result [i ][3 ]) # "is_equal" data in udf was as expected
3460+
3461+ def test_vectorized_udf_return_timestamp_tz (self ):
3462+ from pyspark .sql .functions import pandas_udf , col
3463+ import pandas as pd
3464+ df = self .spark .range (10 )
3465+
3466+ @pandas_udf (returnType = TimestampType ())
3467+ def gen_timestamps (id ):
3468+ ts = [pd .Timestamp (i , unit = 'D' , tz = 'America/Los_Angeles' ) for i in id ]
3469+ return pd .Series (ts )
3470+
3471+ result = df .withColumn ("ts" , gen_timestamps (col ("id" ))).collect ()
3472+ spark_ts_t = TimestampType ()
3473+ for r in result :
3474+ i , ts = r
3475+ ts_tz = pd .Timestamp (i , unit = 'D' , tz = 'America/Los_Angeles' ).to_pydatetime ()
3476+ expected = spark_ts_t .fromInternal (spark_ts_t .toInternal (ts_tz ))
3477+ self .assertEquals (expected , ts )
3478+
33953479
33963480@unittest .skipIf (not _have_pandas or not _have_arrow , "Pandas or Arrow not installed" )
33973481class GroupbyApplyTests (ReusedPySparkTestCase ):
@@ -3550,8 +3634,8 @@ def test_wrong_args(self):
35503634 def test_unsupported_types (self ):
35513635 from pyspark .sql .functions import pandas_udf , col
35523636 schema = StructType (
3553- [StructField ("id" , LongType (), True ), StructField ("dt" , DateType (), True )])
3554- df = self .spark .createDataFrame ([(1 , datetime . date ( 1970 , 1 , 1 ) ,)], schema = schema )
3637+ [StructField ("id" , LongType (), True ), StructField ("dt" , DecimalType (), True )])
3638+ df = self .spark .createDataFrame ([(1 , None ,)], schema = schema )
35553639 f = pandas_udf (lambda x : x , df .schema )
35563640 with QuietTest (self .sc ):
35573641 with self .assertRaisesRegexp (Exception , 'Unsupported data type' ):
0 commit comments