Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
cd3d51e
createDataFrame working but with fixed schema in python
BryanCutler Sep 29, 2017
c73c7c6
added schema conversion
BryanCutler Oct 4, 2017
e9c6de7
add from_arrow_schema, test and cleanup
BryanCutler Oct 9, 2017
06b033f
fix style
BryanCutler Oct 9, 2017
9d667c6
fixed xrange for Python 3
BryanCutler Oct 10, 2017
31851f8
Merge remote-tracking branch 'upstream/master' into arrow-createDataF…
BryanCutler Oct 10, 2017
ca474db
moved python jvm call to PythonSQLUtils, added tearDownClass to tests
BryanCutler Oct 10, 2017
c7ddee6
forgot to rename conf
BryanCutler Oct 10, 2017
b00a924
fixed typo
BryanCutler Oct 13, 2017
e36a176
using schema if passed in to createDataFrame, added unit test to veri…
BryanCutler Oct 14, 2017
fc3a554
Merge remote-tracking branch 'upstream/master' into arrow-createDataF…
BryanCutler Oct 14, 2017
f42e351
updated function name to_arrow_type
BryanCutler Oct 14, 2017
76e87dc
revert DataFrame schema arg, added test for wrong schema, fixed typos
BryanCutler Oct 16, 2017
81ddfa9
moved common code between parallelize to _serialize_to_jvm
BryanCutler Oct 18, 2017
5e8e11f
when schema provided, attempt to cast series and fallback if not matc…
BryanCutler Oct 18, 2017
3052f30
added support for schema as list of names
BryanCutler Oct 18, 2017
9f7b1c0
Simplify `_createFromPandasWithArrow()`.
ueshin Oct 19, 2017
dc03657
changed to use izip
BryanCutler Oct 24, 2017
f421e2d
added check for case of specifying schema with like 'int'
BryanCutler Oct 24, 2017
0de3126
changed single type to fallback and error
BryanCutler Oct 24, 2017
c41cf33
Merge remote-tracking branch 'upstream/master' into arrow-createDataF…
BryanCutler Oct 27, 2017
b6df7bf
add support for date and timestamp for from_arrow_type
BryanCutler Oct 27, 2017
cfb1c3d
using _create_batch to make arrow batches also without explicit schem…
BryanCutler Oct 30, 2017
b362b9a
Merge remote-tracking branch 'upstream/master' into arrow-createDataF…
BryanCutler Nov 7, 2017
1c244d1
some minor cleanup of _convert_from_pandas
BryanCutler Nov 7, 2017
99ce1e4
minor cleanup of _create_from_pandas_with_arrow
BryanCutler Nov 7, 2017
7d9cc3e
avoid double copies of series with nulls
BryanCutler Nov 9, 2017
126f2e7
added test to make sure input is unchanged
BryanCutler Nov 9, 2017
421d0be
removed copy=True option, did not improve anything
BryanCutler Nov 9, 2017
0ad736b
fix pydoc
BryanCutler Nov 9, 2017
6c72e37
added schema tuple support, refactored creating schema
BryanCutler Nov 10, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,24 +475,30 @@ def f(split, iterator):
return xrange(getStart(split), getStart(split + 1), step)

return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().

# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
jrdd = self._serialize_to_jvm(c, numSlices, serializer)
return RDD(jrdd, self, serializer)

def _serialize_to_jvm(self, data, parallelism, serializer):
"""
Calling the Java parallelize() method with an ArrayList is too slow,
because it sends O(n) Py4J commands. As an alternative, serialized
objects are written to a file and loaded through textFile().
"""
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
try:
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
serializer.dump_stream(c, tempFile)
serializer.dump_stream(data, tempFile)
tempFile.close()
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
return readRDDFromFile(self._jsc, tempFile.name, parallelism)
finally:
# readRDDFromFile eagerily reads the file so we can delete right after.
os.unlink(tempFile.name)
return RDD(jrdd, self, serializer)

def pickleFile(self, name, minPartitions=None):
"""
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def killChild():
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
# TODO(davies): move into sql
java_import(gateway.jvm, "org.apache.spark.sql.*")
java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")

Expand Down
10 changes: 9 additions & 1 deletion python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ def __repr__(self):


def _create_batch(series):
"""
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.

:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
:return: Arrow RecordBatch
"""

from pyspark.sql.types import _check_series_convert_timestamps_internal
import pyarrow as pa
# Make input conform to [(series1, type1), (series2, type2), ...]
Expand All @@ -229,7 +236,8 @@ def cast_series(s, t):
# NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
return _check_series_convert_timestamps_internal(s.fillna(0))\
.values.astype('datetime64[us]', copy=False)
elif t == pa.date32():
# NOTE: can not compare None with pyarrow.DataType(), fixed with Arrow >= 0.7.1
elif t is not None and t == pa.date32():
# TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8
return s.dt.date
elif t is None or s.dtype == t.to_pandas_dtype():
Expand Down
88 changes: 71 additions & 17 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
basestring = unicode = str
xrange = range
else:
from itertools import imap as map
from itertools import izip as zip, imap as map

from pyspark import since
from pyspark.rdd import RDD, ignore_unicode_prefix
Expand Down Expand Up @@ -417,12 +417,12 @@ def _createFromLocal(self, data, schema):
data = [schema.toInternal(row) for row in data]
return self._sc.parallelize(data), schema

def _get_numpy_record_dtypes(self, rec):
def _get_numpy_record_dtype(self, rec):
"""
Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
the dtypes of records so they can be properly loaded into Spark.
:param rec: a numpy record to check dtypes
:return corrected dtypes for a numpy.record or None if no correction needed
the dtypes of fields in a record so they can be properly loaded into Spark.
:param rec: a numpy record to check field dtypes
:return corrected dtype for a numpy.record or None if no correction needed
"""
import numpy as np
cur_dtypes = rec.dtype
Expand All @@ -438,28 +438,70 @@ def _get_numpy_record_dtypes(self, rec):
curr_type = 'datetime64[us]'
has_rec_fix = True
record_type_list.append((str(col_names[i]), curr_type))
return record_type_list if has_rec_fix else None
return np.dtype(record_type_list) if has_rec_fix else None

def _convert_from_pandas(self, pdf, schema):
def _convert_from_pandas(self, pdf):
"""
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame
:return tuple of list of records and schema
:return list of records
"""
# If no schema supplied by user then get the names of columns only
if schema is None:
schema = [str(x) for x in pdf.columns]

# Convert pandas.DataFrame to list of numpy records
np_records = pdf.to_records(index=False)

# Check if any columns need to be fixed for Spark to infer properly
if len(np_records) > 0:
record_type_list = self._get_numpy_record_dtypes(np_records[0])
if record_type_list is not None:
return [r.astype(record_type_list).tolist() for r in np_records], schema
record_dtype = self._get_numpy_record_dtype(np_records[0])
if record_dtype is not None:
return [r.astype(record_dtype).tolist() for r in np_records]

# Convert list of numpy records to python lists
return [r.tolist() for r in np_records], schema
return [r.tolist() for r in np_records]

def _create_from_pandas_with_arrow(self, pdf, schema):
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype

# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
elif isinstance(schema, DataType):
raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
else:
# Any timestamps must be coerced to be compatible with Spark
arrow_types = [to_arrow_type(TimestampType())
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
for t in pdf.dtypes]

# Slice the DataFrame to be batched
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))

# Create Arrow record batches
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)])
for pdf_slice in pdf_slices]

# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
if isinstance(schema, (list, tuple)):
struct = from_arrow_schema(batches[0].schema)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BryanCutler, I think here we'd meet the same issue, SPARK-15244 in this code path. Mind opening a followup with a simple test if it is true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do

for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct

# Create the Spark DataFrame directly from the Arrow data and schema
jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer())
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
jrdd, schema.json(), self._wrapped._jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
return df

@since(2.0)
@ignore_unicode_prefix
Expand Down Expand Up @@ -557,7 +599,19 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
except Exception:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
data, schema = self._convert_from_pandas(data, schema)

# If no schema supplied by user then get the names of columns only
if schema is None:
schema = [str(x) for x in data.columns]

if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
and len(data) > 0:
try:
return self._create_from_pandas_with_arrow(data, schema)
except Exception as e:
warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
# Fallback to create DataFrame without arrow if raise some exception
data = self._convert_from_pandas(data)

if isinstance(schema, StructType):
verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True
Expand All @@ -576,7 +630,7 @@ def prepare(obj):
verify_func(obj)
return obj,
else:
if isinstance(schema, list):
if isinstance(schema, (list, tuple)):
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
prepare = lambda obj: obj

Expand Down
89 changes: 75 additions & 14 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3120,9 +3120,9 @@ def setUpClass(cls):
StructField("5_double_t", DoubleType(), True),
StructField("6_date_t", DateType(), True),
StructField("7_timestamp_t", TimestampType(), True)])
cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
(u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
(u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]

@classmethod
def tearDownClass(cls):
Expand All @@ -3138,6 +3138,17 @@ def assertFramesEqual(self, df_with_arrow, df_without):
("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)

def create_pandas_data_frame(self):
import pandas as pd
import numpy as np
data_dict = {}
for j, name in enumerate(self.schema.names):
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
# need to convert these to numpy types first
data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
return pd.DataFrame(data=data_dict)

def test_unsupported_datatype(self):
schema = StructType([StructField("decimal", DecimalType(), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
Expand All @@ -3154,21 +3165,15 @@ def test_null_conversion(self):
def test_toPandas_arrow_toggle(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
pdf = df.toPandas()
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
try:
pdf = df.toPandas()
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)

def test_pandas_round_trip(self):
import pandas as pd
import numpy as np
data_dict = {}
for j, name in enumerate(self.schema.names):
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
# need to convert these to numpy types first
data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
pdf = pd.DataFrame(data=data_dict)
pdf = self.create_pandas_data_frame()
df = self.spark.createDataFrame(self.data, schema=self.schema)
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)
Expand All @@ -3180,6 +3185,62 @@ def test_filtered_frame(self):
self.assertEqual(pdf.columns[0], "i")
self.assertTrue(pdf.empty)

def test_createDataFrame_toggle(self):
pdf = self.create_pandas_data_frame()
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
df_no_arrow = self.spark.createDataFrame(pdf)
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
df_arrow = self.spark.createDataFrame(pdf)
self.assertEquals(df_no_arrow.collect(), df_arrow.collect())

def test_createDataFrame_with_schema(self):
pdf = self.create_pandas_data_frame()
df = self.spark.createDataFrame(pdf, schema=self.schema)
self.assertEquals(self.schema, df.schema)
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)

def test_createDataFrame_with_incorrect_schema(self):
pdf = self.create_pandas_data_frame()
wrong_schema = StructType(list(reversed(self.schema)))
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"):
self.spark.createDataFrame(pdf, schema=wrong_schema)

def test_createDataFrame_with_names(self):
pdf = self.create_pandas_data_frame()
# Test that schema as a list of column names gets applied
df = self.spark.createDataFrame(pdf, schema=list('abcdefg'))
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
# Test that schema as tuple of column names gets applied
df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg'))
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))

def test_createDataFrame_with_single_data_type(self):
import pandas as pd
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"):
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")

def test_createDataFrame_does_not_modify_input(self):
# Some series get converted for Spark to consume, this makes sure input is unchanged
pdf = self.create_pandas_data_frame()
# Use a nanosecond value to make sure it is not truncated
pdf.ix[0, '7_timestamp_t'] = 1
# Integers with nulls will get NaNs filled with 0 and will be casted
pdf.ix[1, '2_int_t'] = None
pdf_copy = pdf.copy(deep=True)
self.spark.createDataFrame(pdf, schema=self.schema)
self.assertTrue(pdf.equals(pdf_copy))

def test_schema_conversion_roundtrip(self):
from pyspark.sql.types import from_arrow_schema, to_arrow_schema
arrow_schema = to_arrow_schema(self.schema)
schema_rt = from_arrow_schema(arrow_schema)
self.assertEquals(self.schema, schema_rt)


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class VectorizedUDFTests(ReusedSQLTestCase):
Expand Down
49 changes: 49 additions & 0 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,55 @@ def to_arrow_type(dt):
return arrow_type


def to_arrow_schema(schema):
""" Convert a schema from Spark to Arrow
"""
import pyarrow as pa
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
for field in schema]
return pa.schema(fields)


def from_arrow_type(at):
""" Convert pyarrow type to Spark data type.
"""
# TODO: newer pyarrow has is_boolean(at) functions that would be better to check type
import pyarrow as pa
if at == pa.bool_():
spark_type = BooleanType()
elif at == pa.int8():
spark_type = ByteType()
elif at == pa.int16():
spark_type = ShortType()
elif at == pa.int32():
spark_type = IntegerType()
elif at == pa.int64():
spark_type = LongType()
elif at == pa.float32():
spark_type = FloatType()
elif at == pa.float64():
spark_type = DoubleType()
elif type(at) == pa.DecimalType:
spark_type = DecimalType(precision=at.precision, scale=at.scale)
elif at == pa.string():
spark_type = StringType()
elif at == pa.date32():
spark_type = DateType()
elif type(at) == pa.TimestampType:
spark_type = TimestampType()
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
return spark_type


def from_arrow_schema(arrow_schema):
""" Convert schema from Arrow to Spark.
"""
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in arrow_schema])


def _check_dataframe_localize_timestamps(pdf):
"""
Convert timezone aware timestamps to timezone-naive in local time
Expand Down
Loading