From 087564e4c202f79b8f01fb82e073a7ad2f0141fd Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 29 Aug 2018 14:26:49 -0700 Subject: [PATCH 1/7] changed toPandas to send out of order batches, followed by batch order indices --- python/pyspark/serializers.py | 35 +++++++++++++++ python/pyspark/sql/dataframe.py | 18 +++++--- python/pyspark/sql/tests.py | 6 +++ .../scala/org/apache/spark/sql/Dataset.scala | 43 ++++++++++--------- 4 files changed, 74 insertions(+), 28 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 48006778e86f2..70aade6adb8d5 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -185,6 +185,41 @@ def loads(self, obj): raise NotImplementedError +class BatchOrderSerializer(Serializer): + """ + Deserialize a stream of batches followed by batch order information. + """ + + def __init__(self, serializer): + self.serializer = serializer + self.batch_order = None + + def dump_stream(self, iterator, stream): + return self.serializer.dump_stream(iterator, stream) + + def load_stream(self, stream): + for batch in self.serializer.load_stream(stream): + yield batch + num = read_int(stream) + self.batch_order = [] + for i in xrange(num): + index = read_int(stream) + self.batch_order.append(index) + + def get_batch_order_and_reset(self): + """ + Returns a list of indices to put batches read from load_stream in the correct order. + This must be called after load_stream and will clear the batch order after calling. + """ + assert self.batch_order is not None, "Must call load_stream first to read batch order" + batch_order = self.batch_order + self.batch_order = None + return batch_order + + def __repr__(self): + return "BatchOrderSerializer(%s)" % self.serializer + + class ArrowStreamSerializer(Serializer): """ Serializes Arrow record batches as a stream. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1affc9b4fcf6c..1e3e09d14d06b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,8 +29,8 @@ from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \ - UTF8Deserializer +from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, BatchOrderSerializer, \ + PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -2118,9 +2118,12 @@ def toPandas(self): from pyspark.sql.types import _check_dataframe_convert_date, \ _check_dataframe_localize_timestamps import pyarrow - batches = self._collectAsArrow() + + # Collect un-ordered list of batches, and list of correct order indices + batches, batch_order = self._collectAsArrow() if len(batches) > 0: - table = pyarrow.Table.from_batches(batches) + # Re-order the batch list with correct order to build a table + table = pyarrow.Table.from_batches([batches[i] for i in batch_order]) pdf = table.to_pandas() pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) @@ -2169,14 +2172,15 @@ def toPandas(self): def _collectAsArrow(self): """ - Returns all records as a list of ArrowRecordBatches, pyarrow must be installed - and available on driver and worker Python environments. + Returns all records as a list of ArrowRecordBatches and batch order as a list of indices, + pyarrow must be installed and available on driver and worker Python environments. .. note:: Experimental. """ + ser = BatchOrderSerializer(ArrowStreamSerializer()) with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(sock_info, ArrowStreamSerializer())) + return list(_load_from_socket(sock_info, ser)), ser.get_batch_order_and_reset() ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 81c0af0b3d81b..00e2b450011b0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4434,6 +4434,12 @@ def test_timestamp_dst(self): self.assertPandasEqual(pdf, df_from_python.toPandas()) self.assertPandasEqual(pdf, df_from_pandas.toPandas()) + def test_toPandas_batch_order(self): + df = self.spark.range(64, numPartitions=8).toDF("a") + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 4}): + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf, pdf_arrow) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index db439b1ee76f1..86d6367e38c70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql -import java.io.CharArrayWriter +import java.io.{CharArrayWriter, DataOutputStream} import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -3279,34 +3280,34 @@ class Dataset[T] private[sql]( val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone withAction("collectAsArrowToPython", queryExecution) { plan => - PythonRDD.serveToStream("serve-Arrow") { out => + PythonRDD.serveToStream("serve-Arrow") { outputStream => + val out = new DataOutputStream(outputStream) val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) val arrowBatchRdd = toArrowBatchRdd(plan) val numPartitions = arrowBatchRdd.partitions.length - // Store collection results for worst case of 1 to N-1 partitions - val results = new Array[Array[Array[Byte]]](numPartitions - 1) - var lastIndex = -1 // index of last partition written + // Batches ordered by (index of partition, batch # in partition) tuple + val batchOrder = new ArrayBuffer[(Int, Int)]() + var partitionCount = 0 - // Handler to eagerly write partitions to Python in order + // Handler to eagerly write batches to Python out of order def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { - // If result is from next partition in order - if (index - 1 == lastIndex) { + if (arrowBatches.nonEmpty) { batchWriter.writeBatches(arrowBatches.iterator) - lastIndex += 1 - // Write stored partitions that come next in order - while (lastIndex < results.length && results(lastIndex) != null) { - batchWriter.writeBatches(results(lastIndex).iterator) - results(lastIndex) = null - lastIndex += 1 - } - // After last batch, end the stream - if (lastIndex == results.length) { - batchWriter.end() + arrowBatches.indices.foreach { i => batchOrder.append((index, i)) } + } + partitionCount += 1 + + // After last batch, end the stream and write batch order + if (partitionCount == numPartitions) { + batchWriter.end() + out.writeInt(batchOrder.length) + // Batch order indices are from 0 to N-1 batches, sorted by order they arrived. + // Re-order batches according to these indices to build a table. + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, i) => + out.writeInt(i) } - } else { - // Store partitions received out of order - results(index - 1) = arrowBatches + out.flush() } } From 6073ed9a5aecea6f8f1c172412fecc1ae4d28720 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 30 Oct 2018 14:53:13 -0700 Subject: [PATCH 2/7] reorg the batch order deser to simplify --- python/pyspark/serializers.py | 32 +++++++++---------- python/pyspark/sql/dataframe.py | 24 ++++++++------ .../scala/org/apache/spark/sql/Dataset.scala | 18 +++++++---- 3 files changed, 40 insertions(+), 34 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 70aade6adb8d5..d285e6518ec1e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -185,39 +185,37 @@ def loads(self, obj): raise NotImplementedError -class BatchOrderSerializer(Serializer): +class ArrowCollectSerializer(Serializer): """ Deserialize a stream of batches followed by batch order information. """ - def __init__(self, serializer): - self.serializer = serializer - self.batch_order = None + def __init__(self): + self.serializer = ArrowStreamSerializer() def dump_stream(self, iterator, stream): return self.serializer.dump_stream(iterator, stream) def load_stream(self, stream): + """ + Load a stream of un-ordered Arrow RecordBatches, where the last + iteration will yield a list of indices to put the RecordBatches in + the correct order. + """ + # load the batches for batch in self.serializer.load_stream(stream): yield batch + + # load the batch order indices num = read_int(stream) - self.batch_order = [] + batch_order = [] for i in xrange(num): index = read_int(stream) - self.batch_order.append(index) - - def get_batch_order_and_reset(self): - """ - Returns a list of indices to put batches read from load_stream in the correct order. - This must be called after load_stream and will clear the batch order after calling. - """ - assert self.batch_order is not None, "Must call load_stream first to read batch order" - batch_order = self.batch_order - self.batch_order = None - return batch_order + batch_order.append(index) + yield batch_order def __repr__(self): - return "BatchOrderSerializer(%s)" % self.serializer + return "ArrowCollectSerializer(%s)" % self.serializer class ArrowStreamSerializer(Serializer): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1e3e09d14d06b..73fbaa01a6e84 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,8 +29,8 @@ from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, BatchOrderSerializer, \ - PickleSerializer, UTF8Deserializer +from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \ + UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -2119,11 +2119,9 @@ def toPandas(self): _check_dataframe_localize_timestamps import pyarrow - # Collect un-ordered list of batches, and list of correct order indices - batches, batch_order = self._collectAsArrow() + batches = self._collectAsArrow() if len(batches) > 0: - # Re-order the batch list with correct order to build a table - table = pyarrow.Table.from_batches([batches[i] for i in batch_order]) + table = pyarrow.Table.from_batches(batches) pdf = table.to_pandas() pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) @@ -2172,15 +2170,21 @@ def toPandas(self): def _collectAsArrow(self): """ - Returns all records as a list of ArrowRecordBatches and batch order as a list of indices, - pyarrow must be installed and available on driver and worker Python environments. + Returns all records as a list of ArrowRecordBatches, pyarrow must be installed + and available on driver and worker Python environments. .. note:: Experimental. """ - ser = BatchOrderSerializer(ArrowStreamSerializer()) with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(sock_info, ser)), ser.get_batch_order_and_reset() + + # Collect list of un-ordered batches where last element is a list of correct order indices + results = list(_load_from_socket(sock_info, ArrowCollectSerializer())) + batches = results[:-1] + batch_order = results[-1] + + # Re-order the batch list using the correct order + return [batches[i] for i in batch_order] ########################################################################################## # Pandas compatibility diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 86d6367e38c70..260d314e1d898 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3286,26 +3286,30 @@ class Dataset[T] private[sql]( val arrowBatchRdd = toArrowBatchRdd(plan) val numPartitions = arrowBatchRdd.partitions.length - // Batches ordered by (index of partition, batch # in partition) tuple + // Batches ordered by (index of partition, batch index in that partition) tuple val batchOrder = new ArrayBuffer[(Int, Int)]() var partitionCount = 0 // Handler to eagerly write batches to Python out of order def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { if (arrowBatches.nonEmpty) { + // Write all batches (can be more than 1) in the partition, store the batch order tuple batchWriter.writeBatches(arrowBatches.iterator) - arrowBatches.indices.foreach { i => batchOrder.append((index, i)) } + arrowBatches.indices.foreach { + partition_batch_index => batchOrder.append((index, partition_batch_index)) + } } partitionCount += 1 - // After last batch, end the stream and write batch order + // After last batch, end the stream and write batch order indices if (partitionCount == numPartitions) { batchWriter.end() out.writeInt(batchOrder.length) - // Batch order indices are from 0 to N-1 batches, sorted by order they arrived. - // Re-order batches according to these indices to build a table. - batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, i) => - out.writeInt(i) + // Sort by (index of partition, batch index in that partition) tuple to get the + // overall_batch_index from 0 to N-1 batches, which can be used to put the + // transferred batches in the correct order + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) => + out.writeInt(overall_batch_index) } out.flush() } From 0d77b0051546eb3a178a06a95f35ccb20ddeeff6 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 30 Oct 2018 15:56:59 -0700 Subject: [PATCH 3/7] expanded test for more cases --- python/pyspark/sql/tests.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 00e2b450011b0..d648e9bdd5869 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4435,10 +4435,26 @@ def test_timestamp_dst(self): self.assertPandasEqual(pdf, df_from_pandas.toPandas()) def test_toPandas_batch_order(self): - df = self.spark.range(64, numPartitions=8).toDF("a") - with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 4}): - pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf, pdf_arrow) + + # Collects Arrow RecordBatches out of order in driver JVM then re-orders in Python + def run_test(num_records, num_parts, max_records): + df = self.spark.range(num_records, numPartitions=num_parts).toDF("a") + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}): + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf, pdf_arrow) + + cases = [ + (1024, 512, 2), # Try large num partitions for good chance of not collecting in order + (512, 64, 2), # Try medium num partitions to test out of order collection + (64, 8, 2), # Try small number of partitions to test out of order collection + (64, 64, 1), # Test single batch per partition + (64, 1, 64), # Test single partition, single batch + (64, 1, 8), # Test single partition, multiple batches + (30, 7, 2), # Test different sized partitions + ] + + for case in cases: + run_test(num_records=case[0], num_parts=case[1], max_records=case[2]) @unittest.skipIf( From 6457e420e3b8366c1373e7adb0bf56df03b9cc19 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 30 Oct 2018 16:34:16 -0700 Subject: [PATCH 4/7] remove blank line --- python/pyspark/sql/dataframe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ef7a4ce5c09c2..1334ef3638389 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2125,7 +2125,6 @@ def toPandas(self): from pyspark.sql.types import _check_dataframe_convert_date, \ _check_dataframe_localize_timestamps import pyarrow - batches = self._collectAsArrow() if len(batches) > 0: table = pyarrow.Table.from_batches(batches) From bf2feec2ef023177d72ac1137dbd1b3a02eb9a89 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 6 Nov 2018 15:13:54 -0800 Subject: [PATCH 5/7] add test case with delay --- python/pyspark/sql/tests.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6498dae988a20..57c8959e0398e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4925,25 +4925,31 @@ def test_timestamp_dst(self): def test_toPandas_batch_order(self): + def delay_first_part(partition_index, iterator): + if partition_index == 0: + time.sleep(0.1) + return iterator + # Collects Arrow RecordBatches out of order in driver JVM then re-orders in Python - def run_test(num_records, num_parts, max_records): + def run_test(num_records, num_parts, max_records, use_delay=False): df = self.spark.range(num_records, numPartitions=num_parts).toDF("a") + if use_delay: + df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF() with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}): pdf, pdf_arrow = self._toPandas_arrow_toggle(df) self.assertPandasEqual(pdf, pdf_arrow) cases = [ - (1024, 512, 2), # Try large num partitions for good chance of not collecting in order - (512, 64, 2), # Try medium num partitions to test out of order collection - (64, 8, 2), # Try small number of partitions to test out of order collection - (64, 64, 1), # Test single batch per partition - (64, 1, 64), # Test single partition, single batch - (64, 1, 8), # Test single partition, multiple batches - (30, 7, 2), # Test different sized partitions + (1024, 512, 2), # Use large num partitions for more likely collecting out of order + (64, 8, 2, True), # Use delay in first partition to force collecting out of order + (64, 64, 1), # Test single batch per partition + (64, 1, 64), # Test single partition, single batch + (64, 1, 8), # Test single partition, multiple batches + (30, 7, 2), # Test different sized partitions ] for case in cases: - run_test(num_records=case[0], num_parts=case[1], max_records=case[2]) + run_test(*case) class EncryptionArrowTests(ArrowTests): From 7dc92c8d0dca69e254088fd6e1f3e15da1f90fbe Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 8 Nov 2018 15:52:26 -0800 Subject: [PATCH 6/7] fixed some comments --- python/pyspark/serializers.py | 8 ++++---- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index c026fda61ad2b..f3ebd3767a0a1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -187,7 +187,8 @@ def loads(self, obj): class ArrowCollectSerializer(Serializer): """ - Deserialize a stream of batches followed by batch order information. + Deserialize a stream of batches followed by batch order information. Used in + DataFrame._collectAsArrow() after invoking Dataset.collectAsArrowToPython() in the JVM. """ def __init__(self): @@ -198,9 +199,8 @@ def dump_stream(self, iterator, stream): def load_stream(self, stream): """ - Load a stream of un-ordered Arrow RecordBatches, where the last - iteration will yield a list of indices to put the RecordBatches in - the correct order. + Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields + a list of indices that can be used to put the RecordBatches in the correct order. """ # load the batches for batch in self.serializer.load_stream(stream): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 11ef30a72c07f..664d852ebd519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3179,7 +3179,7 @@ class Dataset[T] private[sql]( val batchOrder = new ArrayBuffer[(Int, Int)]() var partitionCount = 0 - // Handler to eagerly write batches to Python out of order + // Handler to eagerly write batches to Python un-ordered def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { if (arrowBatches.nonEmpty) { // Write all batches (can be more than 1) in the partition, store the batch order tuple From 8045facbe523c89b91b930203bb6874d82d08a4d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 8 Nov 2018 15:53:12 -0800 Subject: [PATCH 7/7] fix comment --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 664d852ebd519..3f2888d67ff39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3179,7 +3179,7 @@ class Dataset[T] private[sql]( val batchOrder = new ArrayBuffer[(Int, Int)]() var partitionCount = 0 - // Handler to eagerly write batches to Python un-ordered + // Handler to eagerly write batches to Python as they arrive, un-ordered def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { if (arrowBatches.nonEmpty) { // Write all batches (can be more than 1) in the partition, store the batch order tuple