diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 22578cbe2e98c..f5b5ad9cdf214 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -183,6 +183,14 @@ def test_filtered_frame(self): self.assertEqual(pdf.columns[0], "i") self.assertTrue(pdf.empty) + def test_no_partition_frame(self): + schema = StructType([StructField("field1", StringType(), True)]) + df = self.spark.createDataFrame(self.sc.emptyRDD(), schema) + pdf = df.toPandas() + self.assertEqual(len(pdf.columns), 1) + self.assertEqual(pdf.columns[0], "field1") + self.assertTrue(pdf.empty) + def _createDataFrame_toggle(self, pdf, schema=None): with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 74cb3e627432c..99b9f35d25ea7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3290,15 +3290,12 @@ class Dataset[T] private[sql]( PythonRDD.serveToStream("serve-Arrow") { outputStream => val out = new DataOutputStream(outputStream) val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) - val arrowBatchRdd = toArrowBatchRdd(plan) - val numPartitions = arrowBatchRdd.partitions.length // Batches ordered by (index of partition, batch index in that partition) tuple val batchOrder = ArrayBuffer.empty[(Int, Int)] - var partitionCount = 0 // Handler to eagerly write batches to Python as they arrive, un-ordered - def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + val handlePartitionBatches = (index: Int, arrowBatches: Array[Array[Byte]]) => if (arrowBatches.nonEmpty) { // Write all batches (can be more than 1) in the partition, store the batch order tuple batchWriter.writeBatches(arrowBatches.iterator) @@ -3306,27 +3303,22 @@ class Dataset[T] private[sql]( partitionBatchIndex => batchOrder.append((index, partitionBatchIndex)) } } - partitionCount += 1 - - // After last batch, end the stream and write batch order indices - if (partitionCount == numPartitions) { - batchWriter.end() - out.writeInt(batchOrder.length) - // Sort by (index of partition, batch index in that partition) tuple to get the - // overall_batch_index from 0 to N-1 batches, which can be used to put the - // transferred batches in the correct order - batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => - out.writeInt(overallBatchIndex) - } - out.flush() - } - } + val arrowBatchRdd = toArrowBatchRdd(plan) sparkSession.sparkContext.runJob( arrowBatchRdd, - (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, - 0 until numPartitions, + (it: Iterator[Array[Byte]]) => it.toArray, handlePartitionBatches) + + // After processing all partitions, end the stream and write batch order indices + batchWriter.end() + out.writeInt(batchOrder.length) + // Sort by (index of partition, batch index in that partition) tuple to get the + // overall_batch_index from 0 to N-1 batches, which can be used to put the + // transferred batches in the correct order + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => + out.writeInt(overallBatchIndex) + } } } }