-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-23030][SQL][PYTHON] Use Arrow stream format for creating from and collecting Pandas DataFrames #21546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9af4821
d617f0d
f10d5d9
03653c6
1b93246
ce22d8a
dede0bd
9e29b09
ceb8d38
f42e4ea
951843d
af03c6b
b047c16
a77b89e
81c8209
5f46a02
7694b8f
a5a1fbe
555605a
54d6979
4af58f9
c6d24f2
b971e42
876c066
a25248e
daa9074
66b59c7
ed248f9
4ae0c11
92b8e26
89d7836
5549644
2fe46f8
ffb47cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -185,27 +185,31 @@ def loads(self, obj): | |
| raise NotImplementedError | ||
|
|
||
|
|
||
| class ArrowSerializer(FramedSerializer): | ||
| class ArrowStreamSerializer(Serializer): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if we can reuse this for
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That was my thought too. It's pretty close, although we do some different handling |
||
| """ | ||
| Serializes bytes as Arrow data with the Arrow file format. | ||
| Serializes Arrow record batches as a stream. | ||
| """ | ||
|
|
||
| def dumps(self, batch): | ||
| def dump_stream(self, iterator, stream): | ||
| import pyarrow as pa | ||
| import io | ||
| sink = io.BytesIO() | ||
| writer = pa.RecordBatchFileWriter(sink, batch.schema) | ||
| writer.write_batch(batch) | ||
| writer.close() | ||
| return sink.getvalue() | ||
| writer = None | ||
| try: | ||
| for batch in iterator: | ||
| if writer is None: | ||
| writer = pa.RecordBatchStreamWriter(stream, batch.schema) | ||
| writer.write_batch(batch) | ||
| finally: | ||
| if writer is not None: | ||
| writer.close() | ||
|
|
||
| def loads(self, obj): | ||
| def load_stream(self, stream): | ||
| import pyarrow as pa | ||
| reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) | ||
| return reader.read_all() | ||
| reader = pa.open_stream(stream) | ||
| for batch in reader: | ||
| yield batch | ||
|
|
||
| def __repr__(self): | ||
| return "ArrowSerializer" | ||
| return "ArrowStreamSerializer" | ||
|
|
||
|
|
||
| def _create_batch(series, timezone): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,7 +29,7 @@ | |
|
|
||
| from pyspark import copy_func, since, _NoValue | ||
| from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix | ||
| from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ | ||
| from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \ | ||
| UTF8Deserializer | ||
| from pyspark.storagelevel import StorageLevel | ||
| from pyspark.traceback_utils import SCCallSiteSync | ||
|
|
@@ -2118,10 +2118,9 @@ def toPandas(self): | |
| from pyspark.sql.types import _check_dataframe_convert_date, \ | ||
| _check_dataframe_localize_timestamps | ||
| import pyarrow | ||
|
|
||
| tables = self._collectAsArrow() | ||
| if tables: | ||
| table = pyarrow.concat_tables(tables) | ||
| batches = self._collectAsArrow() | ||
| if len(batches) > 0: | ||
| table = pyarrow.Table.from_batches(batches) | ||
| pdf = table.to_pandas() | ||
| pdf = _check_dataframe_convert_date(pdf, self.schema) | ||
| return _check_dataframe_localize_timestamps(pdf, timezone) | ||
|
|
@@ -2170,14 +2169,14 @@ def toPandas(self): | |
|
|
||
| def _collectAsArrow(self): | ||
| """ | ||
| Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed | ||
| and available. | ||
| Returns all records as a list of ArrowRecordBatches, pyarrow must be installed | ||
| and available on driver and worker Python environments. | ||
|
|
||
| .. note:: Experimental. | ||
| """ | ||
| with SCCallSiteSync(self._sc) as css: | ||
| sock_info = self._jdf.collectAsArrowToPython() | ||
| return list(_load_from_socket(sock_info, ArrowSerializer())) | ||
| return list(_load_from_socket(sock_info, ArrowStreamSerializer())) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also need to update the description of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah, thanks! |
||
|
|
||
| ########################################################################################## | ||
| # Pandas compatibility | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,7 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ | |
| import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} | ||
| import org.apache.spark.sql.catalyst.util.DateTimeUtils | ||
| import org.apache.spark.sql.execution._ | ||
| import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} | ||
| import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters} | ||
| import org.apache.spark.sql.execution.command._ | ||
| import org.apache.spark.sql.execution.datasources.LogicalRelation | ||
| import org.apache.spark.sql.execution.python.EvaluatePython | ||
|
|
@@ -3273,13 +3273,49 @@ class Dataset[T] private[sql]( | |
| } | ||
|
|
||
| /** | ||
| * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. | ||
| * Collect a Dataset as Arrow batches and serve stream to PySpark. | ||
| */ | ||
| private[sql] def collectAsArrowToPython(): Array[Any] = { | ||
| val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone | ||
|
|
||
| withAction("collectAsArrowToPython", queryExecution) { plan => | ||
| val iter: Iterator[Array[Byte]] = | ||
| toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) | ||
| PythonRDD.serveIterator(iter, "serve-Arrow") | ||
| PythonRDD.serveToStream("serve-Arrow") { out => | ||
| val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) | ||
| val arrowBatchRdd = toArrowBatchRdd(plan) | ||
| val numPartitions = arrowBatchRdd.partitions.length | ||
|
|
||
| // Store collection results for worst case of 1 to N-1 partitions | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it better
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not necessary to buffer the first partition because it can be sent to Python right away, so only need an array of size N-1 |
||
| val results = new Array[Array[Array[Byte]]](numPartitions - 1) | ||
| var lastIndex = -1 // index of last partition written | ||
|
|
||
| // Handler to eagerly write partitions to Python in order | ||
| def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { | ||
| // If result is from next partition in order | ||
| if (index - 1 == lastIndex) { | ||
| batchWriter.writeBatches(arrowBatches.iterator) | ||
| lastIndex += 1 | ||
| // Write stored partitions that come next in order | ||
| while (lastIndex < results.length && results(lastIndex) != null) { | ||
| batchWriter.writeBatches(results(lastIndex).iterator) | ||
| results(lastIndex) = null | ||
| lastIndex += 1 | ||
| } | ||
| // After last batch, end the stream | ||
| if (lastIndex == results.length) { | ||
| batchWriter.end() | ||
| } | ||
| } else { | ||
| // Store partitions received out of order | ||
| results(index - 1) = arrowBatches | ||
| } | ||
| } | ||
|
|
||
| sparkSession.sparkContext.runJob( | ||
| arrowBatchRdd, | ||
| (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we call
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried playing around with that a while ago and can't remember if there was some problem, but I'll give it another shot.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at this again,
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I see. In that case, we need to do |
||
| 0 until numPartitions, | ||
| handlePartitionBatches) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of collecting partitions back at once and holding out of order partitions in driver waiting for partitions in order, is it better to incrementally run job on partitions in order and send streams to python side? So we don't need to hold out of order partitions in driver.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 chunking if we could. I recall Bryan said for grouped UDF we need the entire set. Also not sure if python side we have any assumption on how much of the partition is in each chunk (there shouldn't be?)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I believe this is how I did have another idea though, we could stream all partitions to Python out of order, then follow with another small batch of data that contains maps of partitionIndex to orderReceived. Then the partitions could be put into order on the Python side before making the Pandas DataFrame.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This still keeps Arrow record batches chunked within each partition, which can help the executor memory, but doesn't do anything for the driver side because we still need to collect the entire partition in the driver JVM.
No, Python doesn't care how many chunks the data is in, it's handled by pyarrow
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This sounds good!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess in worst case scenario, the driver still needs to hold all batches in memory. For example, all the batches arrive at the same time. I wonder if there is a way to: This way at least the computation is done in parallel, fetching the result sequentially is a trade off of speed vs memory, something we or the user can choose, but I imagine fetching some 10G - 20G data from executors sequentially shouldn't be too bad. |
||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -3386,20 +3422,20 @@ class Dataset[T] private[sql]( | |
| } | ||
| } | ||
|
|
||
| /** Convert to an RDD of ArrowPayload byte arrays */ | ||
| private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { | ||
| /** Convert to an RDD of serialized ArrowRecordBatches. */ | ||
| private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { | ||
| val schemaCaptured = this.schema | ||
| val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch | ||
| val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone | ||
| plan.execute().mapPartitionsInternal { iter => | ||
| val context = TaskContext.get() | ||
| ArrowConverters.toPayloadIterator( | ||
| ArrowConverters.toBatchIterator( | ||
| iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) | ||
| } | ||
| } | ||
|
|
||
| // This is only used in tests, for now. | ||
| private[sql] def toArrowPayload: RDD[ArrowPayload] = { | ||
| toArrowPayload(queryExecution.executedPlan) | ||
| private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = { | ||
| toArrowBatchRdd(queryExecution.executedPlan) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,6 @@ | |
|
|
||
| package org.apache.spark.sql.api.python | ||
|
|
||
| import org.apache.spark.api.java.JavaRDD | ||
| import org.apache.spark.sql.{DataFrame, SQLContext} | ||
| import org.apache.spark.sql.catalyst.analysis.FunctionRegistry | ||
| import org.apache.spark.sql.catalyst.expressions.ExpressionInfo | ||
|
|
@@ -34,17 +33,19 @@ private[sql] object PythonSQLUtils { | |
| } | ||
|
|
||
| /** | ||
| * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. | ||
| * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] | ||
| * using each serialized ArrowRecordBatch as a partition. | ||
| * | ||
| * @param payloadRDD A JavaRDD of ArrowPayloads. | ||
| * @param schemaString JSON Formatted Schema for ArrowPayloads. | ||
| * @param sqlContext The active [[SQLContext]]. | ||
| * @return The converted [[DataFrame]]. | ||
| * @param filename File to read the Arrow stream from. | ||
| * @param schemaString JSON Formatted Spark schema for Arrow batches. | ||
| * @return A new [[DataFrame]]. | ||
| */ | ||
| def arrowPayloadToDataFrame( | ||
| payloadRDD: JavaRDD[Array[Byte]], | ||
| schemaString: String, | ||
| sqlContext: SQLContext): DataFrame = { | ||
| ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) | ||
| def arrowReadStreamFromFile( | ||
|
||
| sqlContext: SQLContext, | ||
| filename: String, | ||
| schemaString: String): DataFrame = { | ||
| val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename) | ||
| ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi, sorry for the late review here, and more just a question for myself -- is this aspect tested at all? IIUC, it would be used in
spark.createDataFrame, but the tests in session.py don't have arrow enabled, right?not that I see a bug, mostly just wondering as I was looking at making my own changes here, and it would be nice if I knew there were some tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(if not, I can try to address this in some other work)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @squito , yes that's correct this is in the path that
ArrowTestswithcreateDataFrametests. These tests are skipped if pyarrow is not installed, but for our Jenkins tests it is installed under the Python 3.5 env so it gets tested there.It's a little subtle to see that they were run since the test output shows only when tests are skipped. You can see that for Python 2.7
ArrowTestsshow as skipped, but for 3.5 it does not.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made https://issues.apache.org/jira/browse/SPARK-25272 which will give a more clear output that the ArrowTests were run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, I worry about the test coverage of PySpark in general. Anybody in PySpark can lead the effort to propose a solution for improving the test coverage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although most parts in PySpark should be guaranteed by Spark Core and SQL, PySpark starts to have more and more PySpark-only stuffs. I am not very sure how well they are tested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @BryanCutler , sorry I didn't know where to look for those, they look much better than what I would have added!