From 9af482170ee95831bbda139e6e931ba2631df386 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 10 Jan 2018 14:02:15 -0800 Subject: [PATCH 01/30] change ArrowConverters to stream format --- .../spark/sql/execution/arrow/ArrowConverters.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 7487564ed64da..fd329afbda089 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter} +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} import org.apache.arrow.vector.ipc.message.ArrowRecordBatch import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel @@ -101,7 +101,7 @@ private[sql] object ArrowConverters { override def next(): ArrowPayload = { val out = new ByteArrayOutputStream() - val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out)) Utils.tryWithSafeFinally { var rowCount = 0 @@ -133,7 +133,7 @@ private[sql] object ArrowConverters { ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue) new ArrowRowIterator { - private var reader: ArrowFileReader = null + private var reader: ArrowStreamReader = null private var schemaRead = StructType(Seq.empty) private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty @@ -166,7 +166,7 @@ private[sql] object ArrowConverters { private def nextBatch(): Iterator[InternalRow] = { val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable) - reader = new ArrowFileReader(in, allocator) + reader = new ArrowStreamReader(in, allocator) reader.loadNextBatch() // throws IOException val root = reader.getVectorSchemaRoot // throws IOException schemaRead = ArrowUtils.fromArrowSchema(root.getSchema) @@ -189,7 +189,7 @@ private[sql] object ArrowConverters { batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - val reader = new ArrowFileReader(in, allocator) + val reader = new ArrowStreamReader(in, allocator) // Read a batch from a byte stream, ensure the reader is closed Utils.tryWithSafeFinally { From d617f0da8eff1509da465bb707340e391314bec4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 10 Jan 2018 14:14:07 -0800 Subject: [PATCH 02/30] Change ArrowSerializer to use stream format --- python/pyspark/serializers.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 15753f77bd903..5e98574eeb462 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -184,24 +184,28 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class ArrowSerializer(Serializer): """ - Serializes bytes as Arrow data with the Arrow file format. + Serializes Arrow record batches as a stream. """ - def dumps(self, batch): + def dump_stream(self, iterator, stream): import pyarrow as pa - import io - sink = io.BytesIO() - writer = pa.RecordBatchFileWriter(sink, batch.schema) - writer.write_batch(batch) - writer.close() - return sink.getvalue() + writer = None + try: + for batch in iterator: + if writer is None: + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + finally: + if writer is not None: + writer.close() - def loads(self, obj): + def load_stream(self, stream): import pyarrow as pa - reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) - return reader.read_all() + reader = pa.open_stream(stream) + for batch in reader: + yield batch def __repr__(self): return "ArrowSerializer" From f10d5d9cd3cece7f56749e1de7fe01699e4759a0 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 11 Jan 2018 16:40:36 -0800 Subject: [PATCH 03/30] toPandas is working with RecordBatch payloads, using custom handler to stream ordered partitions --- .../apache/spark/api/python/PythonRDD.scala | 9 +- python/pyspark/sql/dataframe.py | 7 +- .../scala/org/apache/spark/sql/Dataset.scala | 58 +++++++++-- .../sql/execution/arrow/ArrowConverters.scala | 99 +++++++++++++++---- 4 files changed, 141 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a1ee2f7d1b119..3b66d565ee8f4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -398,6 +398,13 @@ private[spark] object PythonRDD extends Logging { * data collected from this job, and the secret for authentication. */ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { + serveToStream(threadName) { out => + writeIteratorToStream(items, out) + } + } + + // TODO: scaladoc + def serveToStream(threadName: String)(dataWriteBlock: DataOutputStream => Unit): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -411,7 +418,7 @@ private[spark] object PythonRDD extends Logging { val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) Utils.tryWithSafeFinally { - writeIteratorToStream(items, out) + dataWriteBlock(out) } { out.close() sock.close() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1e6a1acebb5ca..c80340f1ccc9d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2095,11 +2095,10 @@ def toPandas(self): _check_dataframe_localize_timestamps import pyarrow - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) + batch_iter = self._collectAsArrow() + if batch_iter: + table = pyarrow.Table.from_batches(batch_iter) pdf = table.to_pandas() - pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f5526104690d2..338d7b7bdb268 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -24,9 +24,7 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal - import org.apache.commons.lang3.StringUtils - import org.apache.spark.TaskContext import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD @@ -40,7 +38,7 @@ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JSONOptions, JacksonGenerator} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ @@ -48,7 +46,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} +import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload, ArrowPayloadStreamWriter} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -3240,9 +3238,55 @@ class Dataset[T] private[sql]( */ private[sql] def collectAsArrowToPython(): Array[Any] = { withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = - toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + //val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) + + PythonRDD.serveToStream("serve-Arrow") { out => + val payloadWriter = new ArrowPayloadStreamWriter(schema, out) + + /* + val payloadRDD = toArrowPayload + val results = new Array[Array[ArrowPayload]](payloadRDD.partitions.size) + sparkSession.sparkContext.runJob[ArrowPayload, Array[ArrowPayload]]( + payloadRDD, + (ctx: TaskContext, it: Iterator[ArrowPayload]) => it.toArray, + 0 until payloadRDD.partitions.length, + (index, res) => results(index) = res) + val payloads = Array.concat(results: _*) + */ + + val payloadRDD = toArrowPayload(plan) + + val results = new Array[Array[ArrowPayload]](payloadRDD.partitions.size - 1) + var lastIndex = -1 + + def handlePartitionPayloads(index: Int, payloads: Array[ArrowPayload]): Unit = { + if (index - 1 == lastIndex) { + payloadWriter.writePayloads(payloads.iterator) + lastIndex += 1 + while (lastIndex < results.length && results(lastIndex) != null) { + payloadWriter.writePayloads(results(lastIndex).iterator) + lastIndex += 1 + } + } else { + results(index - 1) = payloads + } + } + + sparkSession.sparkContext.runJob( + payloadRDD, + (ctx: TaskContext, it: Iterator[ArrowPayload]) => it.toArray, + 0 until payloadRDD.partitions.length, + handlePartitionPayloads) + //(_: Int, payloads: Array[ArrowPayload]) => payloadWriter.writePayloads(payloads.iterator)) + + //ArrowConverters.writePayloadsToStream(out, schema, toArrowPayload.collect().iterator) + + /*val payloadWriter = new ArrowPayloadStreamWriter(schema, out) + toArrowPayload.foreachPartition { payloadIter => + payloadWriter.writePayloads(payloadIter) + } + payloadWriter.close()*/ + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index fd329afbda089..cd7e810ffb916 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -17,23 +17,24 @@ package org.apache.spark.sql.execution.arrow -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayOutputStream, DataOutputStream} import java.nio.channels.Channels -import scala.collection.JavaConverters._ +import io.netty.buffer.ArrowBuf +import org.apache.arrow.flatbuf.{Message, MessageHeader} +import scala.collection.JavaConverters._ import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} -import org.apache.arrow.vector.ipc.message.ArrowRecordBatch +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageChannelReader, MessageSerializer} import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel - import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} import org.apache.spark.util.Utils @@ -66,8 +67,72 @@ private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { def schema: StructType } + +private[sql] class ArrowPayloadStreamWriter(schema: StructType, out: DataOutputStream) { + + //val allocator = + // ArrowUtils.rootAllocator.newChildAllocator("ArrowPayloadStreamWriter", 0, Long.MaxValue) + + val arrowSchema = ArrowUtils.toArrowSchema(schema, /*timeZoneId*/"") + + //val root = VectorSchemaRoot.create(arrowSchema, allocator) + //val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out)) + val writeChannel = new WriteChannel(Channels.newChannel(out)) + MessageSerializer.serialize(writeChannel, arrowSchema) + + /* + class PayloadReader(in: ReadChannel) extends MessageChannelReader(in) { + + override def readNextMessage(): Message = { + val msg = super.readNextMessage() + if (msg.headerType() != MessageHeader.Schema){ + outChan.write(msg.getByteBuffer()) + } + msg + } + + override def readMessageBody(message: Message, allocator: BufferAllocator): ArrowBuf = { + val buf = super.readMessageBody(message, allocator) + + buf + } + } + */ + + def writePayloads(payloadIter: Iterator[ArrowPayload]): Unit = { + payloadIter.foreach { payload => + writeChannel.write(payload.asPythonSerializable) + //val in = new ByteArrayReadableSeekableByteChannel(payload.asPythonSerializable) + //val reader = new ArrowStreamReader(new PayloadReader(new ReadChannel(in)), allocator) + //val reader = new ArrowStreamReader(in, allocator) + //while (reader.loadNextBatch()) {} // throws IOException) + //reader.close() + } + } + + def close(): Unit = { + // Write End of Stream + writeChannel.writeIntLittleEndian(0) + } +} + + + private[sql] object ArrowConverters { + private[sql] def writePayloadsToStream(out: DataOutputStream, schema: StructType, payloadIter: Iterator[ArrowPayload]): Unit = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, /*timeZoneId*/"") + val writeChannel = new WriteChannel(Channels.newChannel(out)) + MessageSerializer.serialize(writeChannel, arrowSchema) + + payloadIter.foreach { payload => + writeChannel.write(payload.asPythonSerializable) + } + + // Write End of Stream + writeChannel.writeIntLittleEndian(0) + } + /** * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. @@ -84,6 +149,7 @@ private[sql] object ArrowConverters { ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) + val unloader = new VectorUnloader(root) val arrowWriter = ArrowWriter.create(root) context.addTaskCompletionListener { _ => @@ -101,7 +167,7 @@ private[sql] object ArrowConverters { override def next(): ArrowPayload = { val out = new ByteArrayOutputStream() - val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out)) + val writeChannel = new WriteChannel(Channels.newChannel(out)) Utils.tryWithSafeFinally { var rowCount = 0 @@ -111,12 +177,14 @@ private[sql] object ArrowConverters { rowCount += 1 } arrowWriter.finish() - writer.writeBatch() + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + batch.close() } { arrowWriter.reset() - writer.close() } + // TODO: ??? writeChannel.close() new ArrowPayload(out.toByteArray) } } @@ -189,17 +257,8 @@ private[sql] object ArrowConverters { batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - val reader = new ArrowStreamReader(in, allocator) - - // Read a batch from a byte stream, ensure the reader is closed - Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch - } { - reader.close() - } + MessageSerializer.deserializeMessageBatch(new ReadChannel(in), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } private[sql] def toDataFrame( From 03653c687473b82bbfb6653504479498a2a3c63b Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 9 Feb 2018 16:23:17 -0800 Subject: [PATCH 04/30] cleanup and removed ArrowPayload, createDataFrame now working --- python/pyspark/sql/session.py | 15 +- .../scala/org/apache/spark/sql/Dataset.scala | 60 +++----- .../spark/sql/api/python/PythonSQLUtils.scala | 8 +- .../sql/execution/arrow/ArrowConverters.scala | 135 +++++------------- .../arrow/ArrowConvertersSuite.scala | 34 ++--- 5 files changed, 88 insertions(+), 164 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e880dd1ca6d1a..09bc13895b089 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -501,7 +501,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ - from pyspark.serializers import ArrowSerializer, _create_batch + from pyspark.serializers import ArrowSerializer, FramedSerializer, _create_batch from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType from pyspark.sql.utils import require_minimum_pandas_version, \ require_minimum_pyarrow_version @@ -539,9 +539,18 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): struct.names[i] = name schema = struct + class ArrowFramedSerializer(FramedSerializer): + + def dumps(self, batch): + import io + sink = io.BytesIO() + serializer = ArrowSerializer() + serializer.dump_stream([batch], sink) + return sink.getvalue() + # Create the Spark DataFrame directly from the Arrow data and schema - jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer()) - jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame( + jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowFramedSerializer()) + jdf = self._jvm.PythonSQLUtils.arrowStreamToDataFrame( jrdd, schema.json(), self._wrapped._jsqlContext) df = DataFrame(jdf, self._wrapped) df._schema = schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 338d7b7bdb268..66242d8edfe90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload, ArrowPayloadStreamWriter} +import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowBatchStreamWriter} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -3234,58 +3234,38 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { + withAction("collectAsArrowToPython", queryExecution) { plan => - //val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) PythonRDD.serveToStream("serve-Arrow") { out => - val payloadWriter = new ArrowPayloadStreamWriter(schema, out) - - /* - val payloadRDD = toArrowPayload - val results = new Array[Array[ArrowPayload]](payloadRDD.partitions.size) - sparkSession.sparkContext.runJob[ArrowPayload, Array[ArrowPayload]]( - payloadRDD, - (ctx: TaskContext, it: Iterator[ArrowPayload]) => it.toArray, - 0 until payloadRDD.partitions.length, - (index, res) => results(index) = res) - val payloads = Array.concat(results: _*) - */ - - val payloadRDD = toArrowPayload(plan) - - val results = new Array[Array[ArrowPayload]](payloadRDD.partitions.size - 1) + val batchWriter = new ArrowBatchStreamWriter(schema, out) + + val arrowBatchRDD = toArrowBatches(plan) + + val results = new Array[Array[Array[Byte]]](arrowBatchRDD.partitions.size - 1) var lastIndex = -1 - def handlePartitionPayloads(index: Int, payloads: Array[ArrowPayload]): Unit = { + def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { if (index - 1 == lastIndex) { - payloadWriter.writePayloads(payloads.iterator) + batchWriter.writeBatches(arrowBatches.iterator) lastIndex += 1 while (lastIndex < results.length && results(lastIndex) != null) { - payloadWriter.writePayloads(results(lastIndex).iterator) + batchWriter.writeBatches(results(lastIndex).iterator) lastIndex += 1 } } else { - results(index - 1) = payloads + results(index - 1) = arrowBatches } } sparkSession.sparkContext.runJob( - payloadRDD, - (ctx: TaskContext, it: Iterator[ArrowPayload]) => it.toArray, - 0 until payloadRDD.partitions.length, - handlePartitionPayloads) - //(_: Int, payloads: Array[ArrowPayload]) => payloadWriter.writePayloads(payloads.iterator)) - - //ArrowConverters.writePayloadsToStream(out, schema, toArrowPayload.collect().iterator) - - /*val payloadWriter = new ArrowPayloadStreamWriter(schema, out) - toArrowPayload.foreachPartition { payloadIter => - payloadWriter.writePayloads(payloadIter) - } - payloadWriter.close()*/ + arrowBatchRDD, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until arrowBatchRDD.partitions.length, + handlePartitionBatches) } } } @@ -3394,19 +3374,19 @@ class Dataset[T] private[sql]( } /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { + private[sql] def toArrowBatches(plan: SparkPlan): RDD[Array[Byte]] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() - ArrowConverters.toPayloadIterator( + ArrowConverters.toBatchIterator( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } // This is only used in tests, for now. - private[sql] def toArrowPayload: RDD[ArrowPayload] = { - toArrowPayload(queryExecution.executedPlan) + private[sql] def toArrowBatches: RDD[Array[Byte]] = { + toArrowBatches(queryExecution.executedPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index b33760b1edbc6..9ba4958a32f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -36,15 +36,15 @@ private[sql] object PythonSQLUtils { /** * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. * - * @param payloadRDD A JavaRDD of ArrowPayloads. + * @param arrowStreamRDD A JavaRDD of Arrow data in stream protocol. * @param schemaString JSON Formatted Schema for ArrowPayloads. * @param sqlContext The active [[SQLContext]]. * @return The converted [[DataFrame]]. */ - def arrowPayloadToDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + def arrowStreamToDataFrame( + arrowStreamRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) + ArrowConverters.toDataFrame(arrowStreamRDD, schemaString, sqlContext) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index cd7e810ffb916..5ba2d803a568f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -20,14 +20,11 @@ package org.apache.spark.sql.execution.arrow import java.io.{ByteArrayOutputStream, DataOutputStream} import java.nio.channels.Channels -import io.netty.buffer.ArrowBuf -import org.apache.arrow.flatbuf.{Message, MessageHeader} - import scala.collection.JavaConverters._ import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} -import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageChannelReader, MessageSerializer} +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD @@ -38,75 +35,15 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, Columna import org.apache.spark.util.Utils -/** - * Store Arrow data in a form that can be serialized by Spark and served to a Python process. - */ -private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable { - - /** - * Convert the ArrowPayload to an ArrowRecordBatch. - */ - def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { - ArrowConverters.byteArrayToBatch(payload, allocator) - } - - /** - * Get the ArrowPayload as a type that can be served to Python. - */ - def asPythonSerializable: Array[Byte] = payload -} - -/** - * Iterator interface to iterate over Arrow record batches and return rows - */ -private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { - - /** - * Return the schema loaded from the Arrow record batch being iterated over - */ - def schema: StructType -} - - -private[sql] class ArrowPayloadStreamWriter(schema: StructType, out: DataOutputStream) { - - //val allocator = - // ArrowUtils.rootAllocator.newChildAllocator("ArrowPayloadStreamWriter", 0, Long.MaxValue) +private[sql] class ArrowBatchStreamWriter(schema: StructType, out: DataOutputStream) { val arrowSchema = ArrowUtils.toArrowSchema(schema, /*timeZoneId*/"") - - //val root = VectorSchemaRoot.create(arrowSchema, allocator) - //val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out)) val writeChannel = new WriteChannel(Channels.newChannel(out)) MessageSerializer.serialize(writeChannel, arrowSchema) - /* - class PayloadReader(in: ReadChannel) extends MessageChannelReader(in) { - - override def readNextMessage(): Message = { - val msg = super.readNextMessage() - if (msg.headerType() != MessageHeader.Schema){ - outChan.write(msg.getByteBuffer()) - } - msg - } - - override def readMessageBody(message: Message, allocator: BufferAllocator): ArrowBuf = { - val buf = super.readMessageBody(message, allocator) - - buf - } - } - */ - - def writePayloads(payloadIter: Iterator[ArrowPayload]): Unit = { - payloadIter.foreach { payload => - writeChannel.write(payload.asPythonSerializable) - //val in = new ByteArrayReadableSeekableByteChannel(payload.asPythonSerializable) - //val reader = new ArrowStreamReader(new PayloadReader(new ReadChannel(in)), allocator) - //val reader = new ArrowStreamReader(in, allocator) - //while (reader.loadNextBatch()) {} // throws IOException) - //reader.close() + def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = { + arrowBatchIter.foreach { batchBytes => + writeChannel.write(batchBytes) } } @@ -117,36 +54,34 @@ private[sql] class ArrowPayloadStreamWriter(schema: StructType, out: DataOutputS } +/** + * Iterator interface to iterate over Arrow record batches and return rows + */ +private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { -private[sql] object ArrowConverters { - - private[sql] def writePayloadsToStream(out: DataOutputStream, schema: StructType, payloadIter: Iterator[ArrowPayload]): Unit = { - val arrowSchema = ArrowUtils.toArrowSchema(schema, /*timeZoneId*/"") - val writeChannel = new WriteChannel(Channels.newChannel(out)) - MessageSerializer.serialize(writeChannel, arrowSchema) + /** + * Return the schema loaded from the Arrow record batch being iterated over + */ + def schema: StructType +} - payloadIter.foreach { payload => - writeChannel.write(payload.asPythonSerializable) - } - // Write End of Stream - writeChannel.writeIntLittleEndian(0) - } +private[sql] object ArrowConverters { /** - * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload + * Maps Iterator from InternalRow to Arrow batches. Limit ArrowRecordBatch size in a batch * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. */ - private[sql] def toPayloadIterator( + private[sql] def toBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, timeZoneId: String, - context: TaskContext): Iterator[ArrowPayload] = { + context: TaskContext): Iterator[Array[Byte]] = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = - ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) + ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) val unloader = new VectorUnloader(root) @@ -157,7 +92,7 @@ private[sql] object ArrowConverters { allocator.close() } - new Iterator[ArrowPayload] { + new Iterator[Array[Byte]] { override def hasNext: Boolean = rowIter.hasNext || { root.close() @@ -165,7 +100,7 @@ private[sql] object ArrowConverters { false } - override def next(): ArrowPayload = { + override def next(): Array[Byte] = { val out = new ByteArrayOutputStream() val writeChannel = new WriteChannel(Channels.newChannel(out)) @@ -185,25 +120,25 @@ private[sql] object ArrowConverters { } // TODO: ??? writeChannel.close() - new ArrowPayload(out.toByteArray) + out.toByteArray } } } /** - * Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing the row iterator - * and the schema from the first batch of Arrow data read. + * Maps Iterator from Arrow batches to InternalRow. Returns an ArrowRowIterator that can iterate + * over record batch rows and has the schema from the first batch of Arrow data read. */ - private[sql] def fromPayloadIterator( - payloadIter: Iterator[ArrowPayload], + private[sql] def fromStreamIterator( + arrowStreamIter: Iterator[Array[Byte]], context: TaskContext): ArrowRowIterator = { val allocator = - ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue) + ArrowUtils.rootAllocator.newChildAllocator("fromStreamIterator", 0, Long.MaxValue) new ArrowRowIterator { private var reader: ArrowStreamReader = null private var schemaRead = StructType(Seq.empty) - private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty + private var rowIter = if (arrowStreamIter.hasNext) nextBatch() else Iterator.empty context.addTaskCompletionListener { _ => closeReader() @@ -214,7 +149,7 @@ private[sql] object ArrowConverters { override def hasNext: Boolean = rowIter.hasNext || { closeReader() - if (payloadIter.hasNext) { + if (arrowStreamIter.hasNext) { rowIter = nextBatch() true } else { @@ -233,7 +168,7 @@ private[sql] object ArrowConverters { } private def nextBatch(): Iterator[InternalRow] = { - val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable) + val in = new ByteArrayReadableSeekableByteChannel(arrowStreamIter.next()) reader = new ArrowStreamReader(in, allocator) reader.loadNextBatch() // throws IOException val root = reader.getVectorSchemaRoot // throws IOException @@ -253,7 +188,7 @@ private[sql] object ArrowConverters { /** * Convert a byte array to an ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { val in = new ByteArrayReadableSeekableByteChannel(batchBytes) @@ -262,12 +197,12 @@ private[sql] object ArrowConverters { } private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowStreamRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - val rdd = payloadRDD.rdd.mapPartitions { iter => + val rdd = arrowStreamRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromStreamIterator(iter, context) } val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 261df06100aef..e66add6de0d3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -51,11 +51,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { test("collect to arrow record batch") { val indexData = (1 to 6).toDF("i") - val arrowPayloads = indexData.toArrowPayload.collect() + val arrowPayloads = indexData.toArrowBatches.collect() assert(arrowPayloads.nonEmpty) assert(arrowPayloads.length == indexData.rdd.getNumPartitions) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowPayloads.map(ArrowConverters.loadBatch(_, allocator)) val rowCount = arrowRecordBatches.map(_.getLength).sum assert(rowCount === indexData.count()) arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) @@ -1153,7 +1153,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { |} """.stripMargin - val arrowPayloads = testData2.toArrowPayload.collect() + val arrowPayloads = testData2.toArrowBatches.collect() // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload assert(arrowPayloads.length === 2) val schema = testData2.schema @@ -1168,20 +1168,20 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() + val arrowPayload = spark.emptyDataFrame.toArrowBatches.collect() assert(arrowPayload.isEmpty) val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") - val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() + val filteredArrowPayload = filteredDF.filter("i < 0").toArrowBatches.collect() assert(filteredArrowPayload.isEmpty) } test("empty partition collect") { val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") - val arrowPayloads = emptyPart.toArrowPayload.collect() + val arrowPayloads = emptyPart.toArrowBatches.collect() assert(arrowPayloads.length === 1) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowPayloads.map(ArrowConverters.loadBatch(_, allocator)) assert(arrowRecordBatches.head.getLength == 1) arrowRecordBatches.foreach(_.close()) allocator.close() @@ -1192,10 +1192,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val maxRecordsPerBatch = 3 spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") - val arrowPayloads = df.toArrowPayload.collect() + val arrowPayloads = df.toArrowBatches.collect() assert(arrowPayloads.length >= 4) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowPayloads.map(ArrowConverters.loadBatch(_, allocator)) var recordCount = 0 arrowRecordBatches.foreach { batch => assert(batch.getLength > 0) @@ -1217,8 +1217,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { mapData.toDF().toArrowPayload.collect() } - runUnsupported { complexData.toArrowPayload.collect() } + runUnsupported { mapData.toDF().toArrowBatches.collect() } + runUnsupported { complexData.toArrowBatches.collect() } } test("test Arrow Validator") { @@ -1319,7 +1319,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("roundtrip payloads") { - val inputRows = (0 until 9).map { i => + /*val inputRows = (0 until 9).map { i => InternalRow(i) } :+ InternalRow(null) @@ -1341,22 +1341,22 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { count += 1 } - assert(count == inputRows.length) + assert(count == inputRows.length)*/ } /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def collectAndValidate( df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator - val arrowPayload = df.coalesce(1).toArrowPayload.collect().head + val batchBytes = df.coalesce(1).toArrowBatches.collect().head val tempFile = new File(tempDataPath, file) Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, arrowPayload, tempFile, timeZoneId) + validateConversion(df.schema, batchBytes, tempFile, timeZoneId) } private def validateConversion( sparkSchema: StructType, - arrowPayload: ArrowPayload, + batchBytes: Array[Byte], jsonFile: File, timeZoneId: String = null): Unit = { val allocator = new RootAllocator(Long.MaxValue) @@ -1368,7 +1368,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) val vectorLoader = new VectorLoader(arrowRoot) - val arrowRecordBatch = arrowPayload.loadBatch(allocator) + val arrowRecordBatch = ArrowConverters.loadBatch(batchBytes, allocator) vectorLoader.load(arrowRecordBatch) val jsonRoot = jsonReader.read() Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) From 1b932463bca0815e79f3a8d61d1c816e62949698 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 8 Mar 2018 16:14:06 -0800 Subject: [PATCH 05/30] toPandas and createDataFrame working but tests fail with date cols --- python/pyspark/sql/session.py | 30 +++++---- .../scala/org/apache/spark/sql/Dataset.scala | 3 +- .../spark/sql/api/python/PythonSQLUtils.scala | 15 ++++- .../sql/execution/arrow/ArrowConverters.scala | 63 ++++++++++++++++++- 4 files changed, 94 insertions(+), 17 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 09bc13895b089..4e636dabd038d 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -501,7 +501,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ - from pyspark.serializers import ArrowSerializer, FramedSerializer, _create_batch + from pyspark.serializers import ArrowSerializer, _create_batch from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType from pyspark.sql.utils import require_minimum_pandas_version, \ require_minimum_pyarrow_version @@ -539,19 +539,23 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): struct.names[i] = name schema = struct - class ArrowFramedSerializer(FramedSerializer): - - def dumps(self, batch): - import io - sink = io.BytesIO() - serializer = ArrowSerializer() - serializer.dump_stream([batch], sink) - return sink.getvalue() + import os + from tempfile import NamedTemporaryFile + serializer = ArrowSerializer() + temp_filenames = [] + try: + for batch in batches: + temp_file = NamedTemporaryFile(delete=False, dir=self.sparkContext._temp_dir) + temp_filenames.append(temp_file.name) + serializer.dump_stream([batch], temp_file) + temp_file.close() + jdf = self._jvm.PythonSQLUtils.arrowReadStreamFromFiles( + self._wrapped._jsqlContext, schema.json(), temp_filenames) + finally: + # arrowReadStreamFromFile eagerly reads the file so we can delete right after. + for temp_filename in temp_filenames: + os.unlink(temp_filename) - # Create the Spark DataFrame directly from the Arrow data and schema - jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowFramedSerializer()) - jdf = self._jvm.PythonSQLUtils.arrowStreamToDataFrame( - jrdd, schema.json(), self._wrapped._jsqlContext) df = DataFrame(jdf, self._wrapped) df._schema = schema return df diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 66242d8edfe90..02f0fd40f4adb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3237,11 +3237,12 @@ class Dataset[T] private[sql]( * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone withAction("collectAsArrowToPython", queryExecution) { plan => PythonRDD.serveToStream("serve-Arrow") { out => - val batchWriter = new ArrowBatchStreamWriter(schema, out) + val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) val arrowBatchRDD = toArrowBatches(plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 9ba4958a32f9a..236eb235d3c52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.api.python -import org.apache.spark.api.java.JavaRDD +import java.util.{ArrayList => JArrayList} + +import scala.collection.JavaConverters._ + +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.ExpressionInfo @@ -47,4 +51,13 @@ private[sql] object PythonSQLUtils { sqlContext: SQLContext): DataFrame = { ArrowConverters.toDataFrame(arrowStreamRDD, schemaString, sqlContext) } + + def arrowReadStreamFromFiles( + sqlContext: SQLContext, + schemaString: String, + filenames: JArrayList[String]): DataFrame = { + JavaSparkContext.fromSparkContext(sqlContext.sparkContext) + val jrdd = ArrowConverters.readArrowStreamFromFiles(sqlContext, filenames.asScala.toArray) + arrowStreamToDataFrame(jrdd, schemaString, sqlContext) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 5ba2d803a568f..37470e7f48bb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.arrow import java.io.{ByteArrayOutputStream, DataOutputStream} import java.nio.channels.Channels +import java.nio.file.{Files, Paths} import scala.collection.JavaConverters._ import org.apache.arrow.memory.BufferAllocator @@ -35,9 +36,12 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, Columna import org.apache.spark.util.Utils -private[sql] class ArrowBatchStreamWriter(schema: StructType, out: DataOutputStream) { +private[sql] class ArrowBatchStreamWriter( + schema: StructType, + out: DataOutputStream, + timeZoneId: String) { - val arrowSchema = ArrowUtils.toArrowSchema(schema, /*timeZoneId*/"") + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val writeChannel = new WriteChannel(Channels.newChannel(out)) MessageSerializer.serialize(writeChannel, arrowSchema) @@ -125,6 +129,53 @@ private[sql] object ArrowConverters { } } + private[sql] def fromBatchIterator( + arrowBatchIter: Iterator[Array[Byte]], + schema: StructType, + context: TaskContext): Iterator[InternalRow] = { + val allocator = + ArrowUtils.rootAllocator.newChildAllocator("fromStreamIterator", 0, Long.MaxValue) + + val arrowSchema = ArrowUtils.toArrowSchema(schema, "***TODO***") + val root = VectorSchemaRoot.create(arrowSchema, allocator) + + new Iterator[InternalRow] { + private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty + + context.addTaskCompletionListener { _ => + allocator.close() + root.close() + } + + override def hasNext: Boolean = rowIter.hasNext || { + if (arrowBatchIter.hasNext) { + rowIter = nextBatch() + true + } else { + allocator.close() + root.close() + false + } + } + + override def next(): InternalRow = rowIter.next() + + private def nextBatch(): Iterator[InternalRow] = { + val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) + val vectorLoader = new VectorLoader(root) + vectorLoader.load(arrowRecordBatch) + + val columns = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector).asInstanceOf[ColumnVector] + }.toArray + + val batch = new ColumnarBatch(columns) + batch.setNumRows(root.getRowCount) + batch.rowIterator().asScala + } + } + } + /** * Maps Iterator from Arrow batches to InternalRow. Returns an ArrowRowIterator that can iterate * over record batch rows and has the schema from the first batch of Arrow data read. @@ -207,4 +258,12 @@ private[sql] object ArrowConverters { val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + private[sql] def readArrowStreamFromFiles(sqlContext: SQLContext, filenames: Array[String]): + JavaRDD[Array[Byte]] = { + val fileData = filenames.map { filename => + Files.readAllBytes(Paths.get(filename)) // throws IOException + } + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(fileData, filenames.length)) + } } From ce22d8ad18e052d150528752b727c6cfe11485f7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 26 Mar 2018 17:32:03 -0700 Subject: [PATCH 06/30] removed usage of seekableByteChannel --- .../spark/sql/execution/arrow/ArrowConverters.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 37470e7f48bb2..32c403bf73341 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.arrow -import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataOutputStream} import java.nio.channels.Channels import java.nio.file.{Files, Paths} @@ -26,7 +26,6 @@ import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ import org.apache.arrow.vector.ipc.{ArrowStreamReader, ReadChannel, WriteChannel} import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} -import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} @@ -219,7 +218,7 @@ private[sql] object ArrowConverters { } private def nextBatch(): Iterator[InternalRow] = { - val in = new ByteArrayReadableSeekableByteChannel(arrowStreamIter.next()) + val in = new ByteArrayInputStream(arrowStreamIter.next()) reader = new ArrowStreamReader(in, allocator) reader.loadNextBatch() // throws IOException val root = reader.getVectorSchemaRoot // throws IOException @@ -242,8 +241,8 @@ private[sql] object ArrowConverters { private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { - val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - MessageSerializer.deserializeMessageBatch(new ReadChannel(in), allocator) + val in = new ByteArrayInputStream(batchBytes) + MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) .asInstanceOf[ArrowRecordBatch] // throws IOException } From dede0bd96921c439747a9176f24c9ecbb9c8ce0a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 27 Mar 2018 17:28:54 -0700 Subject: [PATCH 07/30] for toPandas, set old collection result to null and add comments --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 02f0fd40f4adb..09606298c625d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3246,15 +3246,18 @@ class Dataset[T] private[sql]( val arrowBatchRDD = toArrowBatches(plan) + // Store collection results for worst case of 1 to N-1 partitions val results = new Array[Array[Array[Byte]]](arrowBatchRDD.partitions.size - 1) - var lastIndex = -1 + var lastIndex = -1 // index of last partition written def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order if (index - 1 == lastIndex) { batchWriter.writeBatches(arrowBatches.iterator) lastIndex += 1 while (lastIndex < results.length && results(lastIndex) != null) { batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null lastIndex += 1 } } else { From 9e29b092cb7d45fa486db0215c3bd4a99c5f8d98 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 28 Mar 2018 11:28:18 -0700 Subject: [PATCH 08/30] cleanup, not yet passing ArrowConvertersSuite --- .../apache/spark/api/python/PythonRDD.scala | 17 ++++- .../scala/org/apache/spark/sql/Dataset.scala | 22 +++--- .../sql/execution/arrow/ArrowConverters.scala | 16 ++-- .../arrow/ArrowConvertersSuite.scala | 76 +++++++++++++------ 4 files changed, 88 insertions(+), 43 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 3b66d565ee8f4..6bfefb2baf98b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -403,8 +403,21 @@ private[spark] object PythonRDD extends Logging { } } - // TODO: scaladoc - def serveToStream(threadName: String)(dataWriteBlock: DataOutputStream => Unit): Array[Any] = { + /** + * Create a socket server and background thread to execute the `dataWriteBlock` + * for the given DataOutputStream. + * + * The socket server can only accept one connection, or close if no connection + * in 15 seconds. + * + * Once a connection comes in, it will execute the `dataWriteBlock` and pass in + * the socket output stream. + * + * The thread will terminate after the `dataWriteBlock` is executed or any + * exceptions happen. + */ + private[spark] def serveToStream( + threadName: String)(dataWriteBlock: DataOutputStream => Unit): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 09606298c625d..c936a9e044059 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3240,35 +3240,37 @@ class Dataset[T] private[sql]( val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone withAction("collectAsArrowToPython", queryExecution) { plan => - PythonRDD.serveToStream("serve-Arrow") { out => val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) - - val arrowBatchRDD = toArrowBatches(plan) + val arrowBatchRdd = getArrowBatchRdd(plan) + val numPartitions = arrowBatchRdd.partitions.length // Store collection results for worst case of 1 to N-1 partitions - val results = new Array[Array[Array[Byte]]](arrowBatchRDD.partitions.size - 1) + val results = new Array[Array[Array[Byte]]](numPartitions - 1) var lastIndex = -1 // index of last partition written + // Handler to eagerly write partitions to Python in order def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { // If result is from next partition in order if (index - 1 == lastIndex) { batchWriter.writeBatches(arrowBatches.iterator) lastIndex += 1 + // Write stored partitions that come next in order while (lastIndex < results.length && results(lastIndex) != null) { batchWriter.writeBatches(results(lastIndex).iterator) results(lastIndex) = null lastIndex += 1 } } else { + // Store partitions received out of order results(index - 1) = arrowBatches } } sparkSession.sparkContext.runJob( - arrowBatchRDD, + arrowBatchRdd, (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, - 0 until arrowBatchRDD.partitions.length, + 0 until numPartitions, handlePartitionBatches) } } @@ -3377,8 +3379,8 @@ class Dataset[T] private[sql]( } } - /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowBatches(plan: SparkPlan): RDD[Array[Byte]] = { + /** Convert to an RDD of Arrow record batch byte arrays */ + private[sql] def getArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone @@ -3390,7 +3392,7 @@ class Dataset[T] private[sql]( } // This is only used in tests, for now. - private[sql] def toArrowBatches: RDD[Array[Byte]] = { - toArrowBatches(queryExecution.executedPlan) + private[sql] def getArrowBatchRdd: RDD[Array[Byte]] = { + getArrowBatchRdd(queryExecution.executedPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 32c403bf73341..e34c35dccbd60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -72,8 +72,8 @@ private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { private[sql] object ArrowConverters { /** - * Maps Iterator from InternalRow to Arrow batches. Limit ArrowRecordBatch size in a batch - * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. + * Maps Iterator from InternalRow to Arrow batches as byte arrays. Limit ArrowRecordBatch size + * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. */ private[sql] def toBatchIterator( rowIter: Iterator[InternalRow], @@ -128,14 +128,18 @@ private[sql] object ArrowConverters { } } + /** + * Maps iterator from Arrow batches as byte arrays to InternalRows. + */ private[sql] def fromBatchIterator( arrowBatchIter: Iterator[Array[Byte]], schema: StructType, + timeZoneId: String, context: TaskContext): Iterator[InternalRow] = { val allocator = - ArrowUtils.rootAllocator.newChildAllocator("fromStreamIterator", 0, Long.MaxValue) + ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, Long.MaxValue) - val arrowSchema = ArrowUtils.toArrowSchema(schema, "***TODO***") + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val root = VectorSchemaRoot.create(arrowSchema, allocator) new Iterator[InternalRow] { @@ -176,8 +180,8 @@ private[sql] object ArrowConverters { } /** - * Maps Iterator from Arrow batches to InternalRow. Returns an ArrowRowIterator that can iterate - * over record batch rows and has the schema from the first batch of Arrow data read. + * Maps Iterator from Arrow stream format to InternalRow. Returns an ArrowRowIterator that can + * iterate over record batch rows and has the schema from the first batch of Arrow data read. */ private[sql] def fromStreamIterator( arrowStreamIter: Iterator[Array[Byte]], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index e66add6de0d3e..48cf808b50566 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -30,8 +30,8 @@ import org.apache.arrow.vector.util.Validator import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -51,11 +51,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { test("collect to arrow record batch") { val indexData = (1 to 6).toDF("i") - val arrowPayloads = indexData.toArrowBatches.collect() - assert(arrowPayloads.nonEmpty) - assert(arrowPayloads.length == indexData.rdd.getNumPartitions) + val arrowBatches = indexData.getArrowBatchRdd.collect() + assert(arrowBatches.nonEmpty) + assert(arrowBatches.length == indexData.rdd.getNumPartitions) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(ArrowConverters.loadBatch(_, allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) val rowCount = arrowRecordBatches.map(_.getLength).sum assert(rowCount === indexData.count()) arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) @@ -1153,9 +1153,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { |} """.stripMargin - val arrowPayloads = testData2.toArrowBatches.collect() + val arrowBatches = testData2.getArrowBatchRdd.collect() // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload - assert(arrowPayloads.length === 2) + assert(arrowBatches.length === 2) val schema = testData2.schema val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") @@ -1163,25 +1163,25 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { Files.write(json1, tempFile1, StandardCharsets.UTF_8) Files.write(json2, tempFile2, StandardCharsets.UTF_8) - validateConversion(schema, arrowPayloads(0), tempFile1) - validateConversion(schema, arrowPayloads(1), tempFile2) + validateConversion(schema, arrowBatches(0), tempFile1) + validateConversion(schema, arrowBatches(1), tempFile2) } test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.toArrowBatches.collect() + val arrowPayload = spark.emptyDataFrame.getArrowBatchRdd.collect() assert(arrowPayload.isEmpty) val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") - val filteredArrowPayload = filteredDF.filter("i < 0").toArrowBatches.collect() + val filteredArrowPayload = filteredDF.filter("i < 0").getArrowBatchRdd.collect() assert(filteredArrowPayload.isEmpty) } test("empty partition collect") { val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") - val arrowPayloads = emptyPart.toArrowBatches.collect() - assert(arrowPayloads.length === 1) + val arrowBatches = emptyPart.getArrowBatchRdd.collect() + assert(arrowBatches.length === 1) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(ArrowConverters.loadBatch(_, allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) assert(arrowRecordBatches.head.getLength == 1) arrowRecordBatches.foreach(_.close()) allocator.close() @@ -1192,10 +1192,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val maxRecordsPerBatch = 3 spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") - val arrowPayloads = df.toArrowBatches.collect() - assert(arrowPayloads.length >= 4) + val arrowBatches = df.getArrowBatchRdd.collect() + assert(arrowBatches.length >= 4) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(ArrowConverters.loadBatch(_, allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) var recordCount = 0 arrowRecordBatches.foreach { batch => assert(batch.getLength > 0) @@ -1217,8 +1217,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { mapData.toDF().toArrowBatches.collect() } - runUnsupported { complexData.toArrowBatches.collect() } + runUnsupported { mapData.toDF().getArrowBatchRdd.collect() } + runUnsupported { complexData.getArrowBatchRdd.collect() } } test("test Arrow Validator") { @@ -1318,16 +1318,41 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - test("roundtrip payloads") { - /*val inputRows = (0 until 9).map { i => + test("roundtrip arrow batches") { + val inputRows = (0 until 9).map { i => + InternalRow(i) + } :+ InternalRow(null) + + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + + val ctx = TaskContext.empty() + val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 0, null, ctx) + val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) + + var count = 0 + outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { + assert(row.getInt(0) == i) + } else { + assert(row.isNullAt(0)) + } + count += 1 + } + + assert(count == inputRows.length) + } + + /* + test("roundtrip arrow stream") { + val inputRows = (0 until 9).map { i => InternalRow(i) } :+ InternalRow(null) val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) - val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) + val streamIter = ArrowConverters.toStreamIterator(inputRows.toIterator, schema, 0, null, ctx) + val outputRowIter = ArrowConverters.fromStreamIterator(streamIter, ctx) assert(schema == outputRowIter.schema) @@ -1341,14 +1366,15 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { count += 1 } - assert(count == inputRows.length)*/ + assert(count == inputRows.length) } + */ /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def collectAndValidate( df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator - val batchBytes = df.coalesce(1).toArrowBatches.collect().head + val batchBytes = df.coalesce(1).getArrowBatchRdd.collect().head val tempFile = new File(tempDataPath, file) Files.write(json, tempFile, StandardCharsets.UTF_8) validateConversion(df.schema, batchBytes, tempFile, timeZoneId) From ceb8d38a6c83c3b6dae040c9e8d860811ecad0cc Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 29 Mar 2018 14:14:03 -0700 Subject: [PATCH 09/30] fix to read Arrow stream with multiple batches, cleanup, add docs, scala tests pass, style pass --- .../apache/spark/api/python/PythonRDD.scala | 12 +-- .../scala/org/apache/spark/sql/Dataset.scala | 6 +- .../sql/execution/arrow/ArrowConverters.scala | 98 ++++++++++++++++--- .../arrow/ArrowConvertersSuite.scala | 57 ++++++++--- 4 files changed, 142 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 6bfefb2baf98b..4a0e72050cde4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -404,20 +404,20 @@ private[spark] object PythonRDD extends Logging { } /** - * Create a socket server and background thread to execute the `dataWriteBlock` + * Create a socket server and background thread to execute the block of code * for the given DataOutputStream. * * The socket server can only accept one connection, or close if no connection * in 15 seconds. * - * Once a connection comes in, it will execute the `dataWriteBlock` and pass in + * Once a connection comes in, it will execute the block of code and pass in * the socket output stream. * - * The thread will terminate after the `dataWriteBlock` is executed or any + * The thread will terminate after the block of code is executed or any * exceptions happen. */ - private[spark] def serveToStream( - threadName: String)(dataWriteBlock: DataOutputStream => Unit): Array[Any] = { + private[spark] def serveToStream(threadName: String) + (block: DataOutputStream => Unit): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -431,7 +431,7 @@ private[spark] object PythonRDD extends Logging { val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) Utils.tryWithSafeFinally { - dataWriteBlock(out) + block(out) } { out.close() sock.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c936a9e044059..4667a7ec9e997 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -24,7 +24,9 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal + import org.apache.commons.lang3.StringUtils + import org.apache.spark.TaskContext import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD @@ -38,7 +40,7 @@ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection -import org.apache.spark.sql.catalyst.json.{JSONOptions, JacksonGenerator} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ @@ -46,7 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowBatchStreamWriter} +import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index e34c35dccbd60..e20f1c5132242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -22,19 +22,23 @@ import java.nio.channels.Channels import java.nio.file.{Files, Paths} import scala.collection.JavaConverters._ + import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ArrowStreamReader, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} + import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils - +/** + * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format + */ private[sql] class ArrowBatchStreamWriter( schema: StructType, out: DataOutputStream, @@ -122,7 +126,6 @@ private[sql] object ArrowConverters { arrowWriter.reset() } - // TODO: ??? writeChannel.close() out.toByteArray } } @@ -146,8 +149,8 @@ private[sql] object ArrowConverters { private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty context.addTaskCompletionListener { _ => - allocator.close() root.close() + allocator.close() } override def hasNext: Boolean = rowIter.hasNext || { @@ -155,8 +158,8 @@ private[sql] object ArrowConverters { rowIter = nextBatch() true } else { - allocator.close() root.close() + allocator.close() false } } @@ -167,6 +170,7 @@ private[sql] object ArrowConverters { val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) val vectorLoader = new VectorLoader(root) vectorLoader.load(arrowRecordBatch) + arrowRecordBatch.close() val columns = root.getFieldVectors.asScala.map { vector => new ArrowColumnVector(vector).asInstanceOf[ColumnVector] @@ -179,6 +183,63 @@ private[sql] object ArrowConverters { } } + /** + * Maps Iterator from InternalRow to Arrow stream format as a byte array. Each Arrow stream + * will have 1 ArrowRecordBatch. Limit ArrowRecordBatch size in a batch by setting + * maxRecordsPerBatch or use 0 to fully consume rowIter. Once this limit is reach, a new Arrow + * stream will be started. + */ + private[sql] def toStreamIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Int, + timeZoneId: String, + context: TaskContext): Iterator[Array[Byte]] = { + + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val allocator = + ArrowUtils.rootAllocator.newChildAllocator("toStreamIterator", 0, Long.MaxValue) + + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) + + context.addTaskCompletionListener { _ => + root.close() + allocator.close() + } + + new Iterator[Array[Byte]] { + + override def hasNext: Boolean = rowIter.hasNext || { + root.close() + allocator.close() + false + } + + override def next(): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writer = new ArrowStreamWriter(root, null, out) + writer.start() + + Utils.tryWithSafeFinally { + var rowCount = 0 + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + rowCount += 1 + } + arrowWriter.finish() + writer.writeBatch() + } { + writer.close() + arrowWriter.reset() + } + + out.toByteArray + } + } + } + /** * Maps Iterator from Arrow stream format to InternalRow. Returns an ArrowRowIterator that can * iterate over record batch rows and has the schema from the first batch of Arrow data read. @@ -192,7 +253,7 @@ private[sql] object ArrowConverters { new ArrowRowIterator { private var reader: ArrowStreamReader = null private var schemaRead = StructType(Seq.empty) - private var rowIter = if (arrowStreamIter.hasNext) nextBatch() else Iterator.empty + private var rowIter = if (arrowStreamIter.hasNext) nextStream() else Iterator.empty context.addTaskCompletionListener { _ => closeReader() @@ -202,11 +263,16 @@ private[sql] object ArrowConverters { override def schema: StructType = schemaRead override def hasNext: Boolean = rowIter.hasNext || { - closeReader() - if (arrowStreamIter.hasNext) { - rowIter = nextBatch() + if (reader != null && reader.loadNextBatch()) { + rowIter = nextBatch(reader.getVectorSchemaRoot) + true + } + else if (arrowStreamIter.hasNext) { + closeReader() + rowIter = nextStream() true } else { + closeReader() allocator.close() false } @@ -221,17 +287,19 @@ private[sql] object ArrowConverters { } } - private def nextBatch(): Iterator[InternalRow] = { + private def nextStream(): Iterator[InternalRow] = { val in = new ByteArrayInputStream(arrowStreamIter.next()) reader = new ArrowStreamReader(in, allocator) reader.loadNextBatch() // throws IOException val root = reader.getVectorSchemaRoot // throws IOException schemaRead = ArrowUtils.fromArrowSchema(root.getSchema) + nextBatch(root) + } + private def nextBatch(root: VectorSchemaRoot): Iterator[InternalRow] = { val columns = root.getFieldVectors.asScala.map { vector => new ArrowColumnVector(vector).asInstanceOf[ColumnVector] }.toArray - val batch = new ColumnarBatch(columns) batch.setNumRows(root.getRowCount) batch.rowIterator().asScala @@ -250,6 +318,9 @@ private[sql] object ArrowConverters { .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Convert a JavaRDD of Arrow streams as byte arrays to a DataFrame + */ private[sql] def toDataFrame( arrowStreamRDD: JavaRDD[Array[Byte]], schemaString: String, @@ -262,6 +333,9 @@ private[sql] object ArrowConverters { sqlContext.internalCreateDataFrame(rdd, schema) } + /** + * Read files entirely and parallelize into a JavaRDD with 1 partition per file + */ private[sql] def readArrowStreamFromFiles(sqlContext: SQLContext, filenames: Array[String]): JavaRDD[Array[Byte]] = { val fileData = filenames.map { filename => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 48cf808b50566..31325e89d09e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.execution.arrow -import java.io.File +import java.io.{ByteArrayOutputStream, DataOutputStream, File} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat @@ -30,8 +30,8 @@ import org.apache.arrow.vector.util.Validator import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -1154,7 +1154,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { """.stripMargin val arrowBatches = testData2.getArrowBatchRdd.collect() - // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload + // NOTE: testData2 should have 2 partitions -> 2 arrow batches assert(arrowBatches.length === 2) val schema = testData2.schema @@ -1168,12 +1168,12 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.getArrowBatchRdd.collect() - assert(arrowPayload.isEmpty) + val arrowBatches = spark.emptyDataFrame.getArrowBatchRdd.collect() + assert(arrowBatches.isEmpty) val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") - val filteredArrowPayload = filteredDF.filter("i < 0").getArrowBatchRdd.collect() - assert(filteredArrowPayload.isEmpty) + val filteredArrowBatches = filteredDF.filter("i < 0").getArrowBatchRdd.collect() + assert(filteredArrowBatches.isEmpty) } test("empty partition collect") { @@ -1326,7 +1326,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 0, null, ctx) + val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) var count = 0 @@ -1342,7 +1342,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(count == inputRows.length) } - /* test("roundtrip arrow stream") { val inputRows = (0 until 9).map { i => InternalRow(i) @@ -1351,7 +1350,44 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val streamIter = ArrowConverters.toStreamIterator(inputRows.toIterator, schema, 0, null, ctx) + val streamIter = ArrowConverters.toStreamIterator(inputRows.toIterator, schema, 5, null, ctx) + val outputRowIter = ArrowConverters.fromStreamIterator(streamIter, ctx) + + assert(schema == outputRowIter.schema) + + var count = 0 + outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { + assert(row.getInt(0) == i) + } else { + assert(row.isNullAt(0)) + } + count += 1 + } + + assert(count == inputRows.length) + } + + test("ArrowBatchStreamWriter roundtrip") { + val inputRows = (0 until 9).map { i => + InternalRow(i) + } :+ InternalRow(null) + + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + + val ctx = TaskContext.empty() + val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + + // Write batches to Arrow stream format as a byte array + val out = new ByteArrayOutputStream() + val dataOut = new DataOutputStream(out) + val writer = new ArrowBatchStreamWriter(schema, dataOut, null) + writer.writeBatches(batchIter) + writer.close() + out.close() + + // Convert Arrow stream format to Rows + val streamIter = Iterator(out.toByteArray) val outputRowIter = ArrowConverters.fromStreamIterator(streamIter, ctx) assert(schema == outputRowIter.schema) @@ -1368,7 +1404,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(count == inputRows.length) } - */ /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def collectAndValidate( From f42e4ea7b4fb944eeefd39a0fd6a1428b527214a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 29 Mar 2018 15:17:25 -0700 Subject: [PATCH 10/30] use base OutputStream for serveToStream instead of DataOutputStream --- .../main/scala/org/apache/spark/api/python/PythonRDD.scala | 7 +++---- .../apache/spark/sql/execution/arrow/ArrowConverters.scala | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4a0e72050cde4..81e9eb75d5646 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -399,7 +399,7 @@ private[spark] object PythonRDD extends Logging { */ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { serveToStream(threadName) { out => - writeIteratorToStream(items, out) + writeIteratorToStream(items, new DataOutputStream(out)) } } @@ -416,8 +416,7 @@ private[spark] object PythonRDD extends Logging { * The thread will terminate after the block of code is executed or any * exceptions happen. */ - private[spark] def serveToStream(threadName: String) - (block: DataOutputStream => Unit): Array[Any] = { + private[spark] def serveToStream(threadName: String)(block: OutputStream => Unit): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -429,7 +428,7 @@ private[spark] object PythonRDD extends Logging { val sock = serverSocket.accept() authHelper.authClient(sock) - val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + val out = new BufferedOutputStream(sock.getOutputStream) Utils.tryWithSafeFinally { block(out) } { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index e20f1c5132242..72b0dd438710a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.arrow -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, OutputStream} import java.nio.channels.Channels import java.nio.file.{Files, Paths} @@ -41,7 +41,7 @@ import org.apache.spark.util.Utils */ private[sql] class ArrowBatchStreamWriter( schema: StructType, - out: DataOutputStream, + out: OutputStream, timeZoneId: String) { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) From 951843d760aa6d29ff18112e82d28f4f6dc09907 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 29 Mar 2018 15:21:13 -0700 Subject: [PATCH 11/30] accidentally removed date type checks, passing pyspark tests --- python/pyspark/sql/dataframe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c80340f1ccc9d..44537b87047f5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2099,6 +2099,7 @@ def toPandas(self): if batch_iter: table = pyarrow.Table.from_batches(batch_iter) pdf = table.to_pandas() + pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) From af03c6b384fe4ea73d67ad1d3f46be4a1e027e9e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 11 Jun 2018 17:29:12 -0700 Subject: [PATCH 12/30] Changed to only use Arrow batch bytes as payload, had to hack Arrow MessageChannelReader --- python/pyspark/context.py | 11 +- python/pyspark/sql/session.py | 21 +- .../spark/sql/api/python/PythonSQLUtils.scala | 27 +- .../sql/execution/arrow/ArrowConverters.scala | 233 +++++++----------- .../arrow/ArrowConvertersSuite.scala | 37 +-- 5 files changed, 117 insertions(+), 212 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ede3b6af0a8cf..a8acf56178a8a 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -493,10 +493,14 @@ def f(split, iterator): c = list(c) # Make it a list so we can compute its length batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - jrdd = self._serialize_to_jvm(c, numSlices, serializer) + + def reader_func(temp_filename): + return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) + + jrdd = self._serialize_to_jvm(c, serializer, reader_func) return RDD(jrdd, self, serializer) - def _serialize_to_jvm(self, data, parallelism, serializer): + def _serialize_to_jvm(self, data, serializer, reader_func): """ Calling the Java parallelize() method with an ArrayList is too slow, because it sends O(n) Py4J commands. As an alternative, serialized @@ -506,8 +510,7 @@ def _serialize_to_jvm(self, data, parallelism, serializer): try: serializer.dump_stream(data, tempFile) tempFile.close() - readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile - return readRDDFromFile(self._jsc, tempFile.name, parallelism) + return reader_func(tempFile.name) finally: # readRDDFromFile eagerily reads the file so we can delete right after. os.unlink(tempFile.name) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 4e636dabd038d..d69b5eac08151 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -539,23 +539,12 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): struct.names[i] = name schema = struct - import os - from tempfile import NamedTemporaryFile - serializer = ArrowSerializer() - temp_filenames = [] - try: - for batch in batches: - temp_file = NamedTemporaryFile(delete=False, dir=self.sparkContext._temp_dir) - temp_filenames.append(temp_file.name) - serializer.dump_stream([batch], temp_file) - temp_file.close() - jdf = self._jvm.PythonSQLUtils.arrowReadStreamFromFiles( - self._wrapped._jsqlContext, schema.json(), temp_filenames) - finally: - # arrowReadStreamFromFile eagerly reads the file so we can delete right after. - for temp_filename in temp_filenames: - os.unlink(temp_filename) + def reader_func(temp_filename): + return self._jvm.PythonSQLUtils.arrowReadStreamFromFile( + self._wrapped._jsqlContext, temp_filename, schema.json()) + # Create Spark DataFrame from Arrow stream file, using one batch per partition + jdf = self._sc._serialize_to_jvm(batches, ArrowSerializer(), reader_func) df = DataFrame(jdf, self._wrapped) df._schema = schema return df diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 236eb235d3c52..a530c4dcc15e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.api.python -import java.util.{ArrayList => JArrayList} - -import scala.collection.JavaConverters._ - import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry @@ -38,10 +34,10 @@ private[sql] object PythonSQLUtils { } /** - * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * Python callable function to convert an RDD of Arrow record batches into a [[DataFrame]]. * - * @param arrowStreamRDD A JavaRDD of Arrow data in stream protocol. - * @param schemaString JSON Formatted Schema for ArrowPayloads. + * @param arrowStreamRDD A JavaRDD of Arrow record batches as byte arrays. + * @param schemaString JSON Formatted Spark schema for Arrow batches. * @param sqlContext The active [[SQLContext]]. * @return The converted [[DataFrame]]. */ @@ -52,12 +48,21 @@ private[sql] object PythonSQLUtils { ArrowConverters.toDataFrame(arrowStreamRDD, schemaString, sqlContext) } - def arrowReadStreamFromFiles( + /** + * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] + * using each batch as a partition. + * + * @param sqlContext The active [[SQLContext]]. + * @param filename File to read the Arrow stream from. + * @param schemaString JSON Formatted Spark schema for Arrow batches. + * @return A new [[DataFrame]]. + */ + def arrowReadStreamFromFile( sqlContext: SQLContext, - schemaString: String, - filenames: JArrayList[String]): DataFrame = { + filename: String, + schemaString: String): DataFrame = { JavaSparkContext.fromSparkContext(sqlContext.sparkContext) - val jrdd = ArrowConverters.readArrowStreamFromFiles(sqlContext, filenames.asScala.toArray) + val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename) arrowStreamToDataFrame(jrdd, schemaString, sqlContext) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 72b0dd438710a..80d0da53bcb34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -17,15 +17,17 @@ package org.apache.spark.sql.execution.arrow -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, OutputStream} -import java.nio.channels.Channels -import java.nio.file.{Files, Paths} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream} +import java.nio.ByteBuffer +import java.nio.channels.{Channels, SeekableByteChannel} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import org.apache.arrow.flatbuf.{Message, MessageHeader} import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} import org.apache.spark.TaskContext @@ -36,6 +38,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils + /** * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format */ @@ -60,19 +63,6 @@ private[sql] class ArrowBatchStreamWriter( } } - -/** - * Iterator interface to iterate over Arrow record batches and return rows - */ -private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { - - /** - * Return the schema loaded from the Arrow record batch being iterated over - */ - def schema: StructType -} - - private[sql] object ArrowConverters { /** @@ -183,130 +173,6 @@ private[sql] object ArrowConverters { } } - /** - * Maps Iterator from InternalRow to Arrow stream format as a byte array. Each Arrow stream - * will have 1 ArrowRecordBatch. Limit ArrowRecordBatch size in a batch by setting - * maxRecordsPerBatch or use 0 to fully consume rowIter. Once this limit is reach, a new Arrow - * stream will be started. - */ - private[sql] def toStreamIterator( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Int, - timeZoneId: String, - context: TaskContext): Iterator[Array[Byte]] = { - - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - val allocator = - ArrowUtils.rootAllocator.newChildAllocator("toStreamIterator", 0, Long.MaxValue) - - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val arrowWriter = ArrowWriter.create(root) - - context.addTaskCompletionListener { _ => - root.close() - allocator.close() - } - - new Iterator[Array[Byte]] { - - override def hasNext: Boolean = rowIter.hasNext || { - root.close() - allocator.close() - false - } - - override def next(): Array[Byte] = { - val out = new ByteArrayOutputStream() - val writer = new ArrowStreamWriter(root, null, out) - writer.start() - - Utils.tryWithSafeFinally { - var rowCount = 0 - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { - val row = rowIter.next() - arrowWriter.write(row) - rowCount += 1 - } - arrowWriter.finish() - writer.writeBatch() - } { - writer.close() - arrowWriter.reset() - } - - out.toByteArray - } - } - } - - /** - * Maps Iterator from Arrow stream format to InternalRow. Returns an ArrowRowIterator that can - * iterate over record batch rows and has the schema from the first batch of Arrow data read. - */ - private[sql] def fromStreamIterator( - arrowStreamIter: Iterator[Array[Byte]], - context: TaskContext): ArrowRowIterator = { - val allocator = - ArrowUtils.rootAllocator.newChildAllocator("fromStreamIterator", 0, Long.MaxValue) - - new ArrowRowIterator { - private var reader: ArrowStreamReader = null - private var schemaRead = StructType(Seq.empty) - private var rowIter = if (arrowStreamIter.hasNext) nextStream() else Iterator.empty - - context.addTaskCompletionListener { _ => - closeReader() - allocator.close() - } - - override def schema: StructType = schemaRead - - override def hasNext: Boolean = rowIter.hasNext || { - if (reader != null && reader.loadNextBatch()) { - rowIter = nextBatch(reader.getVectorSchemaRoot) - true - } - else if (arrowStreamIter.hasNext) { - closeReader() - rowIter = nextStream() - true - } else { - closeReader() - allocator.close() - false - } - } - - override def next(): InternalRow = rowIter.next() - - private def closeReader(): Unit = { - if (reader != null) { - reader.close() - reader = null - } - } - - private def nextStream(): Iterator[InternalRow] = { - val in = new ByteArrayInputStream(arrowStreamIter.next()) - reader = new ArrowStreamReader(in, allocator) - reader.loadNextBatch() // throws IOException - val root = reader.getVectorSchemaRoot // throws IOException - schemaRead = ArrowUtils.fromArrowSchema(root.getSchema) - nextBatch(root) - } - - private def nextBatch(root: VectorSchemaRoot): Iterator[InternalRow] = { - val columns = root.getFieldVectors.asScala.map { vector => - new ArrowColumnVector(vector).asInstanceOf[ColumnVector] - }.toArray - val batch = new ColumnarBatch(columns) - batch.setNumRows(root.getRowCount) - batch.rowIterator().asScala - } - } - } - /** * Convert a byte array to an ArrowRecordBatch. */ @@ -319,28 +185,97 @@ private[sql] object ArrowConverters { } /** - * Convert a JavaRDD of Arrow streams as byte arrays to a DataFrame + * Create a DataFrame from a JavaRDD of Arrow record batches */ private[sql] def toDataFrame( arrowStreamRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone val rdd = arrowStreamRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromStreamIterator(iter, context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } /** - * Read files entirely and parallelize into a JavaRDD with 1 partition per file + * Read a file as an Arrow stream and create an RDD from record batches */ - private[sql] def readArrowStreamFromFiles(sqlContext: SQLContext, filenames: Array[String]): + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): JavaRDD[Array[Byte]] = { - val fileData = filenames.map { filename => - Files.readAllBytes(Paths.get(filename)) // throws IOException + val fileStream = new FileInputStream(filename) + try { + val batches = getBatchesFromStream(fileStream.getChannel) + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) + } finally { + fileStream.close() } - JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(fileData, filenames.length)) + } + + /** + * Read input of an Arrow stream and return all record batches read as byte arrays + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Array[Array[Byte]] = { + + // TODO: simplify in super class + class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { + // TODO: need ReadChannel to be protected + // extends MessageChannelReader(new ReadChannel(fileChannel)) { + val in = new ReadChannel(inputChannel) + private val batches = new ArrayBuffer[Array[Byte]] + + def getRecordBatchBytes() = batches.toArray + + def readNextMessage(): Message = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { + return null + } + val messageLength = MessageSerializer.bytesToInt(buffer.array()) + if (messageLength == 0) { + return null + } + + loadMessageBuffer(readMessageBuffer(messageLength), messageLength) + } + + protected def readMessageBuffer(messageLength: Int): ByteBuffer = { + // Read the message size. There is an i32 little endian prefix. + val buffer = ByteBuffer.allocate(messageLength) + if (in.readFully(buffer) != messageLength) { + throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + buffer + } + + // Load a Message and read RecordBatch, storing it in an array + protected def loadMessageBuffer(buffer: ByteBuffer, messageLength: Int): Message = { + val msg = Message.getRootAsMessage(buffer) + val bodyLength = msg.bodyLength().asInstanceOf[Int] + + if (msg.headerType() == MessageHeader.RecordBatch) { + val allbuf = ByteBuffer.allocate(4 + messageLength + bodyLength) + allbuf.put(WriteChannel.intToBytes(messageLength)) + allbuf.put(buffer) + in.readFully(allbuf) + batches.append(allbuf.array()) + } else if (bodyLength > 0) { + // Skip message body if not a record batch + inputChannel.position(inputChannel.position() + bodyLength) + } + + msg + } + } + + // Read the input stream and store all record batches in an array + val msgReader = new RecordBatchMessageReader(in) + while (msgReader.readNextMessage() != null) {} + msgReader.getRecordBatchBytes() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 31325e89d09e4..dd7b0816fff44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -26,7 +26,7 @@ import com.google.common.io.Files import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.ipc.JsonFileReader -import org.apache.arrow.vector.util.Validator +import org.apache.arrow.vector.util.{ByteArrayReadableSeekableByteChannel, Validator} import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkException, TaskContext} @@ -1342,32 +1342,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(count == inputRows.length) } - test("roundtrip arrow stream") { - val inputRows = (0 until 9).map { i => - InternalRow(i) - } :+ InternalRow(null) - - val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) - - val ctx = TaskContext.empty() - val streamIter = ArrowConverters.toStreamIterator(inputRows.toIterator, schema, 5, null, ctx) - val outputRowIter = ArrowConverters.fromStreamIterator(streamIter, ctx) - - assert(schema == outputRowIter.schema) - - var count = 0 - outputRowIter.zipWithIndex.foreach { case (row, i) => - if (i != 9) { - assert(row.getInt(0) == i) - } else { - assert(row.isNullAt(0)) - } - count += 1 - } - - assert(count == inputRows.length) - } - test("ArrowBatchStreamWriter roundtrip") { val inputRows = (0 until 9).map { i => InternalRow(i) @@ -1386,11 +1360,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { writer.close() out.close() - // Convert Arrow stream format to Rows - val streamIter = Iterator(out.toByteArray) - val outputRowIter = ArrowConverters.fromStreamIterator(streamIter, ctx) - - assert(schema == outputRowIter.schema) + // Read Arrow stream into batches, then convert back to rows + val in = new ByteArrayReadableSeekableByteChannel(out.toByteArray) + val readBatches = ArrowConverters.getBatchesFromStream(in) + val outputRowIter = ArrowConverters.fromBatchIterator(readBatches.toIterator, schema, null, ctx) var count = 0 outputRowIter.zipWithIndex.foreach { case (row, i) => From b047c1624429ea579aa279e92909b90400b40c58 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 11 Jun 2018 17:30:48 -0700 Subject: [PATCH 13/30] added todo comment --- .../org/apache/spark/sql/execution/arrow/ArrowConverters.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 80d0da53bcb34..5f5786ec4106d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -59,6 +59,7 @@ private[sql] class ArrowBatchStreamWriter( def close(): Unit = { // Write End of Stream + // TODO: this could be a static function in ArrowStreamWriter writeChannel.writeIntLittleEndian(0) } } From a77b89ea0357e3ce146ff35537eb7da8a8c80bad Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 12 Jun 2018 10:35:05 -0700 Subject: [PATCH 14/30] change getBatchesFromStream to return iterator --- .../sql/execution/arrow/ArrowConverters.scala | 43 +++++++++++++++---- .../arrow/ArrowConvertersSuite.scala | 2 +- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 5f5786ec4106d..a78c9fdd25163 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -208,7 +208,8 @@ private[sql] object ArrowConverters { JavaRDD[Array[Byte]] = { val fileStream = new FileInputStream(filename) try { - val batches = getBatchesFromStream(fileStream.getChannel) + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray // Parallelize the record batches to create an RDD JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) } finally { @@ -219,18 +220,20 @@ private[sql] object ArrowConverters { /** * Read input of an Arrow stream and return all record batches read as byte arrays */ - private[sql] def getBatchesFromStream(in: SeekableByteChannel): Array[Array[Byte]] = { + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { // TODO: simplify in super class class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { // TODO: need ReadChannel to be protected // extends MessageChannelReader(new ReadChannel(fileChannel)) { val in = new ReadChannel(inputChannel) - private val batches = new ArrayBuffer[Array[Byte]] + //private val batches = new ArrayBuffer[Array[Byte]] + private var lastBatch: Array[Byte] = null - def getRecordBatchBytes() = batches.toArray + def getLastBatch() = lastBatch def readNextMessage(): Message = { + lastBatch = null val buffer = ByteBuffer.allocate(4) if (in.readFully(buffer) != 4) { return null @@ -264,7 +267,7 @@ private[sql] object ArrowConverters { allbuf.put(WriteChannel.intToBytes(messageLength)) allbuf.put(buffer) in.readFully(allbuf) - batches.append(allbuf.array()) + lastBatch = allbuf.array() } else if (bodyLength > 0) { // Skip message body if not a record batch inputChannel.position(inputChannel.position() + bodyLength) @@ -274,9 +277,31 @@ private[sql] object ArrowConverters { } } - // Read the input stream and store all record batches in an array - val msgReader = new RecordBatchMessageReader(in) - while (msgReader.readNextMessage() != null) {} - msgReader.getRecordBatchBytes() + new Iterator[Array[Byte]] { + + // Read the input stream and store the next batch read + val msgReader = new RecordBatchMessageReader(in) + var batch: Array[Byte] = null + readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { + val prevBatch = batch + readNextBatch() + prevBatch + } + + def readNextBatch(): Unit = { + var stop = false + while (!stop) { + val msg = msgReader.readNextMessage() + batch = msgReader.getLastBatch() + if (msg == null || (msg != null && batch != null)) { + stop = true + } + } + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index dd7b0816fff44..d8a979cf6e7b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1363,7 +1363,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { // Read Arrow stream into batches, then convert back to rows val in = new ByteArrayReadableSeekableByteChannel(out.toByteArray) val readBatches = ArrowConverters.getBatchesFromStream(in) - val outputRowIter = ArrowConverters.fromBatchIterator(readBatches.toIterator, schema, null, ctx) + val outputRowIter = ArrowConverters.fromBatchIterator(readBatches, schema, null, ctx) var count = 0 outputRowIter.zipWithIndex.foreach { case (row, i) => From 81c82093edf78d36a1de850b3d8faede88fb0524 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 12 Jun 2018 11:06:22 -0700 Subject: [PATCH 15/30] need to end stream on toPandas after all batches sent to python, and added some comments --- .../scala/org/apache/spark/sql/Dataset.scala | 4 +++ .../spark/sql/api/python/PythonSQLUtils.scala | 5 ++-- .../sql/execution/arrow/ArrowConverters.scala | 25 ++++++++++++------- .../arrow/ArrowConvertersSuite.scala | 2 +- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4667a7ec9e997..2cf7fbf0adc76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3263,6 +3263,10 @@ class Dataset[T] private[sql]( results(lastIndex) = null lastIndex += 1 } + // After last batch, end the stream + if (lastIndex == results.length) { + batchWriter.end() + } } else { // Store partitions received out of order results(index - 1) = arrowBatches diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index a530c4dcc15e6..910cd4fc29346 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -34,7 +34,8 @@ private[sql] object PythonSQLUtils { } /** - * Python callable function to convert an RDD of Arrow record batches into a [[DataFrame]]. + * Python callable function to convert an RDD of serialized ArrowRecordBatches into + * a [[DataFrame]]. * * @param arrowStreamRDD A JavaRDD of Arrow record batches as byte arrays. * @param schemaString JSON Formatted Spark schema for Arrow batches. @@ -50,7 +51,7 @@ private[sql] object PythonSQLUtils { /** * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] - * using each batch as a partition. + * using each serialized ArrowRecordBatch as a partition. * * @param sqlContext The active [[SQLContext]]. * @param filename File to read the Arrow stream from. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index a78c9fdd25163..fa5600793bd45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -22,7 +22,6 @@ import java.nio.ByteBuffer import java.nio.channels.{Channels, SeekableByteChannel} import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import org.apache.arrow.flatbuf.{Message, MessageHeader} import org.apache.arrow.memory.BufferAllocator @@ -40,7 +39,7 @@ import org.apache.spark.util.Utils /** - * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format + * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format. */ private[sql] class ArrowBatchStreamWriter( schema: StructType, @@ -49,15 +48,23 @@ private[sql] class ArrowBatchStreamWriter( val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val writeChannel = new WriteChannel(Channels.newChannel(out)) + + // Write the Arrow schema first, before batches MessageSerializer.serialize(writeChannel, arrowSchema) + /** + * Consume iterator to write each serialized ArrowRecordBatch to the stream. + */ def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = { arrowBatchIter.foreach { batchBytes => writeChannel.write(batchBytes) } } - def close(): Unit = { + /** + * End the Arrow stream, does not close output stream. + */ + def end(): Unit = { // Write End of Stream // TODO: this could be a static function in ArrowStreamWriter writeChannel.writeIntLittleEndian(0) @@ -67,7 +74,7 @@ private[sql] class ArrowBatchStreamWriter( private[sql] object ArrowConverters { /** - * Maps Iterator from InternalRow to Arrow batches as byte arrays. Limit ArrowRecordBatch size + * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. */ private[sql] def toBatchIterator( @@ -123,7 +130,7 @@ private[sql] object ArrowConverters { } /** - * Maps iterator from Arrow batches as byte arrays to InternalRows. + * Maps iterator from serialized ArrowRecordBatches to InternalRows. */ private[sql] def fromBatchIterator( arrowBatchIter: Iterator[Array[Byte]], @@ -175,7 +182,7 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ private[arrow] def loadBatch( batchBytes: Array[Byte], @@ -186,7 +193,7 @@ private[sql] object ArrowConverters { } /** - * Create a DataFrame from a JavaRDD of Arrow record batches + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. */ private[sql] def toDataFrame( arrowStreamRDD: JavaRDD[Array[Byte]], @@ -202,7 +209,7 @@ private[sql] object ArrowConverters { } /** - * Read a file as an Arrow stream and create an RDD from record batches + * Read a file as an Arrow stream and create an RDD of serialized ArrowRecordBatches. */ private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): JavaRDD[Array[Byte]] = { @@ -218,7 +225,7 @@ private[sql] object ArrowConverters { } /** - * Read input of an Arrow stream and return all record batches read as byte arrays + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. */ private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index d8a979cf6e7b0..80c852a08541d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1357,7 +1357,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val dataOut = new DataOutputStream(out) val writer = new ArrowBatchStreamWriter(schema, dataOut, null) writer.writeBatches(batchIter) - writer.close() + writer.end() out.close() // Read Arrow stream into batches, then convert back to rows From 5f46a02aa34e9f51ac310d27e3272b883c67cc37 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 12 Jun 2018 11:11:19 -0700 Subject: [PATCH 16/30] forgot to remove old comment --- .../org/apache/spark/sql/execution/arrow/ArrowConverters.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index fa5600793bd45..0d9c79a992193 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -234,7 +234,6 @@ private[sql] object ArrowConverters { // TODO: need ReadChannel to be protected // extends MessageChannelReader(new ReadChannel(fileChannel)) { val in = new ReadChannel(inputChannel) - //private val batches = new ArrayBuffer[Array[Byte]] private var lastBatch: Array[Byte] = null def getLastBatch() = lastBatch From 7694b8fe6970789d4e88f9c5df46c97a4b235f02 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 12 Jun 2018 11:47:49 -0700 Subject: [PATCH 17/30] fixed up comments --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../apache/spark/sql/execution/arrow/ArrowConverters.scala | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2cf7fbf0adc76..c5b65b9bd52d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3385,7 +3385,7 @@ class Dataset[T] private[sql]( } } - /** Convert to an RDD of Arrow record batch byte arrays */ + /** Convert to an RDD of serialized ArrowRecordBatches. */ private[sql] def getArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 0d9c79a992193..9157703525495 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -263,7 +263,7 @@ private[sql] object ArrowConverters { buffer } - // Load a Message and read RecordBatch, storing it in an array + // Load a Message, if it is a RecordBatch then read body and store as serialized bytes protected def loadMessageBuffer(buffer: ByteBuffer, messageLength: Int): Message = { val msg = Message.getRootAsMessage(buffer) val bodyLength = msg.bodyLength().asInstanceOf[Int] @@ -283,9 +283,8 @@ private[sql] object ArrowConverters { } } + // Create an iterator to get each serialized ArrowRecordBatch in an stream new Iterator[Array[Byte]] { - - // Read the input stream and store the next batch read val msgReader = new RecordBatchMessageReader(in) var batch: Array[Byte] = null readNextBatch() From a5a1fbe7121c5b0dd93876a56c29ad17dcd9b168 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 12 Jun 2018 13:22:30 -0700 Subject: [PATCH 18/30] fixed up some wording --- .../org/apache/spark/sql/api/python/PythonSQLUtils.scala | 6 +++--- .../apache/spark/sql/execution/arrow/ArrowConverters.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 910cd4fc29346..ac32abdeb61f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -37,16 +37,16 @@ private[sql] object PythonSQLUtils { * Python callable function to convert an RDD of serialized ArrowRecordBatches into * a [[DataFrame]]. * - * @param arrowStreamRDD A JavaRDD of Arrow record batches as byte arrays. + * @param arrowBatchRDD A JavaRDD of serialized ArrowRecordBatches. * @param schemaString JSON Formatted Spark schema for Arrow batches. * @param sqlContext The active [[SQLContext]]. * @return The converted [[DataFrame]]. */ def arrowStreamToDataFrame( - arrowStreamRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - ArrowConverters.toDataFrame(arrowStreamRDD, schemaString, sqlContext) + ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, sqlContext) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 9157703525495..b1a8269d43942 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -196,12 +196,12 @@ private[sql] object ArrowConverters { * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. */ private[sql] def toDataFrame( - arrowStreamRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone - val rdd = arrowStreamRDD.rdd.mapPartitions { iter => + val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } @@ -209,7 +209,7 @@ private[sql] object ArrowConverters { } /** - * Read a file as an Arrow stream and create an RDD of serialized ArrowRecordBatches. + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. */ private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): JavaRDD[Array[Byte]] = { From 555605a01edc9cfbf15f6523e32326edb2debd0d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 14 Jun 2018 10:03:17 -0700 Subject: [PATCH 19/30] Updated MessageChannelReader to reflect Arrow changes --- .../spark/sql/execution/arrow/ArrowConverters.scala | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index b1a8269d43942..5a3a99036a895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -249,23 +249,20 @@ private[sql] object ArrowConverters { return null } - loadMessageBuffer(readMessageBuffer(messageLength), messageLength) + loadMessageOverride(messageLength, ByteBuffer.allocate(messageLength)) } - protected def readMessageBuffer(messageLength: Int): ByteBuffer = { - // Read the message size. There is an i32 little endian prefix. - val buffer = ByteBuffer.allocate(messageLength) + protected def loadMessage(messageLength: Int, buffer: ByteBuffer): Message = { if (in.readFully(buffer) != messageLength) { throw new java.io.IOException( "Unexpected end of stream trying to read message.") } buffer.rewind() - buffer + Message.getRootAsMessage(buffer) } - // Load a Message, if it is a RecordBatch then read body and store as serialized bytes - protected def loadMessageBuffer(buffer: ByteBuffer, messageLength: Int): Message = { - val msg = Message.getRootAsMessage(buffer) + protected def loadMessageOverride(messageLength: Int, buffer: ByteBuffer): Message = { + val msg = loadMessage(messageLength, buffer) val bodyLength = msg.bodyLength().asInstanceOf[Int] if (msg.headerType() == MessageHeader.RecordBatch) { From 54d697986f8f4be7b09e4a0ad5ae25dd9ec24352 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 14 Jun 2018 10:39:26 -0700 Subject: [PATCH 20/30] move arrowStreamToDataFrame to arrowReadStreamFromFile as not being called from python, close resource in test, cleanup --- .../spark/sql/api/python/PythonSQLUtils.scala | 20 +------------------ .../sql/execution/arrow/ArrowConverters.scala | 7 ++++--- .../arrow/ArrowConvertersSuite.scala | 10 +++++----- 3 files changed, 10 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index ac32abdeb61f2..c0830e77b5a87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.api.python -import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.ExpressionInfo @@ -33,22 +32,6 @@ private[sql] object PythonSQLUtils { FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray } - /** - * Python callable function to convert an RDD of serialized ArrowRecordBatches into - * a [[DataFrame]]. - * - * @param arrowBatchRDD A JavaRDD of serialized ArrowRecordBatches. - * @param schemaString JSON Formatted Spark schema for Arrow batches. - * @param sqlContext The active [[SQLContext]]. - * @return The converted [[DataFrame]]. - */ - def arrowStreamToDataFrame( - arrowBatchRDD: JavaRDD[Array[Byte]], - schemaString: String, - sqlContext: SQLContext): DataFrame = { - ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, sqlContext) - } - /** * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] * using each serialized ArrowRecordBatch as a partition. @@ -62,8 +45,7 @@ private[sql] object PythonSQLUtils { sqlContext: SQLContext, filename: String, schemaString: String): DataFrame = { - JavaSparkContext.fromSparkContext(sqlContext.sparkContext) val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename) - arrowStreamToDataFrame(jrdd, schemaString, sqlContext) + ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 5a3a99036a895..5ce12b979e031 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -209,10 +209,11 @@ private[sql] object ArrowConverters { } /** - * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. */ - private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): - JavaRDD[Array[Byte]] = { + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { val fileStream = new FileInputStream(filename) try { // Create array so that we can safely close the file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 80c852a08541d..657f0d93bac66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1354,11 +1354,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { // Write batches to Arrow stream format as a byte array val out = new ByteArrayOutputStream() - val dataOut = new DataOutputStream(out) - val writer = new ArrowBatchStreamWriter(schema, dataOut, null) - writer.writeBatches(batchIter) - writer.end() - out.close() + Utils.tryWithResource(new DataOutputStream(out)) { dataOut => + val writer = new ArrowBatchStreamWriter(schema, dataOut, null) + writer.writeBatches(batchIter) + writer.end() + } // Read Arrow stream into batches, then convert back to rows val in = new ByteArrayReadableSeekableByteChannel(out.toByteArray) From 4af58f9539ea12c8c309790001efe497d18f0129 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 14 Jun 2018 10:41:39 -0700 Subject: [PATCH 21/30] rename ArrowSerializer to ArrowStreamSerializer --- python/pyspark/serializers.py | 4 ++-- python/pyspark/sql/dataframe.py | 4 ++-- python/pyspark/sql/session.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 5e98574eeb462..0d2b1d93baa74 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -184,7 +184,7 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(Serializer): +class ArrowStreamSerializer(Serializer): """ Serializes Arrow record batches as a stream. """ @@ -208,7 +208,7 @@ def load_stream(self, stream): yield batch def __repr__(self): - return "ArrowSerializer" + return "ArrowStreamSerializer" def _create_batch(series, timezone): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 44537b87047f5..136f9a41e03e3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,7 @@ from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ +from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -2153,7 +2153,7 @@ def _collectAsArrow(self): """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(sock_info, ArrowSerializer())) + return list(_load_from_socket(sock_info, ArrowStreamSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index d69b5eac08151..6d5080a845b41 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -501,7 +501,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ - from pyspark.serializers import ArrowSerializer, _create_batch + from pyspark.serializers import ArrowStreamSerializer, _create_batch from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType from pyspark.sql.utils import require_minimum_pandas_version, \ require_minimum_pyarrow_version @@ -544,7 +544,7 @@ def reader_func(temp_filename): self._wrapped._jsqlContext, temp_filename, schema.json()) # Create Spark DataFrame from Arrow stream file, using one batch per partition - jdf = self._sc._serialize_to_jvm(batches, ArrowSerializer(), reader_func) + jdf = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func) df = DataFrame(jdf, self._wrapped) df._schema = schema return df From c6d24f2579780ebe63653af069cd0db942afc2ed Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 22 Jun 2018 14:57:00 -0700 Subject: [PATCH 22/30] try using static utilty functions instead of MessageChannelReader subclass --- .../sql/execution/arrow/ArrowConverters.scala | 98 ++++++++----------- 1 file changed, 40 insertions(+), 58 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 5ce12b979e031..db8d26cb16044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -230,79 +230,61 @@ private[sql] object ArrowConverters { */ private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { - // TODO: simplify in super class - class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { - // TODO: need ReadChannel to be protected - // extends MessageChannelReader(new ReadChannel(fileChannel)) { - val in = new ReadChannel(inputChannel) - private var lastBatch: Array[Byte] = null - - def getLastBatch() = lastBatch - - def readNextMessage(): Message = { - lastBatch = null - val buffer = ByteBuffer.allocate(4) - if (in.readFully(buffer) != 4) { - return null - } - val messageLength = MessageSerializer.bytesToInt(buffer.array()) - if (messageLength == 0) { - return null - } - - loadMessageOverride(messageLength, ByteBuffer.allocate(messageLength)) + // TODO: this could be moved to Arrow + def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { + return 0 } + MessageSerializer.bytesToInt(buffer.array()) + } - protected def loadMessage(messageLength: Int, buffer: ByteBuffer): Message = { - if (in.readFully(buffer) != messageLength) { - throw new java.io.IOException( - "Unexpected end of stream trying to read message.") - } - buffer.rewind() - Message.getRootAsMessage(buffer) - } - - protected def loadMessageOverride(messageLength: Int, buffer: ByteBuffer): Message = { - val msg = loadMessage(messageLength, buffer) - val bodyLength = msg.bodyLength().asInstanceOf[Int] - - if (msg.headerType() == MessageHeader.RecordBatch) { - val allbuf = ByteBuffer.allocate(4 + messageLength + bodyLength) - allbuf.put(WriteChannel.intToBytes(messageLength)) - allbuf.put(buffer) - in.readFully(allbuf) - lastBatch = allbuf.array() - } else if (bodyLength > 0) { - // Skip message body if not a record batch - inputChannel.position(inputChannel.position() + bodyLength) - } - - msg + // TODO: this could be moved to Arrow + def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { + throw new java.io.IOException( + "Unexpected end of stream trying to read message.") } + buffer.rewind() + Message.getRootAsMessage(buffer) } - // Create an iterator to get each serialized ArrowRecordBatch in an stream + + // Create an iterator to get each serialized ArrowRecordBatch from a stream new Iterator[Array[Byte]] { - val msgReader = new RecordBatchMessageReader(in) - var batch: Array[Byte] = null - readNextBatch() + val inputChannel = new ReadChannel(in) + var batch: Array[Byte] = readNextBatch() override def hasNext: Boolean = batch != null override def next(): Array[Byte] = { val prevBatch = batch - readNextBatch() + batch = readNextBatch() prevBatch } - def readNextBatch(): Unit = { - var stop = false - while (!stop) { - val msg = msgReader.readNextMessage() - batch = msgReader.getLastBatch() - if (msg == null || (msg != null && batch != null)) { - stop = true + def readNextBatch(): Array[Byte] = { + val messageLength = readMessageLength(inputChannel) + if (messageLength == 0) { + return null + } + + val buffer = ByteBuffer.allocate(messageLength) + val msg = loadMessage(inputChannel, messageLength, buffer) + val bodyLength = msg.bodyLength().asInstanceOf[Int] + + if (msg.headerType() == MessageHeader.RecordBatch) { + val allbuf = ByteBuffer.allocate(4 + messageLength + bodyLength) + allbuf.put(WriteChannel.intToBytes(messageLength)) + allbuf.put(buffer) + inputChannel.readFully(allbuf) + allbuf.array() + } else { + if (bodyLength > 0) { + // Skip message body if not a record batch + in.position(in.position() + bodyLength) } + readNextBatch() } } } From b971e42ee2973ba72a7668a39e1bf4c3de919289 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 22 Jun 2018 15:04:15 -0700 Subject: [PATCH 23/30] fixed wording of _collectAsArrow --- python/pyspark/sql/dataframe.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 136f9a41e03e3..1d5b3cc91a2de 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2095,9 +2095,9 @@ def toPandas(self): _check_dataframe_localize_timestamps import pyarrow - batch_iter = self._collectAsArrow() - if batch_iter: - table = pyarrow.Table.from_batches(batch_iter) + batches = self._collectAsArrow() + if batches: + table = pyarrow.Table.from_batches(batches) pdf = table.to_pandas() pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) @@ -2146,8 +2146,8 @@ def toPandas(self): def _collectAsArrow(self): """ - Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed - and available. + Returns all records as a list of ArrowRecordBatches, pyarrow must be installed + and available on driver and worker Python environments. .. note:: Experimental. """ From 876c066bfb35296ca4c3961e598582eb04e3aacd Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 23 Jul 2018 17:38:47 -0700 Subject: [PATCH 24/30] forgot to inline test data --- .../spark/sql/execution/arrow/ArrowConvertersSuite.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 657f0d93bac66..e3d9e8312a67d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1343,12 +1343,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("ArrowBatchStreamWriter roundtrip") { - val inputRows = (0 until 9).map { i => - InternalRow(i) - } :+ InternalRow(null) + val inputRows = (0 until 9).map(InternalRow(_)) :+ InternalRow(null) val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) - val ctx = TaskContext.empty() val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) From a25248e6d5fb04bdaddfc185801ff4439f9d1654 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 27 Jun 2018 13:40:19 -0700 Subject: [PATCH 25/30] changed toPandas to send out of order batches, followed by batch order indices fixed batch indexing when multiple batches per partition change batch order to tuple of ints added assert and reset to BatchOrderSerializer remove StopIteration from BatchOrderSerializer --- python/pyspark/serializers.py | 35 ++++++++++++++++ python/pyspark/sql/dataframe.py | 17 ++++---- python/pyspark/sql/tests.py | 6 +++ .../scala/org/apache/spark/sql/Dataset.scala | 42 +++++++++---------- 4 files changed, 72 insertions(+), 28 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 0d2b1d93baa74..61234f672a528 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -184,6 +184,41 @@ def loads(self, obj): raise NotImplementedError +class BatchOrderSerializer(Serializer): + """ + Deserialize a stream of batches followed by batch order information. + """ + + def __init__(self, serializer): + self.serializer = serializer + self.batch_order = None + + def dump_stream(self, iterator, stream): + return self.serializer.dump_stream(iterator, stream) + + def load_stream(self, stream): + for batch in self.serializer.load_stream(stream): + yield batch + num = read_int(stream) + self.batch_order = [] + for i in xrange(num): + index = read_int(stream) + self.batch_order.append(index) + + def get_batch_order_and_reset(self): + """ + Returns a list of indices to put batches read from load_stream in the correct order. + This must be called after load_stream and will clear the batch order after calling. + """ + assert self.batch_order is not None, "Must call load_stream first to read batch order" + batch_order = self.batch_order + self.batch_order = None + return batch_order + + def __repr__(self): + return "BatchOrderSerializer(%s)" % self.serializer + + class ArrowStreamSerializer(Serializer): """ Serializes Arrow record batches as a stream. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1d5b3cc91a2de..ba521f326e3be 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,8 +29,8 @@ from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \ - UTF8Deserializer +from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, BatchOrderSerializer, \ + PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -2095,9 +2095,11 @@ def toPandas(self): _check_dataframe_localize_timestamps import pyarrow - batches = self._collectAsArrow() + # Collect un-ordered list of batches, and list of correct order indices + batches, batch_order = self._collectAsArrow() if batches: - table = pyarrow.Table.from_batches(batches) + # Re-order the batch list with correct order to build a table + table = pyarrow.Table.from_batches([batches[i] for i in batch_order]) pdf = table.to_pandas() pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) @@ -2146,14 +2148,15 @@ def toPandas(self): def _collectAsArrow(self): """ - Returns all records as a list of ArrowRecordBatches, pyarrow must be installed - and available on driver and worker Python environments. + Returns all records as a list of ArrowRecordBatches and batch order as a list of indices, + pyarrow must be installed and available on driver and worker Python environments. .. note:: Experimental. """ + ser = BatchOrderSerializer(ArrowStreamSerializer()) with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(sock_info, ArrowStreamSerializer())) + return list(_load_from_socket(sock_info, ser)), ser.get_batch_order_and_reset() ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 487eb19c3b98a..3bf506bee96a2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4018,6 +4018,12 @@ def test_timestamp_dst(self): self.assertPandasEqual(pdf, df_from_python.toPandas()) self.assertPandasEqual(pdf, df_from_pandas.toPandas()) + def test_toPandas_batch_order(self): + df = self.spark.range(64, numPartitions=8).toDF("a") + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 4}): + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf, pdf_arrow) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c5b65b9bd52d1..48fe17bca9d93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql -import java.io.CharArrayWriter +import java.io.{CharArrayWriter, DataOutputStream} import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -3242,34 +3243,33 @@ class Dataset[T] private[sql]( val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone withAction("collectAsArrowToPython", queryExecution) { plan => - PythonRDD.serveToStream("serve-Arrow") { out => + PythonRDD.serveToStream("serve-Arrow") { outputStream => + val out = new DataOutputStream(outputStream) val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) val arrowBatchRdd = getArrowBatchRdd(plan) val numPartitions = arrowBatchRdd.partitions.length - // Store collection results for worst case of 1 to N-1 partitions - val results = new Array[Array[Array[Byte]]](numPartitions - 1) - var lastIndex = -1 // index of last partition written + // Batches ordered by (index of partition, batch # in partition) tuple + val batchOrder = new ArrayBuffer[(Int, Int)]() + var partitionCount = 0 - // Handler to eagerly write partitions to Python in order + // Handler to eagerly write batches to Python out of order def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { - // If result is from next partition in order - if (index - 1 == lastIndex) { + if (arrowBatches.nonEmpty) { batchWriter.writeBatches(arrowBatches.iterator) - lastIndex += 1 - // Write stored partitions that come next in order - while (lastIndex < results.length && results(lastIndex) != null) { - batchWriter.writeBatches(results(lastIndex).iterator) - results(lastIndex) = null - lastIndex += 1 - } - // After last batch, end the stream - if (lastIndex == results.length) { - batchWriter.end() + arrowBatches.indices.foreach { i => batchOrder.append((index, i)) } + } + partitionCount += 1 + + // After last batch, end the stream and write batch order + if (partitionCount == numPartitions) { + batchWriter.end() + out.writeInt(batchOrder.length) + // Batch order indices are from 0 to N-1 batches, sorted by order they arrived + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, i) => + out.writeInt(i) } - } else { - // Store partitions received out of order - results(index - 1) = arrowBatches + out.flush() } } From daa907470c6b8f4c7ec2606baf90f2e817c60e6c Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 21 Aug 2018 14:19:40 -0700 Subject: [PATCH 26/30] cleanup from review --- .../scala/org/apache/spark/api/python/PythonRDD.scala | 9 +++++---- .../spark/sql/execution/arrow/ArrowConverters.scala | 7 ++----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 81e9eb75d5646..e5dd9617d55cc 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -404,8 +404,8 @@ private[spark] object PythonRDD extends Logging { } /** - * Create a socket server and background thread to execute the block of code - * for the given DataOutputStream. + * Create a socket server and background thread to execute the writeFunc + * with the given OutputStream. * * The socket server can only accept one connection, or close if no connection * in 15 seconds. @@ -416,7 +416,8 @@ private[spark] object PythonRDD extends Logging { * The thread will terminate after the block of code is executed or any * exceptions happen. */ - private[spark] def serveToStream(threadName: String)(block: OutputStream => Unit): Array[Any] = { + private[spark] def serveToStream( + threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -430,7 +431,7 @@ private[spark] object PythonRDD extends Logging { val out = new BufferedOutputStream(sock.getOutputStream) Utils.tryWithSafeFinally { - block(out) + writeFunc(out) } { out.close() sock.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index db8d26cb16044..e4e7ca9dfa980 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -56,9 +56,7 @@ private[sql] class ArrowBatchStreamWriter( * Consume iterator to write each serialized ArrowRecordBatch to the stream. */ def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = { - arrowBatchIter.foreach { batchBytes => - writeChannel.write(batchBytes) - } + arrowBatchIter.foreach(writeChannel.write) } /** @@ -249,7 +247,6 @@ private[sql] object ArrowConverters { Message.getRootAsMessage(buffer) } - // Create an iterator to get each serialized ArrowRecordBatch from a stream new Iterator[Array[Byte]] { val inputChannel = new ReadChannel(in) @@ -271,7 +268,7 @@ private[sql] object ArrowConverters { val buffer = ByteBuffer.allocate(messageLength) val msg = loadMessage(inputChannel, messageLength, buffer) - val bodyLength = msg.bodyLength().asInstanceOf[Int] + val bodyLength = msg.bodyLength().toInt if (msg.headerType() == MessageHeader.RecordBatch) { val allbuf = ByteBuffer.allocate(4 + messageLength + bodyLength) From 66b59c76e1fbefeabd139f3c9a2651e6d6d4805d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 21 Aug 2018 14:21:39 -0700 Subject: [PATCH 27/30] change naming per requests --- python/pyspark/sql/dataframe.py | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 8 ++++---- .../execution/arrow/ArrowConvertersSuite.scala | 18 +++++++++--------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ba521f326e3be..ceea0c45b0202 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2097,7 +2097,7 @@ def toPandas(self): # Collect un-ordered list of batches, and list of correct order indices batches, batch_order = self._collectAsArrow() - if batches: + if len(batches) > 0: # Re-order the batch list with correct order to build a table table = pyarrow.Table.from_batches([batches[i] for i in batch_order]) pdf = table.to_pandas() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 48fe17bca9d93..c2143590fa952 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3246,7 +3246,7 @@ class Dataset[T] private[sql]( PythonRDD.serveToStream("serve-Arrow") { outputStream => val out = new DataOutputStream(outputStream) val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) - val arrowBatchRdd = getArrowBatchRdd(plan) + val arrowBatchRdd = toArrowBatchRdd(plan) val numPartitions = arrowBatchRdd.partitions.length // Batches ordered by (index of partition, batch # in partition) tuple @@ -3386,7 +3386,7 @@ class Dataset[T] private[sql]( } /** Convert to an RDD of serialized ArrowRecordBatches. */ - private[sql] def getArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { + private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone @@ -3398,7 +3398,7 @@ class Dataset[T] private[sql]( } // This is only used in tests, for now. - private[sql] def getArrowBatchRdd: RDD[Array[Byte]] = { - getArrowBatchRdd(queryExecution.executedPlan) + private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = { + toArrowBatchRdd(queryExecution.executedPlan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index e3d9e8312a67d..c36872a6a5289 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -51,7 +51,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { test("collect to arrow record batch") { val indexData = (1 to 6).toDF("i") - val arrowBatches = indexData.getArrowBatchRdd.collect() + val arrowBatches = indexData.toArrowBatchRdd.collect() assert(arrowBatches.nonEmpty) assert(arrowBatches.length == indexData.rdd.getNumPartitions) val allocator = new RootAllocator(Long.MaxValue) @@ -1153,7 +1153,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { |} """.stripMargin - val arrowBatches = testData2.getArrowBatchRdd.collect() + val arrowBatches = testData2.toArrowBatchRdd.collect() // NOTE: testData2 should have 2 partitions -> 2 arrow batches assert(arrowBatches.length === 2) val schema = testData2.schema @@ -1168,17 +1168,17 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("empty frame collect") { - val arrowBatches = spark.emptyDataFrame.getArrowBatchRdd.collect() + val arrowBatches = spark.emptyDataFrame.toArrowBatchRdd.collect() assert(arrowBatches.isEmpty) val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") - val filteredArrowBatches = filteredDF.filter("i < 0").getArrowBatchRdd.collect() + val filteredArrowBatches = filteredDF.filter("i < 0").toArrowBatchRdd.collect() assert(filteredArrowBatches.isEmpty) } test("empty partition collect") { val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") - val arrowBatches = emptyPart.getArrowBatchRdd.collect() + val arrowBatches = emptyPart.toArrowBatchRdd.collect() assert(arrowBatches.length === 1) val allocator = new RootAllocator(Long.MaxValue) val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) @@ -1192,7 +1192,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val maxRecordsPerBatch = 3 spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") - val arrowBatches = df.getArrowBatchRdd.collect() + val arrowBatches = df.toArrowBatchRdd.collect() assert(arrowBatches.length >= 4) val allocator = new RootAllocator(Long.MaxValue) val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) @@ -1217,8 +1217,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { mapData.toDF().getArrowBatchRdd.collect() } - runUnsupported { complexData.getArrowBatchRdd.collect() } + runUnsupported { mapData.toDF().toArrowBatchRdd.collect() } + runUnsupported { complexData.toArrowBatchRdd.collect() } } test("test Arrow Validator") { @@ -1379,7 +1379,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { private def collectAndValidate( df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator - val batchBytes = df.coalesce(1).getArrowBatchRdd.collect().head + val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head val tempFile = new File(tempDataPath, file) Files.write(json, tempFile, StandardCharsets.UTF_8) validateConversion(df.schema, batchBytes, tempFile, timeZoneId) From ed248f92badbf6e35baabc0fdaa49705403a7996 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 21 Aug 2018 14:30:22 -0700 Subject: [PATCH 28/30] Revert "changed toPandas to send out of order batches, followed by batch order indices" This reverts commit a25248e6d5fb04bdaddfc185801ff4439f9d1654. --- python/pyspark/serializers.py | 35 ---------------- python/pyspark/sql/dataframe.py | 18 ++++---- python/pyspark/sql/tests.py | 6 --- .../scala/org/apache/spark/sql/Dataset.scala | 42 +++++++++---------- 4 files changed, 28 insertions(+), 73 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 61234f672a528..0d2b1d93baa74 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -184,41 +184,6 @@ def loads(self, obj): raise NotImplementedError -class BatchOrderSerializer(Serializer): - """ - Deserialize a stream of batches followed by batch order information. - """ - - def __init__(self, serializer): - self.serializer = serializer - self.batch_order = None - - def dump_stream(self, iterator, stream): - return self.serializer.dump_stream(iterator, stream) - - def load_stream(self, stream): - for batch in self.serializer.load_stream(stream): - yield batch - num = read_int(stream) - self.batch_order = [] - for i in xrange(num): - index = read_int(stream) - self.batch_order.append(index) - - def get_batch_order_and_reset(self): - """ - Returns a list of indices to put batches read from load_stream in the correct order. - This must be called after load_stream and will clear the batch order after calling. - """ - assert self.batch_order is not None, "Must call load_stream first to read batch order" - batch_order = self.batch_order - self.batch_order = None - return batch_order - - def __repr__(self): - return "BatchOrderSerializer(%s)" % self.serializer - - class ArrowStreamSerializer(Serializer): """ Serializes Arrow record batches as a stream. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ceea0c45b0202..13c805680a295 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,8 +29,8 @@ from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, BatchOrderSerializer, \ - PickleSerializer, UTF8Deserializer +from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \ + UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -2094,12 +2094,9 @@ def toPandas(self): from pyspark.sql.types import _check_dataframe_convert_date, \ _check_dataframe_localize_timestamps import pyarrow - - # Collect un-ordered list of batches, and list of correct order indices - batches, batch_order = self._collectAsArrow() + batches = self._collectAsArrow() if len(batches) > 0: - # Re-order the batch list with correct order to build a table - table = pyarrow.Table.from_batches([batches[i] for i in batch_order]) + table = pyarrow.Table.from_batches(batches) pdf = table.to_pandas() pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) @@ -2148,15 +2145,14 @@ def toPandas(self): def _collectAsArrow(self): """ - Returns all records as a list of ArrowRecordBatches and batch order as a list of indices, - pyarrow must be installed and available on driver and worker Python environments. + Returns all records as a list of ArrowRecordBatches, pyarrow must be installed + and available on driver and worker Python environments. .. note:: Experimental. """ - ser = BatchOrderSerializer(ArrowStreamSerializer()) with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(sock_info, ser)), ser.get_batch_order_and_reset() + return list(_load_from_socket(sock_info, ArrowStreamSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3bf506bee96a2..487eb19c3b98a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4018,12 +4018,6 @@ def test_timestamp_dst(self): self.assertPandasEqual(pdf, df_from_python.toPandas()) self.assertPandasEqual(pdf, df_from_pandas.toPandas()) - def test_toPandas_batch_order(self): - df = self.spark.range(64, numPartitions=8).toDF("a") - with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 4}): - pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf, pdf_arrow) - @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c2143590fa952..24a84e35aa06b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql -import java.io.{CharArrayWriter, DataOutputStream} +import java.io.CharArrayWriter import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -3243,33 +3242,34 @@ class Dataset[T] private[sql]( val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone withAction("collectAsArrowToPython", queryExecution) { plan => - PythonRDD.serveToStream("serve-Arrow") { outputStream => - val out = new DataOutputStream(outputStream) + PythonRDD.serveToStream("serve-Arrow") { out => val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) val arrowBatchRdd = toArrowBatchRdd(plan) val numPartitions = arrowBatchRdd.partitions.length - // Batches ordered by (index of partition, batch # in partition) tuple - val batchOrder = new ArrayBuffer[(Int, Int)]() - var partitionCount = 0 + // Store collection results for worst case of 1 to N-1 partitions + val results = new Array[Array[Array[Byte]]](numPartitions - 1) + var lastIndex = -1 // index of last partition written - // Handler to eagerly write batches to Python out of order + // Handler to eagerly write partitions to Python in order def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { - if (arrowBatches.nonEmpty) { + // If result is from next partition in order + if (index - 1 == lastIndex) { batchWriter.writeBatches(arrowBatches.iterator) - arrowBatches.indices.foreach { i => batchOrder.append((index, i)) } - } - partitionCount += 1 - - // After last batch, end the stream and write batch order - if (partitionCount == numPartitions) { - batchWriter.end() - out.writeInt(batchOrder.length) - // Batch order indices are from 0 to N-1 batches, sorted by order they arrived - batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, i) => - out.writeInt(i) + lastIndex += 1 + // Write stored partitions that come next in order + while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 + } + // After last batch, end the stream + if (lastIndex == results.length) { + batchWriter.end() } - out.flush() + } else { + // Store partitions received out of order + results(index - 1) = arrowBatches } } From 92b8e2669b96e76c0fa712438f0a7b8f52324dea Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 21 Aug 2018 15:56:47 -0700 Subject: [PATCH 29/30] cleanup after Java Arrow 0.10.0, fixup and simplify getBatchesFromStream --- .../sql/execution/arrow/ArrowConverters.scala | 74 +++++++++---------- 1 file changed, 34 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 8712c0845b097..b1731633d5222 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -18,24 +18,24 @@ package org.apache.spark.sql.execution.arrow import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream} -import java.nio.ByteBuffer import java.nio.channels.{Channels, SeekableByteChannel} import scala.collection.JavaConverters._ -import org.apache.arrow.flatbuf.{Message, MessageHeader} +import org.apache.arrow.flatbuf.MessageHeader import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel} import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD +import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferOutputStream, Utils} /** @@ -63,9 +63,7 @@ private[sql] class ArrowBatchStreamWriter( * End the Arrow stream, does not close output stream. */ def end(): Unit = { - // Write End of Stream - // TODO: this could be a static function in ArrowStreamWriter - writeChannel.writeIntLittleEndian(0) + ArrowStreamWriter.writeEndOfStream(writeChannel) } } @@ -186,8 +184,8 @@ private[sql] object ArrowConverters { batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { val in = new ByteArrayInputStream(batchBytes) - MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) - .asInstanceOf[ArrowRecordBatch] // throws IOException + MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } /** @@ -228,28 +226,8 @@ private[sql] object ArrowConverters { */ private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { - // TODO: this could be moved to Arrow - def readMessageLength(in: ReadChannel): Int = { - val buffer = ByteBuffer.allocate(4) - if (in.readFully(buffer) != 4) { - return 0 - } - MessageSerializer.bytesToInt(buffer.array()) - } - - // TODO: this could be moved to Arrow - def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { - if (in.readFully(buffer) != messageLength) { - throw new java.io.IOException( - "Unexpected end of stream trying to read message.") - } - buffer.rewind() - Message.getRootAsMessage(buffer) - } - // Create an iterator to get each serialized ArrowRecordBatch from a stream new Iterator[Array[Byte]] { - val inputChannel = new ReadChannel(in) var batch: Array[Byte] = readNextBatch() override def hasNext: Boolean = batch != null @@ -261,26 +239,42 @@ private[sql] object ArrowConverters { } def readNextBatch(): Array[Byte] = { - val messageLength = readMessageLength(inputChannel) - if (messageLength == 0) { + val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in)) + if (msgMetadata == null) { return null } - val buffer = ByteBuffer.allocate(messageLength) - val msg = loadMessage(inputChannel, messageLength, buffer) - val bodyLength = msg.bodyLength().toInt + // Get the length of the body, which has not be read at this point + val bodyLength = msgMetadata.getMessageBodyLength.toInt - if (msg.headerType() == MessageHeader.RecordBatch) { - val allbuf = ByteBuffer.allocate(4 + messageLength + bodyLength) - allbuf.put(WriteChannel.intToBytes(messageLength)) - allbuf.put(buffer) - inputChannel.readFully(allbuf) - allbuf.array() + // Only care about RecordBatch data, skip Schema and unsupported Dictionary messages + if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) { + + // Create output backed by buffer to hold msg length (int32), msg metadata, msg body + val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength) + + // Write message metadata to buffer output stream + MessageSerializer.writeMessageBuffer( + new WriteChannel(Channels.newChannel(bbout)), + msgMetadata.getMessageLength, + msgMetadata.getMessageBuffer) + + // Get a zero-copy ByteBuffer with metadata already written + bbout.close() + val bb = bbout.toByteBuffer + bb.position(bbout.getCount()) + + // Read message body directly into the ByteBuffer to avoid copy, return backed byte array + bb.limit(bb.capacity()) + JavaUtils.readFully(in, bb) + bb.array() } else { if (bodyLength > 0) { // Skip message body if not a record batch in.position(in.position() + bodyLength) } + + // Proceed to next message readNextBatch() } } From 554964465dbcb99cc313620fafb0fc41acfd4304 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 22 Aug 2018 21:18:09 -0700 Subject: [PATCH 30/30] used tryWithResource, improved comments --- .../sql/execution/arrow/ArrowConverters.scala | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index b1731633d5222..4aea1fa6f9d7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -210,14 +210,11 @@ private[sql] object ArrowConverters { private[sql] def readArrowStreamFromFile( sqlContext: SQLContext, filename: String): JavaRDD[Array[Byte]] = { - val fileStream = new FileInputStream(filename) - try { - // Create array so that we can safely close the file + Utils.tryWithResource(new FileInputStream(filename)) { fileStream => + // Create array to consume iterator so that we can safely close the file val batches = getBatchesFromStream(fileStream.getChannel).toArray // Parallelize the record batches to create an RDD JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) - } finally { - fileStream.close() } } @@ -226,7 +223,7 @@ private[sql] object ArrowConverters { */ private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { - // Create an iterator to get each serialized ArrowRecordBatch from a stream + // Iterate over the serialized Arrow RecordBatch messages from a stream new Iterator[Array[Byte]] { var batch: Array[Byte] = readNextBatch() @@ -238,28 +235,31 @@ private[sql] object ArrowConverters { prevBatch } + // This gets the next serialized ArrowRecordBatch by reading message metadata to check if it + // is a RecordBatch message and then returning the complete serialized message which consists + // of a int32 length, serialized message metadata and a serialized RecordBatch message body def readNextBatch(): Array[Byte] = { val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in)) if (msgMetadata == null) { return null } - // Get the length of the body, which has not be read at this point + // Get the length of the body, which has not been read at this point val bodyLength = msgMetadata.getMessageBodyLength.toInt - // Only care about RecordBatch data, skip Schema and unsupported Dictionary messages + // Only care about RecordBatch messages, skip Schema and unsupported Dictionary messages if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) { - // Create output backed by buffer to hold msg length (int32), msg metadata, msg body + // Buffer backed output large enough to hold the complete serialized message val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength) - // Write message metadata to buffer output stream + // Write message metadata to ByteBuffer output stream MessageSerializer.writeMessageBuffer( new WriteChannel(Channels.newChannel(bbout)), msgMetadata.getMessageLength, msgMetadata.getMessageBuffer) - // Get a zero-copy ByteBuffer with metadata already written + // Get a zero-copy ByteBuffer with already contains message metadata, must close first bbout.close() val bb = bbout.toByteBuffer bb.position(bbout.getCount()) @@ -270,7 +270,7 @@ private[sql] object ArrowConverters { bb.array() } else { if (bodyLength > 0) { - // Skip message body if not a record batch + // Skip message body if not a RecordBatch in.position(in.position() + bodyLength) }