Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
9af4821
change ArrowConverters to stream format
BryanCutler Jan 10, 2018
d617f0d
Change ArrowSerializer to use stream format
BryanCutler Jan 10, 2018
f10d5d9
toPandas is working with RecordBatch payloads, using custom handler t…
BryanCutler Jan 12, 2018
03653c6
cleanup and removed ArrowPayload, createDataFrame now working
BryanCutler Feb 10, 2018
1b93246
toPandas and createDataFrame working but tests fail with date cols
BryanCutler Mar 9, 2018
ce22d8a
removed usage of seekableByteChannel
BryanCutler Mar 27, 2018
dede0bd
for toPandas, set old collection result to null and add comments
BryanCutler Mar 28, 2018
9e29b09
cleanup, not yet passing ArrowConvertersSuite
BryanCutler Mar 28, 2018
ceb8d38
fix to read Arrow stream with multiple batches, cleanup, add docs, sc…
BryanCutler Mar 29, 2018
f42e4ea
use base OutputStream for serveToStream instead of DataOutputStream
BryanCutler Mar 29, 2018
951843d
accidentally removed date type checks, passing pyspark tests
BryanCutler Mar 29, 2018
af03c6b
Changed to only use Arrow batch bytes as payload, had to hack Arrow M…
BryanCutler Jun 12, 2018
b047c16
added todo comment
BryanCutler Jun 12, 2018
a77b89e
change getBatchesFromStream to return iterator
BryanCutler Jun 12, 2018
81c8209
need to end stream on toPandas after all batches sent to python, and …
BryanCutler Jun 12, 2018
5f46a02
forgot to remove old comment
BryanCutler Jun 12, 2018
7694b8f
fixed up comments
BryanCutler Jun 12, 2018
a5a1fbe
fixed up some wording
BryanCutler Jun 12, 2018
555605a
Updated MessageChannelReader to reflect Arrow changes
BryanCutler Jun 14, 2018
54d6979
move arrowStreamToDataFrame to arrowReadStreamFromFile as not being c…
BryanCutler Jun 14, 2018
4af58f9
rename ArrowSerializer to ArrowStreamSerializer
BryanCutler Jun 14, 2018
c6d24f2
try using static utilty functions instead of MessageChannelReader sub…
BryanCutler Jun 22, 2018
b971e42
fixed wording of _collectAsArrow
BryanCutler Jun 22, 2018
876c066
forgot to inline test data
BryanCutler Jul 24, 2018
a25248e
changed toPandas to send out of order batches, followed by batch orde…
BryanCutler Jun 27, 2018
daa9074
cleanup from review
BryanCutler Aug 21, 2018
66b59c7
change naming per requests
BryanCutler Aug 21, 2018
ed248f9
Revert "changed toPandas to send out of order batches, followed by ba…
BryanCutler Aug 21, 2018
4ae0c11
Merge remote-tracking branch 'upstream/master' into arrow-toPandas-st…
BryanCutler Aug 21, 2018
92b8e26
cleanup after Java Arrow 0.10.0, fixup and simplify getBatchesFromStream
BryanCutler Aug 21, 2018
89d7836
Merge remote-tracking branch 'upstream/master' into arrow-toPandas-st…
BryanCutler Aug 22, 2018
5549644
used tryWithResource, improved comments
BryanCutler Aug 23, 2018
2fe46f8
Merge remote-tracking branch 'upstream/master' into arrow-toPandas-st…
BryanCutler Aug 27, 2018
ffb47cb
Merge remote-tracking branch 'upstream/master' into arrow-toPandas-st…
BryanCutler Aug 28, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,26 @@ private[spark] object PythonRDD extends Logging {
* data collected from this job, and the secret for authentication.
*/
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
serveToStream(threadName) { out =>
writeIteratorToStream(items, new DataOutputStream(out))
}
}

/**
* Create a socket server and background thread to execute the writeFunc
* with the given OutputStream.
*
* The socket server can only accept one connection, or close if no connection
* in 15 seconds.
*
* Once a connection comes in, it will execute the block of code and pass in
* the socket output stream.
*
* The thread will terminate after the block of code is executed or any
* exceptions happen.
*/
private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)
Expand All @@ -410,9 +430,9 @@ private[spark] object PythonRDD extends Logging {
val sock = serverSocket.accept()
authHelper.authClient(sock)

val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
val out = new BufferedOutputStream(sock.getOutputStream)
Utils.tryWithSafeFinally {
writeIteratorToStream(items, out)
writeFunc(out)
} {
out.close()
sock.close()
Expand Down
11 changes: 7 additions & 4 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,10 +494,14 @@ def f(split, iterator):
c = list(c) # Make it a list so we can compute its length
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
jrdd = self._serialize_to_jvm(c, numSlices, serializer)

def reader_func(temp_filename):
return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices)

jrdd = self._serialize_to_jvm(c, serializer, reader_func)
return RDD(jrdd, self, serializer)

def _serialize_to_jvm(self, data, parallelism, serializer):
def _serialize_to_jvm(self, data, serializer, reader_func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, sorry for the late review here, and more just a question for myself -- is this aspect tested at all? IIUC, it would be used in spark.createDataFrame, but the tests in session.py don't have arrow enabled, right?

not that I see a bug, mostly just wondering as I was looking at making my own changes here, and it would be nice if I knew there were some tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(if not, I can try to address this in some other work)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @squito , yes that's correct this is in the path that ArrowTests with createDataFrame tests. These tests are skipped if pyarrow is not installed, but for our Jenkins tests it is installed under the Python 3.5 env so it gets tested there.

It's a little subtle to see that they were run since the test output shows only when tests are skipped. You can see that for Python 2.7 ArrowTests show as skipped, but for 3.5 it does not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made https://issues.apache.org/jira/browse/SPARK-25272 which will give a more clear output that the ArrowTests were run.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I worry about the test coverage of PySpark in general. Anybody in PySpark can lead the effort to propose a solution for improving the test coverage?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although most parts in PySpark should be guaranteed by Spark Core and SQL, PySpark starts to have more and more PySpark-only stuffs. I am not very sure how well they are tested.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @BryanCutler , sorry I didn't know where to look for those, they look much better than what I would have added!

"""
Calling the Java parallelize() method with an ArrayList is too slow,
because it sends O(n) Py4J commands. As an alternative, serialized
Expand All @@ -507,8 +511,7 @@ def _serialize_to_jvm(self, data, parallelism, serializer):
try:
serializer.dump_stream(data, tempFile)
tempFile.close()
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
return readRDDFromFile(self._jsc, tempFile.name, parallelism)
return reader_func(tempFile.name)
finally:
# readRDDFromFile eagerily reads the file so we can delete right after.
os.unlink(tempFile.name)
Expand Down
30 changes: 17 additions & 13 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,27 +185,31 @@ def loads(self, obj):
raise NotImplementedError


class ArrowSerializer(FramedSerializer):
class ArrowStreamSerializer(Serializer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we can reuse this for ArrowStreamPandasSerializer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my thought too. It's pretty close, although we do some different handling ArrowStreamPandasSerializer that needs to fit in somewhere. Maybe we can look into this as a followup?

"""
Serializes bytes as Arrow data with the Arrow file format.
Serializes Arrow record batches as a stream.
"""

def dumps(self, batch):
def dump_stream(self, iterator, stream):
import pyarrow as pa
import io
sink = io.BytesIO()
writer = pa.RecordBatchFileWriter(sink, batch.schema)
writer.write_batch(batch)
writer.close()
return sink.getvalue()
writer = None
try:
for batch in iterator:
if writer is None:
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
finally:
if writer is not None:
writer.close()

def loads(self, obj):
def load_stream(self, stream):
import pyarrow as pa
reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
return reader.read_all()
reader = pa.open_stream(stream)
for batch in reader:
yield batch

def __repr__(self):
return "ArrowSerializer"
return "ArrowStreamSerializer"


def _create_batch(series, timezone):
Expand Down
15 changes: 7 additions & 8 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from pyspark import copy_func, since, _NoValue
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \
from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \
UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
Expand Down Expand Up @@ -2118,10 +2118,9 @@ def toPandas(self):
from pyspark.sql.types import _check_dataframe_convert_date, \
_check_dataframe_localize_timestamps
import pyarrow

tables = self._collectAsArrow()
if tables:
table = pyarrow.concat_tables(tables)
batches = self._collectAsArrow()
if len(batches) > 0:
table = pyarrow.Table.from_batches(batches)
pdf = table.to_pandas()
pdf = _check_dataframe_convert_date(pdf, self.schema)
return _check_dataframe_localize_timestamps(pdf, timezone)
Expand Down Expand Up @@ -2170,14 +2169,14 @@ def toPandas(self):

def _collectAsArrow(self):
"""
Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed
and available.
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
and available on driver and worker Python environments.

.. note:: Experimental.
"""
with SCCallSiteSync(self._sc) as css:
sock_info = self._jdf.collectAsArrowToPython()
return list(_load_from_socket(sock_info, ArrowSerializer()))
return list(_load_from_socket(sock_info, ArrowStreamSerializer()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to update the description of _collectAsArrow()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, thanks!


##########################################################################################
# Pandas compatibility
Expand Down
12 changes: 7 additions & 5 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer, _create_batch
from pyspark.serializers import ArrowStreamSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
from pyspark.sql.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version
Expand Down Expand Up @@ -539,10 +539,12 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
struct.names[i] = name
schema = struct

# Create the Spark DataFrame directly from the Arrow data and schema
jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer())
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
jrdd, schema.json(), self._wrapped._jsqlContext)
def reader_func(temp_filename):
return self._jvm.PythonSQLUtils.arrowReadStreamFromFile(
self._wrapped._jsqlContext, temp_filename, schema.json())

# Create Spark DataFrame from Arrow stream file, using one batch per partition
jdf = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
return df
Expand Down
56 changes: 46 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.python.EvaluatePython
Expand Down Expand Up @@ -3273,13 +3273,49 @@ class Dataset[T] private[sql](
}

/**
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
* Collect a Dataset as Arrow batches and serve stream to PySpark.
*/
private[sql] def collectAsArrowToPython(): Array[Any] = {
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone

withAction("collectAsArrowToPython", queryExecution) { plan =>
val iter: Iterator[Array[Byte]] =
toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
PythonRDD.serveIterator(iter, "serve-Arrow")
PythonRDD.serveToStream("serve-Arrow") { out =>
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length

// Store collection results for worst case of 1 to N-1 partitions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better 0 to N-1 partitions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not necessary to buffer the first partition because it can be sent to Python right away, so only need an array of size N-1

val results = new Array[Array[Array[Byte]]](numPartitions - 1)
var lastIndex = -1 // index of last partition written

// Handler to eagerly write partitions to Python in order
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
// If result is from next partition in order
if (index - 1 == lastIndex) {
batchWriter.writeBatches(arrowBatches.iterator)
lastIndex += 1
// Write stored partitions that come next in order
while (lastIndex < results.length && results(lastIndex) != null) {
batchWriter.writeBatches(results(lastIndex).iterator)
results(lastIndex) = null
lastIndex += 1
}
// After last batch, end the stream
if (lastIndex == results.length) {
batchWriter.end()
}
} else {
// Store partitions received out of order
results(index - 1) = arrowBatches
}
}

sparkSession.sparkContext.runJob(
arrowBatchRdd,
(ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call handlePartitionBatches here before it.toArray? I'd do it.toArray as lazy as possible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried playing around with that a while ago and can't remember if there was some problem, but I'll give it another shot.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this again, it.toArray is run on the executor, which ends up doing the same thing as collect() and then handlePartitions is run on the results of that in the driver. The task results need to be serialized, so I'm not sure if we can avoid it.toArray here, any thoughts @ueshin ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. In that case, we need to do it.toArray. Thanks.

0 until numPartitions,
handlePartitionBatches)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of collecting partitions back at once and holding out of order partitions in driver waiting for partitions in order, is it better to incrementally run job on partitions in order and send streams to python side? So we don't need to hold out of order partitions in driver.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 chunking if we could. I recall Bryan said for grouped UDF we need the entire set.

Also not sure if python side we have any assumption on how much of the partition is in each chunk (there shouldn't be?)

Copy link
Member Author

@BryanCutler BryanCutler Jun 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it better to incrementally run job on partitions in order

I believe this is how toLocalIterator works right? I tried using that because it does only keep 1 partition in memory at a time, but the performance took quite a hit from the multiple jobs. I think we should still prioritize performance over memory for toPandas() since it's assumed the data to be collect should be relatively small.

I did have another idea though, we could stream all partitions to Python out of order, then follow with another small batch of data that contains maps of partitionIndex to orderReceived. Then the partitions could be put into order on the Python side before making the Pandas DataFrame.

Copy link
Member Author

@BryanCutler BryanCutler Jun 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 chunking if we could. I recall Bryan said for grouped UDF we need the entire set.

This still keeps Arrow record batches chunked within each partition, which can help the executor memory, but doesn't do anything for the driver side because we still need to collect the entire partition in the driver JVM.

Also not sure if python side we have any assumption on how much of the partition is in each chunk (there shouldn't be?)

No, Python doesn't care how many chunks the data is in, it's handled by pyarrow

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did have another idea though, we could stream all partitions to Python out of order, then follow with another small batch of data that contains maps of partitionIndex to orderReceived. Then the partitions could be put into order on the Python side before making the Pandas DataFrame.

This sounds good!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess in worst case scenario, the driver still needs to hold all batches in memory. For example, all the batches arrive at the same time.

I wonder if there is a way to:
(1) Compute all tasks in parallell, once tasks are done, store the result in Block manager on executors.
(2) Return all block id to the driver
(3) Driver fetches each block and stream individually.

This way at least the computation is done in parallel, fetching the result sequentially is a trade off of speed vs memory, something we or the user can choose, but I imagine fetching some 10G - 20G data from executors sequentially shouldn't be too bad.

}
}
}

Expand Down Expand Up @@ -3386,20 +3422,20 @@ class Dataset[T] private[sql](
}
}

/** Convert to an RDD of ArrowPayload byte arrays */
private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = {
/** Convert to an RDD of serialized ArrowRecordBatches. */
private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
plan.execute().mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toPayloadIterator(
ArrowConverters.toBatchIterator(
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
}
}

// This is only used in tests, for now.
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
toArrowPayload(queryExecution.executedPlan)
private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
toArrowBatchRdd(queryExecution.executedPlan)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.api.python

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
Expand All @@ -34,17 +33,19 @@ private[sql] object PythonSQLUtils {
}

/**
* Python Callable function to convert ArrowPayloads into a [[DataFrame]].
* Python callable function to read a file in Arrow stream format and create a [[DataFrame]]
* using each serialized ArrowRecordBatch as a partition.
*
* @param payloadRDD A JavaRDD of ArrowPayloads.
* @param schemaString JSON Formatted Schema for ArrowPayloads.
* @param sqlContext The active [[SQLContext]].
* @return The converted [[DataFrame]].
* @param filename File to read the Arrow stream from.
* @param schemaString JSON Formatted Spark schema for Arrow batches.
* @return A new [[DataFrame]].
*/
def arrowPayloadToDataFrame(
payloadRDD: JavaRDD[Array[Byte]],
schemaString: String,
sqlContext: SQLContext): DataFrame = {
ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext)
def arrowReadStreamFromFile(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call it arrowFileToDataFrame or something... arrowReadStreamFromFile and readArrowStreamFromFile are just too similar...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arrowStreamFromFile is important to get in the name since it is a stream format being read from a file, but how about arrowStreamFromFileToDataFrame? Its a bit long but it would be good to indicate that it produces a DataFrame for the call from Python.

sqlContext: SQLContext,
filename: String,
schemaString: String): DataFrame = {
val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename)
ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext)
}
}
Loading