From ec8d280640286fd943fdee6847f53f43d39b27d5 Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Mon, 20 May 2019 08:56:30 -0400 Subject: [PATCH 1/5] do after all partitions code when 0 partitions --- .../scala/org/apache/spark/sql/Dataset.scala | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 74cb3e627432c..7f86b74ebe3aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3310,16 +3310,20 @@ class Dataset[T] private[sql]( // After last batch, end the stream and write batch order indices if (partitionCount == numPartitions) { - batchWriter.end() - out.writeInt(batchOrder.length) - // Sort by (index of partition, batch index in that partition) tuple to get the - // overall_batch_index from 0 to N-1 batches, which can be used to put the - // transferred batches in the correct order - batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => - out.writeInt(overallBatchIndex) - } - out.flush() + doAfterLastPartition() + } + } + + def doAfterLastPartition(): Unit = { + batchWriter.end() + out.writeInt(batchOrder.length) + // Sort by (index of partition, batch index in that partition) tuple to get the + // overall_batch_index from 0 to N-1 batches, which can be used to put the + // transferred batches in the correct order + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => + out.writeInt(overallBatchIndex) } + out.flush() } sparkSession.sparkContext.runJob( @@ -3327,6 +3331,10 @@ class Dataset[T] private[sql]( (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, 0 until numPartitions, handlePartitionBatches) + + if (numPartitions == 0) { + doAfterLastPartition() + } } } } From d635a74cf211786a3b2a400c42afc4a4cea36010 Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Mon, 20 May 2019 09:21:07 -0400 Subject: [PATCH 2/5] pytest --- python/pyspark/sql/tests/test_arrow.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 22578cbe2e98c..fd7f85a1780aa 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -183,6 +183,13 @@ def test_filtered_frame(self): self.assertEqual(pdf.columns[0], "i") self.assertTrue(pdf.empty) + def test_no_partition_frame(self): + schema = StructType([StructField("field1", StringType(), True)]) + df = self.spark.createDataFrame(self.sc.emptyRDD(), schema) + pdf = df.toPandas() + self.assertEqual(len(pdf.columns), 1) + self.assertTrue(pdf.empty) + def _createDataFrame_toggle(self, pdf, schema=None): with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) From 52a51bf2a83e23e706356def3aa9b1e7ef1e2497 Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Mon, 20 May 2019 21:41:16 -0400 Subject: [PATCH 3/5] assert column name --- python/pyspark/sql/tests/test_arrow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index fd7f85a1780aa..f5b5ad9cdf214 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -188,6 +188,7 @@ def test_no_partition_frame(self): df = self.spark.createDataFrame(self.sc.emptyRDD(), schema) pdf = df.toPandas() self.assertEqual(len(pdf.columns), 1) + self.assertEqual(pdf.columns[0], "field1") self.assertTrue(pdf.empty) def _createDataFrame_toggle(self, pdf, schema=None): From 9f4bc3ee0331d091563d96468fe98bb7eb60c58e Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Tue, 21 May 2019 09:31:59 -0400 Subject: [PATCH 4/5] always write batch order after runJob --- .../scala/org/apache/spark/sql/Dataset.scala | 40 ++++++------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7f86b74ebe3aa..72b7506b2531c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3290,12 +3290,9 @@ class Dataset[T] private[sql]( PythonRDD.serveToStream("serve-Arrow") { outputStream => val out = new DataOutputStream(outputStream) val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) - val arrowBatchRdd = toArrowBatchRdd(plan) - val numPartitions = arrowBatchRdd.partitions.length // Batches ordered by (index of partition, batch index in that partition) tuple val batchOrder = ArrayBuffer.empty[(Int, Int)] - var partitionCount = 0 // Handler to eagerly write batches to Python as they arrive, un-ordered def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { @@ -3306,35 +3303,24 @@ class Dataset[T] private[sql]( partitionBatchIndex => batchOrder.append((index, partitionBatchIndex)) } } - partitionCount += 1 - - // After last batch, end the stream and write batch order indices - if (partitionCount == numPartitions) { - doAfterLastPartition() - } - } - - def doAfterLastPartition(): Unit = { - batchWriter.end() - out.writeInt(batchOrder.length) - // Sort by (index of partition, batch index in that partition) tuple to get the - // overall_batch_index from 0 to N-1 batches, which can be used to put the - // transferred batches in the correct order - batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => - out.writeInt(overallBatchIndex) - } - out.flush() } + val arrowBatchRdd = toArrowBatchRdd(plan) sparkSession.sparkContext.runJob( arrowBatchRdd, - (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, - 0 until numPartitions, - handlePartitionBatches) - - if (numPartitions == 0) { - doAfterLastPartition() + (it: Iterator[Array[Byte]]) => it.toArray, + handlePartitionBatches _) + + // After processing all partitions, end the stream and write batch order indices + batchWriter.end() + out.writeInt(batchOrder.length) + // Sort by (index of partition, batch index in that partition) tuple to get the + // overall_batch_index from 0 to N-1 batches, which can be used to put the + // transferred batches in the correct order + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => + out.writeInt(overallBatchIndex) } + out.flush() } } } From db6f4b1233895083b60a68a40685e50881143dcf Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Tue, 21 May 2019 20:38:19 -0400 Subject: [PATCH 5/5] cr --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 72b7506b2531c..99b9f35d25ea7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3295,7 +3295,7 @@ class Dataset[T] private[sql]( val batchOrder = ArrayBuffer.empty[(Int, Int)] // Handler to eagerly write batches to Python as they arrive, un-ordered - def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + val handlePartitionBatches = (index: Int, arrowBatches: Array[Array[Byte]]) => if (arrowBatches.nonEmpty) { // Write all batches (can be more than 1) in the partition, store the batch order tuple batchWriter.writeBatches(arrowBatches.iterator) @@ -3303,13 +3303,12 @@ class Dataset[T] private[sql]( partitionBatchIndex => batchOrder.append((index, partitionBatchIndex)) } } - } val arrowBatchRdd = toArrowBatchRdd(plan) sparkSession.sparkContext.runJob( arrowBatchRdd, (it: Iterator[Array[Byte]]) => it.toArray, - handlePartitionBatches _) + handlePartitionBatches) // After processing all partitions, end the stream and write batch order indices batchWriter.end() @@ -3320,7 +3319,6 @@ class Dataset[T] private[sql]( batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => out.writeInt(overallBatchIndex) } - out.flush() } } }