Skip to content

Commit 209b936

Browse files
BryanCutlerHyukjinKwon
authored andcommitted
[SPARK-20791][PYSPARK] Use Arrow to create Spark DataFrame from Pandas
## What changes were proposed in this pull request? This change uses Arrow to optimize the creation of a Spark DataFrame from a Pandas DataFrame. The input df is sliced according to the default parallelism. The optimization is enabled with the existing conf "spark.sql.execution.arrow.enabled" and is disabled by default. ## How was this patch tested? Added new unit test to create DataFrame with and without the optimization enabled, then compare results. Author: Bryan Cutler <[email protected]> Author: Takuya UESHIN <[email protected]> Closes #19459 from BryanCutler/arrow-createDataFrame-from_pandas-SPARK-20791.
1 parent 3d90b2c commit 209b936

File tree

8 files changed

+254
-43
lines changed

8 files changed

+254
-43
lines changed

python/pyspark/context.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -475,24 +475,30 @@ def f(split, iterator):
475475
return xrange(getStart(split), getStart(split + 1), step)
476476

477477
return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
478-
# Calling the Java parallelize() method with an ArrayList is too slow,
479-
# because it sends O(n) Py4J commands. As an alternative, serialized
480-
# objects are written to a file and loaded through textFile().
478+
479+
# Make sure we distribute data evenly if it's smaller than self.batchSize
480+
if "__len__" not in dir(c):
481+
c = list(c) # Make it a list so we can compute its length
482+
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
483+
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
484+
jrdd = self._serialize_to_jvm(c, numSlices, serializer)
485+
return RDD(jrdd, self, serializer)
486+
487+
def _serialize_to_jvm(self, data, parallelism, serializer):
488+
"""
489+
Calling the Java parallelize() method with an ArrayList is too slow,
490+
because it sends O(n) Py4J commands. As an alternative, serialized
491+
objects are written to a file and loaded through textFile().
492+
"""
481493
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
482494
try:
483-
# Make sure we distribute data evenly if it's smaller than self.batchSize
484-
if "__len__" not in dir(c):
485-
c = list(c) # Make it a list so we can compute its length
486-
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
487-
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
488-
serializer.dump_stream(c, tempFile)
495+
serializer.dump_stream(data, tempFile)
489496
tempFile.close()
490497
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
491-
jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
498+
return readRDDFromFile(self._jsc, tempFile.name, parallelism)
492499
finally:
493500
# readRDDFromFile eagerily reads the file so we can delete right after.
494501
os.unlink(tempFile.name)
495-
return RDD(jrdd, self, serializer)
496502

497503
def pickleFile(self, name, minPartitions=None):
498504
"""

python/pyspark/java_gateway.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def killChild():
121121
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
122122
# TODO(davies): move into sql
123123
java_import(gateway.jvm, "org.apache.spark.sql.*")
124+
java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
124125
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
125126
java_import(gateway.jvm, "scala.Tuple2")
126127

python/pyspark/serializers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ def __repr__(self):
214214

215215

216216
def _create_batch(series):
217+
"""
218+
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
219+
220+
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
221+
:return: Arrow RecordBatch
222+
"""
223+
217224
from pyspark.sql.types import _check_series_convert_timestamps_internal
218225
import pyarrow as pa
219226
# Make input conform to [(series1, type1), (series2, type2), ...]
@@ -229,7 +236,8 @@ def cast_series(s, t):
229236
# NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
230237
return _check_series_convert_timestamps_internal(s.fillna(0))\
231238
.values.astype('datetime64[us]', copy=False)
232-
elif t == pa.date32():
239+
# NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1
240+
elif t is not None and t == pa.date32():
233241
# TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8
234242
return s.dt.date
235243
elif t is None or s.dtype == t.to_pandas_dtype():

python/pyspark/sql/session.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
basestring = unicode = str
2626
xrange = range
2727
else:
28-
from itertools import imap as map
28+
from itertools import izip as zip, imap as map
2929

3030
from pyspark import since
3131
from pyspark.rdd import RDD, ignore_unicode_prefix
@@ -417,12 +417,12 @@ def _createFromLocal(self, data, schema):
417417
data = [schema.toInternal(row) for row in data]
418418
return self._sc.parallelize(data), schema
419419

420-
def _get_numpy_record_dtypes(self, rec):
420+
def _get_numpy_record_dtype(self, rec):
421421
"""
422422
Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
423-
the dtypes of records so they can be properly loaded into Spark.
424-
:param rec: a numpy record to check dtypes
425-
:return corrected dtypes for a numpy.record or None if no correction needed
423+
the dtypes of fields in a record so they can be properly loaded into Spark.
424+
:param rec: a numpy record to check field dtypes
425+
:return corrected dtype for a numpy.record or None if no correction needed
426426
"""
427427
import numpy as np
428428
cur_dtypes = rec.dtype
@@ -438,28 +438,70 @@ def _get_numpy_record_dtypes(self, rec):
438438
curr_type = 'datetime64[us]'
439439
has_rec_fix = True
440440
record_type_list.append((str(col_names[i]), curr_type))
441-
return record_type_list if has_rec_fix else None
441+
return np.dtype(record_type_list) if has_rec_fix else None
442442

443-
def _convert_from_pandas(self, pdf, schema):
443+
def _convert_from_pandas(self, pdf):
444444
"""
445445
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
446-
:return tuple of list of records and schema
446+
:return list of records
447447
"""
448-
# If no schema supplied by user then get the names of columns only
449-
if schema is None:
450-
schema = [str(x) for x in pdf.columns]
451448

452449
# Convert pandas.DataFrame to list of numpy records
453450
np_records = pdf.to_records(index=False)
454451

455452
# Check if any columns need to be fixed for Spark to infer properly
456453
if len(np_records) > 0:
457-
record_type_list = self._get_numpy_record_dtypes(np_records[0])
458-
if record_type_list is not None:
459-
return [r.astype(record_type_list).tolist() for r in np_records], schema
454+
record_dtype = self._get_numpy_record_dtype(np_records[0])
455+
if record_dtype is not None:
456+
return [r.astype(record_dtype).tolist() for r in np_records]
460457

461458
# Convert list of numpy records to python lists
462-
return [r.tolist() for r in np_records], schema
459+
return [r.tolist() for r in np_records]
460+
461+
def _create_from_pandas_with_arrow(self, pdf, schema):
462+
"""
463+
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
464+
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
465+
data types will be used to coerce the data in Pandas to Arrow conversion.
466+
"""
467+
from pyspark.serializers import ArrowSerializer, _create_batch
468+
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
469+
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
470+
471+
# Determine arrow types to coerce data when creating batches
472+
if isinstance(schema, StructType):
473+
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
474+
elif isinstance(schema, DataType):
475+
raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
476+
else:
477+
# Any timestamps must be coerced to be compatible with Spark
478+
arrow_types = [to_arrow_type(TimestampType())
479+
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
480+
for t in pdf.dtypes]
481+
482+
# Slice the DataFrame to be batched
483+
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
484+
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
485+
486+
# Create Arrow record batches
487+
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)])
488+
for pdf_slice in pdf_slices]
489+
490+
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
491+
if isinstance(schema, (list, tuple)):
492+
struct = from_arrow_schema(batches[0].schema)
493+
for i, name in enumerate(schema):
494+
struct.fields[i].name = name
495+
struct.names[i] = name
496+
schema = struct
497+
498+
# Create the Spark DataFrame directly from the Arrow data and schema
499+
jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer())
500+
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
501+
jrdd, schema.json(), self._wrapped._jsqlContext)
502+
df = DataFrame(jdf, self._wrapped)
503+
df._schema = schema
504+
return df
463505

464506
@since(2.0)
465507
@ignore_unicode_prefix
@@ -557,7 +599,19 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
557599
except Exception:
558600
has_pandas = False
559601
if has_pandas and isinstance(data, pandas.DataFrame):
560-
data, schema = self._convert_from_pandas(data, schema)
602+
603+
# If no schema supplied by user then get the names of columns only
604+
if schema is None:
605+
schema = [str(x) for x in data.columns]
606+
607+
if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
608+
and len(data) > 0:
609+
try:
610+
return self._create_from_pandas_with_arrow(data, schema)
611+
except Exception as e:
612+
warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
613+
# Fallback to create DataFrame without arrow if raise some exception
614+
data = self._convert_from_pandas(data)
561615

562616
if isinstance(schema, StructType):
563617
verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True
@@ -576,7 +630,7 @@ def prepare(obj):
576630
verify_func(obj)
577631
return obj,
578632
else:
579-
if isinstance(schema, list):
633+
if isinstance(schema, (list, tuple)):
580634
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
581635
prepare = lambda obj: obj
582636

python/pyspark/sql/tests.py

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3127,9 +3127,9 @@ def setUpClass(cls):
31273127
StructField("5_double_t", DoubleType(), True),
31283128
StructField("6_date_t", DateType(), True),
31293129
StructField("7_timestamp_t", TimestampType(), True)])
3130-
cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
3131-
("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
3132-
("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
3130+
cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
3131+
(u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
3132+
(u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
31333133

31343134
@classmethod
31353135
def tearDownClass(cls):
@@ -3145,6 +3145,17 @@ def assertFramesEqual(self, df_with_arrow, df_without):
31453145
("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
31463146
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
31473147

3148+
def create_pandas_data_frame(self):
3149+
import pandas as pd
3150+
import numpy as np
3151+
data_dict = {}
3152+
for j, name in enumerate(self.schema.names):
3153+
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
3154+
# need to convert these to numpy types first
3155+
data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
3156+
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
3157+
return pd.DataFrame(data=data_dict)
3158+
31483159
def test_unsupported_datatype(self):
31493160
schema = StructType([StructField("decimal", DecimalType(), True)])
31503161
df = self.spark.createDataFrame([(None,)], schema=schema)
@@ -3161,21 +3172,15 @@ def test_null_conversion(self):
31613172
def test_toPandas_arrow_toggle(self):
31623173
df = self.spark.createDataFrame(self.data, schema=self.schema)
31633174
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
3164-
pdf = df.toPandas()
3165-
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
3175+
try:
3176+
pdf = df.toPandas()
3177+
finally:
3178+
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
31663179
pdf_arrow = df.toPandas()
31673180
self.assertFramesEqual(pdf_arrow, pdf)
31683181

31693182
def test_pandas_round_trip(self):
3170-
import pandas as pd
3171-
import numpy as np
3172-
data_dict = {}
3173-
for j, name in enumerate(self.schema.names):
3174-
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
3175-
# need to convert these to numpy types first
3176-
data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
3177-
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
3178-
pdf = pd.DataFrame(data=data_dict)
3183+
pdf = self.create_pandas_data_frame()
31793184
df = self.spark.createDataFrame(self.data, schema=self.schema)
31803185
pdf_arrow = df.toPandas()
31813186
self.assertFramesEqual(pdf_arrow, pdf)
@@ -3187,6 +3192,62 @@ def test_filtered_frame(self):
31873192
self.assertEqual(pdf.columns[0], "i")
31883193
self.assertTrue(pdf.empty)
31893194

3195+
def test_createDataFrame_toggle(self):
3196+
pdf = self.create_pandas_data_frame()
3197+
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
3198+
try:
3199+
df_no_arrow = self.spark.createDataFrame(pdf)
3200+
finally:
3201+
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
3202+
df_arrow = self.spark.createDataFrame(pdf)
3203+
self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
3204+
3205+
def test_createDataFrame_with_schema(self):
3206+
pdf = self.create_pandas_data_frame()
3207+
df = self.spark.createDataFrame(pdf, schema=self.schema)
3208+
self.assertEquals(self.schema, df.schema)
3209+
pdf_arrow = df.toPandas()
3210+
self.assertFramesEqual(pdf_arrow, pdf)
3211+
3212+
def test_createDataFrame_with_incorrect_schema(self):
3213+
pdf = self.create_pandas_data_frame()
3214+
wrong_schema = StructType(list(reversed(self.schema)))
3215+
with QuietTest(self.sc):
3216+
with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"):
3217+
self.spark.createDataFrame(pdf, schema=wrong_schema)
3218+
3219+
def test_createDataFrame_with_names(self):
3220+
pdf = self.create_pandas_data_frame()
3221+
# Test that schema as a list of column names gets applied
3222+
df = self.spark.createDataFrame(pdf, schema=list('abcdefg'))
3223+
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
3224+
# Test that schema as tuple of column names gets applied
3225+
df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg'))
3226+
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
3227+
3228+
def test_createDataFrame_with_single_data_type(self):
3229+
import pandas as pd
3230+
with QuietTest(self.sc):
3231+
with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"):
3232+
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
3233+
3234+
def test_createDataFrame_does_not_modify_input(self):
3235+
# Some series get converted for Spark to consume, this makes sure input is unchanged
3236+
pdf = self.create_pandas_data_frame()
3237+
# Use a nanosecond value to make sure it is not truncated
3238+
pdf.ix[0, '7_timestamp_t'] = 1
3239+
# Integers with nulls will get NaNs filled with 0 and will be casted
3240+
pdf.ix[1, '2_int_t'] = None
3241+
pdf_copy = pdf.copy(deep=True)
3242+
self.spark.createDataFrame(pdf, schema=self.schema)
3243+
self.assertTrue(pdf.equals(pdf_copy))
3244+
3245+
def test_schema_conversion_roundtrip(self):
3246+
from pyspark.sql.types import from_arrow_schema, to_arrow_schema
3247+
arrow_schema = to_arrow_schema(self.schema)
3248+
schema_rt = from_arrow_schema(arrow_schema)
3249+
self.assertEquals(self.schema, schema_rt)
3250+
31903251

31913252
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
31923253
class VectorizedUDFTests(ReusedSQLTestCase):

python/pyspark/sql/types.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,6 +1629,55 @@ def to_arrow_type(dt):
16291629
return arrow_type
16301630

16311631

1632+
def to_arrow_schema(schema):
1633+
""" Convert a schema from Spark to Arrow
1634+
"""
1635+
import pyarrow as pa
1636+
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
1637+
for field in schema]
1638+
return pa.schema(fields)
1639+
1640+
1641+
def from_arrow_type(at):
1642+
""" Convert pyarrow type to Spark data type.
1643+
"""
1644+
# TODO: newer pyarrow has is_boolean(at) functions that would be better to check type
1645+
import pyarrow as pa
1646+
if at == pa.bool_():
1647+
spark_type = BooleanType()
1648+
elif at == pa.int8():
1649+
spark_type = ByteType()
1650+
elif at == pa.int16():
1651+
spark_type = ShortType()
1652+
elif at == pa.int32():
1653+
spark_type = IntegerType()
1654+
elif at == pa.int64():
1655+
spark_type = LongType()
1656+
elif at == pa.float32():
1657+
spark_type = FloatType()
1658+
elif at == pa.float64():
1659+
spark_type = DoubleType()
1660+
elif type(at) == pa.DecimalType:
1661+
spark_type = DecimalType(precision=at.precision, scale=at.scale)
1662+
elif at == pa.string():
1663+
spark_type = StringType()
1664+
elif at == pa.date32():
1665+
spark_type = DateType()
1666+
elif type(at) == pa.TimestampType:
1667+
spark_type = TimestampType()
1668+
else:
1669+
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
1670+
return spark_type
1671+
1672+
1673+
def from_arrow_schema(arrow_schema):
1674+
""" Convert schema from Arrow to Spark.
1675+
"""
1676+
return StructType(
1677+
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
1678+
for field in arrow_schema])
1679+
1680+
16321681
def _check_dataframe_localize_timestamps(pdf):
16331682
"""
16341683
Convert timezone aware timestamps to timezone-naive in local time

0 commit comments

Comments
 (0)