diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 197f4643e6134..e639a842754bd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -399,6 +399,26 @@ private[spark] object PythonRDD extends Logging { * data collected from this job, and the secret for authentication. */ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { + serveToStream(threadName) { out => + writeIteratorToStream(items, new DataOutputStream(out)) + } + } + + /** + * Create a socket server and background thread to execute the writeFunc + * with the given OutputStream. + * + * The socket server can only accept one connection, or close if no connection + * in 15 seconds. + * + * Once a connection comes in, it will execute the block of code and pass in + * the socket output stream. + * + * The thread will terminate after the block of code is executed or any + * exceptions happen. + */ + private[spark] def serveToStream( + threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -410,9 +430,9 @@ private[spark] object PythonRDD extends Logging { val sock = serverSocket.accept() authHelper.authClient(sock) - val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + val out = new BufferedOutputStream(sock.getOutputStream) Utils.tryWithSafeFinally { - writeIteratorToStream(items, out) + writeFunc(out) } { out.close() sock.close() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b77fa0ee2892b..4cabae4b2f50b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -494,10 +494,14 @@ def f(split, iterator): c = list(c) # Make it a list so we can compute its length batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - jrdd = self._serialize_to_jvm(c, numSlices, serializer) + + def reader_func(temp_filename): + return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) + + jrdd = self._serialize_to_jvm(c, serializer, reader_func) return RDD(jrdd, self, serializer) - def _serialize_to_jvm(self, data, parallelism, serializer): + def _serialize_to_jvm(self, data, serializer, reader_func): """ Calling the Java parallelize() method with an ArrayList is too slow, because it sends O(n) Py4J commands. As an alternative, serialized @@ -507,8 +511,7 @@ def _serialize_to_jvm(self, data, parallelism, serializer): try: serializer.dump_stream(data, tempFile) tempFile.close() - readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile - return readRDDFromFile(self._jsc, tempFile.name, parallelism) + return reader_func(tempFile.name) finally: # readRDDFromFile eagerily reads the file so we can delete right after. os.unlink(tempFile.name) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 10385589c4d3b..48006778e86f2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -185,27 +185,31 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class ArrowStreamSerializer(Serializer): """ - Serializes bytes as Arrow data with the Arrow file format. + Serializes Arrow record batches as a stream. """ - def dumps(self, batch): + def dump_stream(self, iterator, stream): import pyarrow as pa - import io - sink = io.BytesIO() - writer = pa.RecordBatchFileWriter(sink, batch.schema) - writer.write_batch(batch) - writer.close() - return sink.getvalue() + writer = None + try: + for batch in iterator: + if writer is None: + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + finally: + if writer is not None: + writer.close() - def loads(self, obj): + def load_stream(self, stream): import pyarrow as pa - reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) - return reader.read_all() + reader = pa.open_stream(stream) + for batch in reader: + yield batch def __repr__(self): - return "ArrowSerializer" + return "ArrowStreamSerializer" def _create_batch(series, timezone): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 07fb260a77ea0..1affc9b4fcf6c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,7 @@ from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ +from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -2118,10 +2118,9 @@ def toPandas(self): from pyspark.sql.types import _check_dataframe_convert_date, \ _check_dataframe_localize_timestamps import pyarrow - - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) + batches = self._collectAsArrow() + if len(batches) > 0: + table = pyarrow.Table.from_batches(batches) pdf = table.to_pandas() pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) @@ -2170,14 +2169,14 @@ def toPandas(self): def _collectAsArrow(self): """ - Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed - and available. + Returns all records as a list of ArrowRecordBatches, pyarrow must be installed + and available on driver and worker Python environments. .. note:: Experimental. """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(sock_info, ArrowSerializer())) + return list(_load_from_socket(sock_info, ArrowStreamSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 19eea2fd2c775..87d8d6a59a6e9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -501,7 +501,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ - from pyspark.serializers import ArrowSerializer, _create_batch + from pyspark.serializers import ArrowStreamSerializer, _create_batch from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType from pyspark.sql.utils import require_minimum_pandas_version, \ require_minimum_pyarrow_version @@ -539,10 +539,12 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): struct.names[i] = name schema = struct - # Create the Spark DataFrame directly from the Arrow data and schema - jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer()) - jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame( - jrdd, schema.json(), self._wrapped._jsqlContext) + def reader_func(temp_filename): + return self._jvm.PythonSQLUtils.arrowReadStreamFromFile( + self._wrapped._jsqlContext, temp_filename, schema.json()) + + # Create Spark DataFrame from Arrow stream file, using one batch per partition + jdf = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func) df = DataFrame(jdf, self._wrapped) df._schema = schema return df diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 367b98563e0bf..db439b1ee76f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} +import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.python.EvaluatePython @@ -3273,13 +3273,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = - toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => + val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) + val arrowBatchRdd = toArrowBatchRdd(plan) + val numPartitions = arrowBatchRdd.partitions.length + + // Store collection results for worst case of 1 to N-1 partitions + val results = new Array[Array[Array[Byte]]](numPartitions - 1) + var lastIndex = -1 // index of last partition written + + // Handler to eagerly write partitions to Python in order + def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { + batchWriter.writeBatches(arrowBatches.iterator) + lastIndex += 1 + // Write stored partitions that come next in order + while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 + } + // After last batch, end the stream + if (lastIndex == results.length) { + batchWriter.end() + } + } else { + // Store partitions received out of order + results(index - 1) = arrowBatches + } + } + + sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) + } } } @@ -3386,20 +3422,20 @@ class Dataset[T] private[sql]( } } - /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { + /** Convert to an RDD of serialized ArrowRecordBatches. */ + private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() - ArrowConverters.toPayloadIterator( + ArrowConverters.toBatchIterator( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } // This is only used in tests, for now. - private[sql] def toArrowPayload: RDD[ArrowPayload] = { - toArrowPayload(queryExecution.executedPlan) + private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = { + toArrowBatchRdd(queryExecution.executedPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index b33760b1edbc6..c0830e77b5a87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.api.python -import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.ExpressionInfo @@ -34,17 +33,19 @@ private[sql] object PythonSQLUtils { } /** - * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] + * using each serialized ArrowRecordBatch as a partition. * - * @param payloadRDD A JavaRDD of ArrowPayloads. - * @param schemaString JSON Formatted Schema for ArrowPayloads. * @param sqlContext The active [[SQLContext]]. - * @return The converted [[DataFrame]]. + * @param filename File to read the Arrow stream from. + * @param schemaString JSON Formatted Spark schema for Arrow batches. + * @return A new [[DataFrame]]. */ - def arrowPayloadToDataFrame( - payloadRDD: JavaRDD[Array[Byte]], - schemaString: String, - sqlContext: SQLContext): DataFrame = { - ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) + def arrowReadStreamFromFile( + sqlContext: SQLContext, + filename: String, + schemaString: String): DataFrame = { + val jrdd = ArrowConverters.readArrowStreamFromFile(sqlContext, filename) + ArrowConverters.toDataFrame(jrdd, schemaString, sqlContext) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 6a5ac2413d73c..1a48bc8398a63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -17,73 +17,75 @@ package org.apache.spark.sql.execution.arrow -import java.io.ByteArrayOutputStream -import java.nio.channels.Channels +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream} +import java.nio.channels.{Channels, SeekableByteChannel} import scala.collection.JavaConverters._ +import org.apache.arrow.flatbuf.MessageHeader import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter} -import org.apache.arrow.vector.ipc.message.ArrowRecordBatch -import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel +import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD +import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferOutputStream, Utils} /** - * Store Arrow data in a form that can be serialized by Spark and served to a Python process. + * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format. */ -private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable { +private[sql] class ArrowBatchStreamWriter( + schema: StructType, + out: OutputStream, + timeZoneId: String) { - /** - * Convert the ArrowPayload to an ArrowRecordBatch. - */ - def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { - ArrowConverters.byteArrayToBatch(payload, allocator) - } + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + // Write the Arrow schema first, before batches + MessageSerializer.serialize(writeChannel, arrowSchema) /** - * Get the ArrowPayload as a type that can be served to Python. + * Consume iterator to write each serialized ArrowRecordBatch to the stream. */ - def asPythonSerializable: Array[Byte] = payload -} - -/** - * Iterator interface to iterate over Arrow record batches and return rows - */ -private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { + def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = { + arrowBatchIter.foreach(writeChannel.write) + } /** - * Return the schema loaded from the Arrow record batch being iterated over + * End the Arrow stream, does not close output stream. */ - def schema: StructType + def end(): Unit = { + ArrowStreamWriter.writeEndOfStream(writeChannel) + } } private[sql] object ArrowConverters { /** - * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload - * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. + * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size + * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. */ - private[sql] def toPayloadIterator( + private[sql] def toBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, timeZoneId: String, - context: TaskContext): Iterator[ArrowPayload] = { + context: TaskContext): Iterator[Array[Byte]] = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = - ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) + ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) + val unloader = new VectorUnloader(root) val arrowWriter = ArrowWriter.create(root) context.addTaskCompletionListener[Unit] { _ => @@ -91,7 +93,7 @@ private[sql] object ArrowConverters { allocator.close() } - new Iterator[ArrowPayload] { + new Iterator[Array[Byte]] { override def hasNext: Boolean = rowIter.hasNext || { root.close() @@ -99,9 +101,9 @@ private[sql] object ArrowConverters { false } - override def next(): ArrowPayload = { + override def next(): Array[Byte] = { val out = new ByteArrayOutputStream() - val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + val writeChannel = new WriteChannel(Channels.newChannel(out)) Utils.tryWithSafeFinally { var rowCount = 0 @@ -111,45 +113,46 @@ private[sql] object ArrowConverters { rowCount += 1 } arrowWriter.finish() - writer.writeBatch() + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + batch.close() } { arrowWriter.reset() - writer.close() } - new ArrowPayload(out.toByteArray) + out.toByteArray } } } /** - * Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing the row iterator - * and the schema from the first batch of Arrow data read. + * Maps iterator from serialized ArrowRecordBatches to InternalRows. */ - private[sql] def fromPayloadIterator( - payloadIter: Iterator[ArrowPayload], - context: TaskContext): ArrowRowIterator = { + private[sql] def fromBatchIterator( + arrowBatchIter: Iterator[Array[Byte]], + schema: StructType, + timeZoneId: String, + context: TaskContext): Iterator[InternalRow] = { val allocator = - ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue) + ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, Long.MaxValue) + + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val root = VectorSchemaRoot.create(arrowSchema, allocator) - new ArrowRowIterator { - private var reader: ArrowFileReader = null - private var schemaRead = StructType(Seq.empty) - private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty + new Iterator[InternalRow] { + private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty context.addTaskCompletionListener[Unit] { _ => - closeReader() + root.close() allocator.close() } - override def schema: StructType = schemaRead - override def hasNext: Boolean = rowIter.hasNext || { - closeReader() - if (payloadIter.hasNext) { + if (arrowBatchIter.hasNext) { rowIter = nextBatch() true } else { + root.close() allocator.close() false } @@ -157,19 +160,11 @@ private[sql] object ArrowConverters { override def next(): InternalRow = rowIter.next() - private def closeReader(): Unit = { - if (reader != null) { - reader.close() - reader = null - } - } - private def nextBatch(): Iterator[InternalRow] = { - val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable) - reader = new ArrowFileReader(in, allocator) - reader.loadNextBatch() // throws IOException - val root = reader.getVectorSchemaRoot // throws IOException - schemaRead = ArrowUtils.fromArrowSchema(root.getSchema) + val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) + val vectorLoader = new VectorLoader(root) + vectorLoader.load(arrowRecordBatch) + arrowRecordBatch.close() val columns = root.getFieldVectors.asScala.map { vector => new ArrowColumnVector(vector).asInstanceOf[ColumnVector] @@ -183,34 +178,106 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { - val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - val reader = new ArrowFileReader(in, allocator) - - // Read a batch from a byte stream, ensure the reader is closed - Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch - } { - reader.close() - } + val in = new ByteArrayInputStream(batchBytes) + MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - val rdd = payloadRDD.rdd.mapPartitions { iter => + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone + val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd.setName("arrow"), schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { + Utils.tryWithResource(new FileInputStream(filename)) { fileStream => + // Create array to consume iterator so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) + } + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + + // Iterate over the serialized Arrow RecordBatch messages from a stream + new Iterator[Array[Byte]] { + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { + val prevBatch = batch + batch = readNextBatch() + prevBatch + } + + // This gets the next serialized ArrowRecordBatch by reading message metadata to check if it + // is a RecordBatch message and then returning the complete serialized message which consists + // of a int32 length, serialized message metadata and a serialized RecordBatch message body + def readNextBatch(): Array[Byte] = { + val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in)) + if (msgMetadata == null) { + return null + } + + // Get the length of the body, which has not been read at this point + val bodyLength = msgMetadata.getMessageBodyLength.toInt + + // Only care about RecordBatch messages, skip Schema and unsupported Dictionary messages + if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) { + + // Buffer backed output large enough to hold the complete serialized message + val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength) + + // Write message metadata to ByteBuffer output stream + MessageSerializer.writeMessageBuffer( + new WriteChannel(Channels.newChannel(bbout)), + msgMetadata.getMessageLength, + msgMetadata.getMessageBuffer) + + // Get a zero-copy ByteBuffer with already contains message metadata, must close first + bbout.close() + val bb = bbout.toByteBuffer + bb.position(bbout.getCount()) + + // Read message body directly into the ByteBuffer to avoid copy, return backed byte array + bb.limit(bb.capacity()) + JavaUtils.readFully(in, bb) + bb.array() + } else { + if (bodyLength > 0) { + // Skip message body if not a RecordBatch + in.position(in.position() + bodyLength) + } + + // Proceed to next message + readNextBatch() + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 261df06100aef..c36872a6a5289 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.execution.arrow -import java.io.File +import java.io.{ByteArrayOutputStream, DataOutputStream, File} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat @@ -26,7 +26,7 @@ import com.google.common.io.Files import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.ipc.JsonFileReader -import org.apache.arrow.vector.util.Validator +import org.apache.arrow.vector.util.{ByteArrayReadableSeekableByteChannel, Validator} import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkException, TaskContext} @@ -51,11 +51,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { test("collect to arrow record batch") { val indexData = (1 to 6).toDF("i") - val arrowPayloads = indexData.toArrowPayload.collect() - assert(arrowPayloads.nonEmpty) - assert(arrowPayloads.length == indexData.rdd.getNumPartitions) + val arrowBatches = indexData.toArrowBatchRdd.collect() + assert(arrowBatches.nonEmpty) + assert(arrowBatches.length == indexData.rdd.getNumPartitions) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) val rowCount = arrowRecordBatches.map(_.getLength).sum assert(rowCount === indexData.count()) arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0)) @@ -1153,9 +1153,9 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { |} """.stripMargin - val arrowPayloads = testData2.toArrowPayload.collect() - // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload - assert(arrowPayloads.length === 2) + val arrowBatches = testData2.toArrowBatchRdd.collect() + // NOTE: testData2 should have 2 partitions -> 2 arrow batches + assert(arrowBatches.length === 2) val schema = testData2.schema val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") @@ -1163,25 +1163,25 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { Files.write(json1, tempFile1, StandardCharsets.UTF_8) Files.write(json2, tempFile2, StandardCharsets.UTF_8) - validateConversion(schema, arrowPayloads(0), tempFile1) - validateConversion(schema, arrowPayloads(1), tempFile2) + validateConversion(schema, arrowBatches(0), tempFile1) + validateConversion(schema, arrowBatches(1), tempFile2) } test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect() - assert(arrowPayload.isEmpty) + val arrowBatches = spark.emptyDataFrame.toArrowBatchRdd.collect() + assert(arrowBatches.isEmpty) val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i") - val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect() - assert(filteredArrowPayload.isEmpty) + val filteredArrowBatches = filteredDF.filter("i < 0").toArrowBatchRdd.collect() + assert(filteredArrowBatches.isEmpty) } test("empty partition collect") { val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") - val arrowPayloads = emptyPart.toArrowPayload.collect() - assert(arrowPayloads.length === 1) + val arrowBatches = emptyPart.toArrowBatchRdd.collect() + assert(arrowBatches.length === 1) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) assert(arrowRecordBatches.head.getLength == 1) arrowRecordBatches.foreach(_.close()) allocator.close() @@ -1192,10 +1192,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val maxRecordsPerBatch = 3 spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") - val arrowPayloads = df.toArrowPayload.collect() - assert(arrowPayloads.length >= 4) + val arrowBatches = df.toArrowBatchRdd.collect() + assert(arrowBatches.length >= 4) val allocator = new RootAllocator(Long.MaxValue) - val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) + val arrowRecordBatches = arrowBatches.map(ArrowConverters.loadBatch(_, allocator)) var recordCount = 0 arrowRecordBatches.foreach { batch => assert(batch.getLength > 0) @@ -1217,8 +1217,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { mapData.toDF().toArrowPayload.collect() } - runUnsupported { complexData.toArrowPayload.collect() } + runUnsupported { mapData.toDF().toArrowBatchRdd.collect() } + runUnsupported { complexData.toArrowBatchRdd.collect() } } test("test Arrow Validator") { @@ -1318,7 +1318,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - test("roundtrip payloads") { + test("roundtrip arrow batches") { val inputRows = (0 until 9).map { i => InternalRow(i) } :+ InternalRow(null) @@ -1326,10 +1326,41 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) - val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) + val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) - assert(schema == outputRowIter.schema) + var count = 0 + outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { + assert(row.getInt(0) == i) + } else { + assert(row.isNullAt(0)) + } + count += 1 + } + + assert(count == inputRows.length) + } + + test("ArrowBatchStreamWriter roundtrip") { + val inputRows = (0 until 9).map(InternalRow(_)) :+ InternalRow(null) + + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + val ctx = TaskContext.empty() + val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + + // Write batches to Arrow stream format as a byte array + val out = new ByteArrayOutputStream() + Utils.tryWithResource(new DataOutputStream(out)) { dataOut => + val writer = new ArrowBatchStreamWriter(schema, dataOut, null) + writer.writeBatches(batchIter) + writer.end() + } + + // Read Arrow stream into batches, then convert back to rows + val in = new ByteArrayReadableSeekableByteChannel(out.toByteArray) + val readBatches = ArrowConverters.getBatchesFromStream(in) + val outputRowIter = ArrowConverters.fromBatchIterator(readBatches, schema, null, ctx) var count = 0 outputRowIter.zipWithIndex.foreach { case (row, i) => @@ -1348,15 +1379,15 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { private def collectAndValidate( df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator - val arrowPayload = df.coalesce(1).toArrowPayload.collect().head + val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head val tempFile = new File(tempDataPath, file) Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, arrowPayload, tempFile, timeZoneId) + validateConversion(df.schema, batchBytes, tempFile, timeZoneId) } private def validateConversion( sparkSchema: StructType, - arrowPayload: ArrowPayload, + batchBytes: Array[Byte], jsonFile: File, timeZoneId: String = null): Unit = { val allocator = new RootAllocator(Long.MaxValue) @@ -1368,7 +1399,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator) val vectorLoader = new VectorLoader(arrowRoot) - val arrowRecordBatch = arrowPayload.loadBatch(allocator) + val arrowRecordBatch = ArrowConverters.loadBatch(batchBytes, allocator) vectorLoader.load(arrowRecordBatch) val jsonRoot = jsonReader.read() Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot)