@@ -3127,9 +3127,9 @@ def setUpClass(cls):
31273127 StructField ("5_double_t" , DoubleType (), True ),
31283128 StructField ("6_date_t" , DateType (), True ),
31293129 StructField ("7_timestamp_t" , TimestampType (), True )])
3130- cls .data = [("a" , 1 , 10 , 0.2 , 2.0 , datetime (1969 , 1 , 1 ), datetime (1969 , 1 , 1 , 1 , 1 , 1 )),
3131- ("b" , 2 , 20 , 0.4 , 4.0 , datetime (2012 , 2 , 2 ), datetime (2012 , 2 , 2 , 2 , 2 , 2 )),
3132- ("c" , 3 , 30 , 0.8 , 6.0 , datetime (2100 , 3 , 3 ), datetime (2100 , 3 , 3 , 3 , 3 , 3 ))]
3130+ cls .data = [(u "a" , 1 , 10 , 0.2 , 2.0 , datetime (1969 , 1 , 1 ), datetime (1969 , 1 , 1 , 1 , 1 , 1 )),
3131+ (u "b" , 2 , 20 , 0.4 , 4.0 , datetime (2012 , 2 , 2 ), datetime (2012 , 2 , 2 , 2 , 2 , 2 )),
3132+ (u "c" , 3 , 30 , 0.8 , 6.0 , datetime (2100 , 3 , 3 ), datetime (2100 , 3 , 3 , 3 , 3 , 3 ))]
31333133
31343134 @classmethod
31353135 def tearDownClass (cls ):
@@ -3145,6 +3145,17 @@ def assertFramesEqual(self, df_with_arrow, df_without):
31453145 ("\n \n Without:\n %s\n %s" % (df_without , df_without .dtypes )))
31463146 self .assertTrue (df_without .equals (df_with_arrow ), msg = msg )
31473147
3148+ def create_pandas_data_frame (self ):
3149+ import pandas as pd
3150+ import numpy as np
3151+ data_dict = {}
3152+ for j , name in enumerate (self .schema .names ):
3153+ data_dict [name ] = [self .data [i ][j ] for i in range (len (self .data ))]
3154+ # need to convert these to numpy types first
3155+ data_dict ["2_int_t" ] = np .int32 (data_dict ["2_int_t" ])
3156+ data_dict ["4_float_t" ] = np .float32 (data_dict ["4_float_t" ])
3157+ return pd .DataFrame (data = data_dict )
3158+
31483159 def test_unsupported_datatype (self ):
31493160 schema = StructType ([StructField ("decimal" , DecimalType (), True )])
31503161 df = self .spark .createDataFrame ([(None ,)], schema = schema )
@@ -3161,21 +3172,15 @@ def test_null_conversion(self):
31613172 def test_toPandas_arrow_toggle (self ):
31623173 df = self .spark .createDataFrame (self .data , schema = self .schema )
31633174 self .spark .conf .set ("spark.sql.execution.arrow.enabled" , "false" )
3164- pdf = df .toPandas ()
3165- self .spark .conf .set ("spark.sql.execution.arrow.enabled" , "true" )
3175+ try :
3176+ pdf = df .toPandas ()
3177+ finally :
3178+ self .spark .conf .set ("spark.sql.execution.arrow.enabled" , "true" )
31663179 pdf_arrow = df .toPandas ()
31673180 self .assertFramesEqual (pdf_arrow , pdf )
31683181
31693182 def test_pandas_round_trip (self ):
3170- import pandas as pd
3171- import numpy as np
3172- data_dict = {}
3173- for j , name in enumerate (self .schema .names ):
3174- data_dict [name ] = [self .data [i ][j ] for i in range (len (self .data ))]
3175- # need to convert these to numpy types first
3176- data_dict ["2_int_t" ] = np .int32 (data_dict ["2_int_t" ])
3177- data_dict ["4_float_t" ] = np .float32 (data_dict ["4_float_t" ])
3178- pdf = pd .DataFrame (data = data_dict )
3183+ pdf = self .create_pandas_data_frame ()
31793184 df = self .spark .createDataFrame (self .data , schema = self .schema )
31803185 pdf_arrow = df .toPandas ()
31813186 self .assertFramesEqual (pdf_arrow , pdf )
@@ -3187,6 +3192,62 @@ def test_filtered_frame(self):
31873192 self .assertEqual (pdf .columns [0 ], "i" )
31883193 self .assertTrue (pdf .empty )
31893194
3195+ def test_createDataFrame_toggle (self ):
3196+ pdf = self .create_pandas_data_frame ()
3197+ self .spark .conf .set ("spark.sql.execution.arrow.enabled" , "false" )
3198+ try :
3199+ df_no_arrow = self .spark .createDataFrame (pdf )
3200+ finally :
3201+ self .spark .conf .set ("spark.sql.execution.arrow.enabled" , "true" )
3202+ df_arrow = self .spark .createDataFrame (pdf )
3203+ self .assertEquals (df_no_arrow .collect (), df_arrow .collect ())
3204+
3205+ def test_createDataFrame_with_schema (self ):
3206+ pdf = self .create_pandas_data_frame ()
3207+ df = self .spark .createDataFrame (pdf , schema = self .schema )
3208+ self .assertEquals (self .schema , df .schema )
3209+ pdf_arrow = df .toPandas ()
3210+ self .assertFramesEqual (pdf_arrow , pdf )
3211+
3212+ def test_createDataFrame_with_incorrect_schema (self ):
3213+ pdf = self .create_pandas_data_frame ()
3214+ wrong_schema = StructType (list (reversed (self .schema )))
3215+ with QuietTest (self .sc ):
3216+ with self .assertRaisesRegexp (TypeError , ".*field.*can.not.accept.*type" ):
3217+ self .spark .createDataFrame (pdf , schema = wrong_schema )
3218+
3219+ def test_createDataFrame_with_names (self ):
3220+ pdf = self .create_pandas_data_frame ()
3221+ # Test that schema as a list of column names gets applied
3222+ df = self .spark .createDataFrame (pdf , schema = list ('abcdefg' ))
3223+ self .assertEquals (df .schema .fieldNames (), list ('abcdefg' ))
3224+ # Test that schema as tuple of column names gets applied
3225+ df = self .spark .createDataFrame (pdf , schema = tuple ('abcdefg' ))
3226+ self .assertEquals (df .schema .fieldNames (), list ('abcdefg' ))
3227+
3228+ def test_createDataFrame_with_single_data_type (self ):
3229+ import pandas as pd
3230+ with QuietTest (self .sc ):
3231+ with self .assertRaisesRegexp (TypeError , ".*IntegerType.*tuple" ):
3232+ self .spark .createDataFrame (pd .DataFrame ({"a" : [1 ]}), schema = "int" )
3233+
3234+ def test_createDataFrame_does_not_modify_input (self ):
3235+ # Some series get converted for Spark to consume, this makes sure input is unchanged
3236+ pdf = self .create_pandas_data_frame ()
3237+ # Use a nanosecond value to make sure it is not truncated
3238+ pdf .ix [0 , '7_timestamp_t' ] = 1
3239+ # Integers with nulls will get NaNs filled with 0 and will be casted
3240+ pdf .ix [1 , '2_int_t' ] = None
3241+ pdf_copy = pdf .copy (deep = True )
3242+ self .spark .createDataFrame (pdf , schema = self .schema )
3243+ self .assertTrue (pdf .equals (pdf_copy ))
3244+
3245+ def test_schema_conversion_roundtrip (self ):
3246+ from pyspark .sql .types import from_arrow_schema , to_arrow_schema
3247+ arrow_schema = to_arrow_schema (self .schema )
3248+ schema_rt = from_arrow_schema (arrow_schema )
3249+ self .assertEquals (self .schema , schema_rt )
3250+
31903251
31913252@unittest .skipIf (not _have_pandas or not _have_arrow , "Pandas or Arrow not installed" )
31923253class VectorizedUDFTests (ReusedSQLTestCase ):
0 commit comments