diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 02ba40cb853b6..86febae8aa079 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2285,57 +2285,93 @@ class Dataset[T] private[sql]( } /** - * Return an ArrowRecordBatch - * - * @group action - * @since 2.2.0 + * Transform Spark DataType to Arrow ArrowType. */ - @DeveloperApi - def collectAsArrow(): ArrowRecordBatch = { + private[sql] def dataTypeToArrowType(dt: DataType): ArrowType = { + dt match { + case IntegerType => + new ArrowType.Int(8 * IntegerType.defaultSize, true) + case _ => + throw new IllegalArgumentException(s"Unsupported data type") + } + } - // TODO - might be more efficient to do conversion on workers before collect - /* - val vector = MinorType.LIST.getNewVector("TODO", null, null) - withNewExecutionId { - queryExecution.executedPlan.executeToIterator().map(boundEnc.fromRow) + /** + * Transform Spark StructType to Arrow Schema. + */ + private[sql] def schemaToArrowSchema(schema: StructType): Schema = { + val arrowFields = schema.fields.map { + case StructField(name, dataType, nullable, metadata) => + // TODO: Consider nested types + new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava) } - vector.getFieldBuffers.asScala.toArray - */ + val arrowSchema = new Schema(arrowFields.toIterable.asJava) + arrowSchema + } + + /** + * Compute the number of bytes needed to build validity map. According to + * [Arrow Layout](https://github.com/apache/arrow/blob/master/format/Layout.md#null-bitmaps), + * the length of the validity bitmap should be multiples of 64 bytes. + */ + private def numBytesOfBitmap(numOfRows: Int): Int = { + Math.ceil(numOfRows / 64.0).toInt * 8 + } - val rootAllocator = new RootAllocator(1024) // TODO - size?? + /** + * Infer the validity map from the internal rows. + * @param rows An array of InternalRows + * @param idx Index of current column in the array of InternalRows + * @param field StructField related to the current column + * @param allocator ArrowBuf allocator + */ + private def internalRowToValidityMap( + rows: Array[InternalRow], idx: Int, field: StructField, allocator: RootAllocator): ArrowBuf = { + val buf = allocator.buffer(numBytesOfBitmap(rows.length)) + buf + } - def buf(bytes: Array[Byte]): ArrowBuf = { - val buffer = rootAllocator.buffer(bytes.length) - buffer.writeBytes(bytes) - buffer - } + /** + * Transfer an array of InternalRow to an ArrowRecordBatch. + */ + private[sql] def internalRowsToArrowRecordBatch( + rows: Array[InternalRow], allocator: RootAllocator): ArrowRecordBatch = { + val numOfRows = rows.length + + val buffers = this.schema.fields.zipWithIndex.flatMap { case (field, idx) => + val validity = internalRowToValidityMap(rows, idx, field, allocator) + val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) + rows.foreach { row => buf.writeInt(row.getInt(idx)) } + Array(validity, buf) + }.toList.asJava + + val fieldNodes = this.schema.fields.zipWithIndex.map { case (field, idx) => + if (field.nullable) { + new ArrowFieldNode(numOfRows, 0) + } else { + new ArrowFieldNode(numOfRows, 0) + } + }.toList.asJava + + new ArrowRecordBatch(numOfRows, fieldNodes, buffers) + } + /** + * Collect a Dataset to an ArrowRecordBatch. + * + * @group action + * @since 2.2.0 + */ + @DeveloperApi + def collectAsArrow(): ArrowRecordBatch = { + val allocator = new RootAllocator(Long.MaxValue) withNewExecutionId { try { - - def toArrow(internalRow: InternalRow): ArrowBuf = { - val buf = rootAllocator.buffer(128) // TODO - size?? - // TODO internalRow -> buf - buf.setInt(0, 1) - buf - } - val iter = queryExecution.executedPlan.executeCollect().map(toArrow) - val arrowBufList = iter.toList - val nodes: List[ArrowFieldNode] = null // TODO - new ArrowRecordBatch(arrowBufList.length, nodes.asJava, arrowBufList.asJava) - - /* - val validity = Array[Byte](255.asInstanceOf[Byte], 0) - val values = Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) - val validityb = buf(validity) - val valuesb = buf(values) - new ArrowRecordBatch( - 16, List(new ArrowFieldNode(16, 8)).asJava, List(validityb, valuesb).asJava) - */ + val collectedRows = queryExecution.executedPlan.executeCollect() + val recordBatch = internalRowsToArrowRecordBatch(collectedRows, allocator) + recordBatch } catch { case e: Exception => - // logError - // (s"Error converting InternalRow to ArrowBuf; ${e.getMessage}:\n$queryExecution") throw e } } @@ -2710,23 +2746,22 @@ class Dataset[T] private[sql]( } } + /** + * Collect a Dataset as an ArrowRecordBatch, and serve the ArrowRecordBatch to PySpark. + */ private[sql] def collectAsArrowToPython(): Int = { - val batch = collectAsArrow() - // TODO - temporary schema to test - val schema = new Schema(Seq( - new Field("testField", true, new ArrowType.Int(8, true), List.empty[Field].asJava) - ).asJava) + val recordBatch = collectAsArrow() + val arrowSchema = schemaToArrowSchema(this.schema) val out = new ByteArrayOutputStream() try { - val writer = new ArrowWriter(Channels.newChannel(out), schema) - writer.writeRecordBatch(batch) + val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema) + writer.writeRecordBatch(recordBatch) writer.close() } catch { case e: Exception => - // logError - // (s"Error writing ArrowRecordBatch to Python; ${e.getMessage}:\n$queryExecution") throw e } + withNewExecutionId { PythonRDD.serveIterator(Iterator(out.toByteArray), "serve-Arrow") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 192c8b9958e07..cc367acae2ba4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -17,16 +17,9 @@ package org.apache.spark.sql -import java.io._ -import java.net.{InetAddress, InetSocketAddress, Socket} -import java.nio.ByteBuffer -import java.nio.channels.{Channels, FileChannel, SocketChannel} +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} -import io.netty.buffer.ArrowBuf -import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.file.ArrowReader - import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.streaming.MemoryStream @@ -926,49 +919,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.withColumn("b", expr("0")).as[ClassData] .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() }) } - - def array(buf: ArrowBuf): Array[Byte] = { - val bytes = Array.ofDim[Byte](buf.readableBytes()) - buf.readBytes(bytes) - bytes - } - - test("Collect as arrow to python") { - val ds = Seq(1).toDS() - val port = ds.collectAsArrowToPython() - - val s = new Socket(InetAddress.getByName("localhost"), port) - val is = s.getInputStream - - val dis = new DataInputStream(is) - val len = dis.readInt() - val allocator = new RootAllocator(len) - - val buffer = Array.ofDim[Byte](len) - val bytes = dis.read(buffer) - - - var aFile = new RandomAccessFile("/tmp/nio-data.txt", "rw") - aFile.write(bytes) - aFile.close() - - aFile = new RandomAccessFile("/tmp/nio-data.txt", "r") - val fChannel = aFile.getChannel - - val reader = new ArrowReader(fChannel, allocator) - val footer = reader.readFooter() - val schema = footer.getSchema - val blocks = footer.getRecordBatches - val recordBatch = reader.readRecordBatch(blocks.get(0)) - - val nodes = recordBatch.getNodes - val buffers = recordBatch.getBuffers - - // scalastyle:off println - println(array(buffers.get(0)).mkString(", ")) - println(array(buffers.get(1)).mkString(", ")) - // scalastyle:on println - } } case class Generic[T](id: T, value: Double) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala new file mode 100644 index 0000000000000..e954cdc751a6c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.{DataInputStream, EOFException, RandomAccessFile} +import java.net.{InetAddress, Socket} +import java.nio.channels.FileChannel + +import io.netty.buffer.ArrowBuf +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.file.ArrowReader +import org.apache.arrow.vector.schema.ArrowRecordBatch + +import org.apache.spark.sql.test.SharedSQLContext + +case class ArrowIntTest(a: Int, b: Int) + +class DatasetToArrowSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + test("Collect as arrow to python") { + + val ds = Seq(ArrowIntTest(1, 2), ArrowIntTest(2, 3), ArrowIntTest(3, 4)).toDS() + + val port = ds.collectAsArrowToPython() + + val clientThread: Thread = new Thread(new Runnable() { + def run() { + try { + val receiver: RecordBatchReceiver = new RecordBatchReceiver + val record: ArrowRecordBatch = receiver.read(port) + } + catch { + case e: Exception => + throw e + } + } + }) + + clientThread.start() + + try { + clientThread.join() + } catch { + case e: InterruptedException => + throw e + case _ => + } + } +} + +class RecordBatchReceiver { + + def array(buf: ArrowBuf): Array[Byte] = { + val bytes = Array.ofDim[Byte](buf.readableBytes()) + buf.readBytes(bytes) + bytes + } + + def connectAndRead(port: Int): (Array[Byte], Int) = { + val s = new Socket(InetAddress.getByName("localhost"), port) + val is = s.getInputStream + + val dis = new DataInputStream(is) + val len = dis.readInt() + + val buffer = Array.ofDim[Byte](len) + val bytesRead = dis.read(buffer) + if (bytesRead != len) { + throw new EOFException("Wrong EOF") + } + (buffer, len) + } + + def makeFile(buffer: Array[Byte]): FileChannel = { + var aFile = new RandomAccessFile("/tmp/nio-data.txt", "rw") + aFile.write(buffer) + aFile.close() + + aFile = new RandomAccessFile("/tmp/nio-data.txt", "r") + val fChannel = aFile.getChannel + fChannel + } + + def readRecordBatch(fc: FileChannel, len: Int): ArrowRecordBatch = { + val allocator = new RootAllocator(len) + val reader = new ArrowReader(fc, allocator) + val footer = reader.readFooter() + val schema = footer.getSchema + val blocks = footer.getRecordBatches + val recordBatch = reader.readRecordBatch(blocks.get(0)) + recordBatch + } + + def read(port: Int): ArrowRecordBatch = { + val (buffer, len) = connectAndRead(port) + val fc = makeFile(buffer) + readRecordBatch(fc, len) + } +}