diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ff9a612b77f61..f3ebd3767a0a1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -185,6 +185,39 @@ def loads(self, obj): raise NotImplementedError +class ArrowCollectSerializer(Serializer): + """ + Deserialize a stream of batches followed by batch order information. Used in + DataFrame._collectAsArrow() after invoking Dataset.collectAsArrowToPython() in the JVM. + """ + + def __init__(self): + self.serializer = ArrowStreamSerializer() + + def dump_stream(self, iterator, stream): + return self.serializer.dump_stream(iterator, stream) + + def load_stream(self, stream): + """ + Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields + a list of indices that can be used to put the RecordBatches in the correct order. + """ + # load the batches + for batch in self.serializer.load_stream(stream): + yield batch + + # load the batch order indices + num = read_int(stream) + batch_order = [] + for i in xrange(num): + index = read_int(stream) + batch_order.append(index) + yield batch_order + + def __repr__(self): + return "ArrowCollectSerializer(%s)" % self.serializer + + class ArrowStreamSerializer(Serializer): """ Serializes Arrow record batches as a stream. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1b1092c409be0..a1056d0b787e3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,7 @@ from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \ +from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -2168,7 +2168,14 @@ def _collectAsArrow(self): """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(sock_info, ArrowStreamSerializer())) + + # Collect list of un-ordered batches where last element is a list of correct order indices + results = list(_load_from_socket(sock_info, ArrowCollectSerializer())) + batches = results[:-1] + batch_order = results[-1] + + # Re-order the batch list using the correct order + return [batches[i] for i in batch_order] ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 6e75e82d58009..21fe5000df5d9 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -381,6 +381,34 @@ def test_timestamp_dst(self): self.assertPandasEqual(pdf, df_from_python.toPandas()) self.assertPandasEqual(pdf, df_from_pandas.toPandas()) + def test_toPandas_batch_order(self): + + def delay_first_part(partition_index, iterator): + if partition_index == 0: + time.sleep(0.1) + return iterator + + # Collects Arrow RecordBatches out of order in driver JVM then re-orders in Python + def run_test(num_records, num_parts, max_records, use_delay=False): + df = self.spark.range(num_records, numPartitions=num_parts).toDF("a") + if use_delay: + df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF() + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}): + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf, pdf_arrow) + + cases = [ + (1024, 512, 2), # Use large num partitions for more likely collecting out of order + (64, 8, 2, True), # Use delay in first partition to force collecting out of order + (64, 64, 1), # Test single batch per partition + (64, 1, 64), # Test single partition, single batch + (64, 1, 8), # Test single partition, multiple batches + (30, 7, 2), # Test different sized partitions + ] + + for case in cases: + run_test(*case) + class EncryptionArrowTests(ArrowTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b10d66dfb1aef..a664c7338badb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql -import java.io.CharArrayWriter +import java.io.{CharArrayWriter, DataOutputStream} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.util.control.NonFatal @@ -3200,34 +3201,38 @@ class Dataset[T] private[sql]( val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone withAction("collectAsArrowToPython", queryExecution) { plan => - PythonRDD.serveToStream("serve-Arrow") { out => + PythonRDD.serveToStream("serve-Arrow") { outputStream => + val out = new DataOutputStream(outputStream) val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) val arrowBatchRdd = toArrowBatchRdd(plan) val numPartitions = arrowBatchRdd.partitions.length - // Store collection results for worst case of 1 to N-1 partitions - val results = new Array[Array[Array[Byte]]](numPartitions - 1) - var lastIndex = -1 // index of last partition written + // Batches ordered by (index of partition, batch index in that partition) tuple + val batchOrder = new ArrayBuffer[(Int, Int)]() + var partitionCount = 0 - // Handler to eagerly write partitions to Python in order + // Handler to eagerly write batches to Python as they arrive, un-ordered def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { - // If result is from next partition in order - if (index - 1 == lastIndex) { + if (arrowBatches.nonEmpty) { + // Write all batches (can be more than 1) in the partition, store the batch order tuple batchWriter.writeBatches(arrowBatches.iterator) - lastIndex += 1 - // Write stored partitions that come next in order - while (lastIndex < results.length && results(lastIndex) != null) { - batchWriter.writeBatches(results(lastIndex).iterator) - results(lastIndex) = null - lastIndex += 1 + arrowBatches.indices.foreach { + partition_batch_index => batchOrder.append((index, partition_batch_index)) } - // After last batch, end the stream - if (lastIndex == results.length) { - batchWriter.end() + } + partitionCount += 1 + + // After last batch, end the stream and write batch order indices + if (partitionCount == numPartitions) { + batchWriter.end() + out.writeInt(batchOrder.length) + // Sort by (index of partition, batch index in that partition) tuple to get the + // overall_batch_index from 0 to N-1 batches, which can be used to put the + // transferred batches in the correct order + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) => + out.writeInt(overall_batch_index) } - } else { - // Store partitions received out of order - results(index - 1) = arrowBatches + out.flush() } }