Skip to content

Commit ecaa495

Browse files
committed
[SPARK-25274][PYTHON][SQL] In toPandas with Arrow send un-ordered record batches to improve performance
## What changes were proposed in this pull request? When executing `toPandas` with Arrow enabled, partitions that arrive in the JVM out-of-order must be buffered before they can be send to Python. This causes an excess of memory to be used in the driver JVM and increases the time it takes to complete because data must sit in the JVM waiting for preceding partitions to come in. This change sends un-ordered partitions to Python as soon as they arrive in the JVM, followed by a list of partition indices so that Python can assemble the data in the correct order. This way, data is not buffered at the JVM and there is no waiting on particular partitions so performance will be increased. Followup to #21546 ## How was this patch tested? Added new test with a large number of batches per partition, and test that forces a small delay in the first partition. These test that partitions are collected out-of-order and then are are put in the correct order in Python. ## Performance Tests - toPandas Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu and OpenJDK 8 measured wall clock time to execute `toPandas()` and took the average best time of 5 runs/5 loops each. Test code ```python df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1", rand()).withColumn("x2", rand()).withColumn("x3", rand()).withColumn("x4", rand()) for i in range(5): start = time.time() _ = df.toPandas() elapsed = time.time() - start ``` Spark config ``` spark.driver.memory 5g spark.executor.memory 5g spark.driver.maxResultSize 2g spark.sql.execution.arrow.enabled true ``` Current Master w/ Arrow stream | This PR ---------------------|------------ 5.16207 | 4.342533 5.133671 | 4.399408 5.147513 | 4.468471 5.105243 | 4.36524 5.018685 | 4.373791 Avg Master | Avg This PR ------------------|-------------- 5.1134364 | 4.3898886 Speedup of **1.164821449** Closes #22275 from BryanCutler/arrow-toPandas-oo-batches-SPARK-25274. Authored-by: Bryan Cutler <[email protected]> Signed-off-by: Bryan Cutler <[email protected]>
1 parent ab76900 commit ecaa495

File tree

4 files changed

+95
-22
lines changed

4 files changed

+95
-22
lines changed

python/pyspark/serializers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,39 @@ def loads(self, obj):
185185
raise NotImplementedError
186186

187187

188+
class ArrowCollectSerializer(Serializer):
189+
"""
190+
Deserialize a stream of batches followed by batch order information. Used in
191+
DataFrame._collectAsArrow() after invoking Dataset.collectAsArrowToPython() in the JVM.
192+
"""
193+
194+
def __init__(self):
195+
self.serializer = ArrowStreamSerializer()
196+
197+
def dump_stream(self, iterator, stream):
198+
return self.serializer.dump_stream(iterator, stream)
199+
200+
def load_stream(self, stream):
201+
"""
202+
Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields
203+
a list of indices that can be used to put the RecordBatches in the correct order.
204+
"""
205+
# load the batches
206+
for batch in self.serializer.load_stream(stream):
207+
yield batch
208+
209+
# load the batch order indices
210+
num = read_int(stream)
211+
batch_order = []
212+
for i in xrange(num):
213+
index = read_int(stream)
214+
batch_order.append(index)
215+
yield batch_order
216+
217+
def __repr__(self):
218+
return "ArrowCollectSerializer(%s)" % self.serializer
219+
220+
188221
class ArrowStreamSerializer(Serializer):
189222
"""
190223
Serializes Arrow record batches as a stream.

python/pyspark/sql/dataframe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from pyspark import copy_func, since, _NoValue
3131
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
32-
from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \
32+
from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \
3333
UTF8Deserializer
3434
from pyspark.storagelevel import StorageLevel
3535
from pyspark.traceback_utils import SCCallSiteSync
@@ -2168,7 +2168,14 @@ def _collectAsArrow(self):
21682168
"""
21692169
with SCCallSiteSync(self._sc) as css:
21702170
sock_info = self._jdf.collectAsArrowToPython()
2171-
return list(_load_from_socket(sock_info, ArrowStreamSerializer()))
2171+
2172+
# Collect list of un-ordered batches where last element is a list of correct order indices
2173+
results = list(_load_from_socket(sock_info, ArrowCollectSerializer()))
2174+
batches = results[:-1]
2175+
batch_order = results[-1]
2176+
2177+
# Re-order the batch list using the correct order
2178+
return [batches[i] for i in batch_order]
21722179

21732180
##########################################################################################
21742181
# Pandas compatibility

python/pyspark/sql/tests/test_arrow.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,34 @@ def test_timestamp_dst(self):
381381
self.assertPandasEqual(pdf, df_from_python.toPandas())
382382
self.assertPandasEqual(pdf, df_from_pandas.toPandas())
383383

384+
def test_toPandas_batch_order(self):
385+
386+
def delay_first_part(partition_index, iterator):
387+
if partition_index == 0:
388+
time.sleep(0.1)
389+
return iterator
390+
391+
# Collects Arrow RecordBatches out of order in driver JVM then re-orders in Python
392+
def run_test(num_records, num_parts, max_records, use_delay=False):
393+
df = self.spark.range(num_records, numPartitions=num_parts).toDF("a")
394+
if use_delay:
395+
df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF()
396+
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}):
397+
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
398+
self.assertPandasEqual(pdf, pdf_arrow)
399+
400+
cases = [
401+
(1024, 512, 2), # Use large num partitions for more likely collecting out of order
402+
(64, 8, 2, True), # Use delay in first partition to force collecting out of order
403+
(64, 64, 1), # Test single batch per partition
404+
(64, 1, 64), # Test single partition, single batch
405+
(64, 1, 8), # Test single partition, multiple batches
406+
(30, 7, 2), # Test different sized partitions
407+
]
408+
409+
for case in cases:
410+
run_test(*case)
411+
384412

385413
class EncryptionArrowTests(ArrowTests):
386414

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.sql
1919

20-
import java.io.CharArrayWriter
20+
import java.io.{CharArrayWriter, DataOutputStream}
2121

2222
import scala.collection.JavaConverters._
23+
import scala.collection.mutable.ArrayBuffer
2324
import scala.language.implicitConversions
2425
import scala.util.control.NonFatal
2526

@@ -3200,34 +3201,38 @@ class Dataset[T] private[sql](
32003201
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
32013202

32023203
withAction("collectAsArrowToPython", queryExecution) { plan =>
3203-
PythonRDD.serveToStream("serve-Arrow") { out =>
3204+
PythonRDD.serveToStream("serve-Arrow") { outputStream =>
3205+
val out = new DataOutputStream(outputStream)
32043206
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
32053207
val arrowBatchRdd = toArrowBatchRdd(plan)
32063208
val numPartitions = arrowBatchRdd.partitions.length
32073209

3208-
// Store collection results for worst case of 1 to N-1 partitions
3209-
val results = new Array[Array[Array[Byte]]](numPartitions - 1)
3210-
var lastIndex = -1 // index of last partition written
3210+
// Batches ordered by (index of partition, batch index in that partition) tuple
3211+
val batchOrder = new ArrayBuffer[(Int, Int)]()
3212+
var partitionCount = 0
32113213

3212-
// Handler to eagerly write partitions to Python in order
3214+
// Handler to eagerly write batches to Python as they arrive, un-ordered
32133215
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
3214-
// If result is from next partition in order
3215-
if (index - 1 == lastIndex) {
3216+
if (arrowBatches.nonEmpty) {
3217+
// Write all batches (can be more than 1) in the partition, store the batch order tuple
32163218
batchWriter.writeBatches(arrowBatches.iterator)
3217-
lastIndex += 1
3218-
// Write stored partitions that come next in order
3219-
while (lastIndex < results.length && results(lastIndex) != null) {
3220-
batchWriter.writeBatches(results(lastIndex).iterator)
3221-
results(lastIndex) = null
3222-
lastIndex += 1
3219+
arrowBatches.indices.foreach {
3220+
partition_batch_index => batchOrder.append((index, partition_batch_index))
32233221
}
3224-
// After last batch, end the stream
3225-
if (lastIndex == results.length) {
3226-
batchWriter.end()
3222+
}
3223+
partitionCount += 1
3224+
3225+
// After last batch, end the stream and write batch order indices
3226+
if (partitionCount == numPartitions) {
3227+
batchWriter.end()
3228+
out.writeInt(batchOrder.length)
3229+
// Sort by (index of partition, batch index in that partition) tuple to get the
3230+
// overall_batch_index from 0 to N-1 batches, which can be used to put the
3231+
// transferred batches in the correct order
3232+
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) =>
3233+
out.writeInt(overall_batch_index)
32273234
}
3228-
} else {
3229-
// Store partitions received out of order
3230-
results(index - 1) = arrowBatches
3235+
out.flush()
32313236
}
32323237
}
32333238

0 commit comments

Comments
 (0)