diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 86febae8aa079..928eefacfffb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -26,6 +26,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import io.netty.buffer.ArrowBuf +import org.apache.arrow.flatbuf.Precision import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.file.ArrowWriter import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} @@ -2291,6 +2292,16 @@ class Dataset[T] private[sql]( dt match { case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) + case StringType => + ArrowType.List.INSTANCE + case DoubleType => + new ArrowType.FloatingPoint(Precision.DOUBLE) + case FloatType => + new ArrowType.FloatingPoint(Precision.SINGLE) + case BooleanType => + ArrowType.Bool.INSTANCE + case ByteType => + new ArrowType.Int(8, false) case _ => throw new IllegalArgumentException(s"Unsupported data type") } @@ -2302,8 +2313,16 @@ class Dataset[T] private[sql]( private[sql] def schemaToArrowSchema(schema: StructType): Schema = { val arrowFields = schema.fields.map { case StructField(name, dataType, nullable, metadata) => - // TODO: Consider nested types - new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava) + dataType match { + // TODO: Consider other nested types + case StringType => + // TODO: Make sure String => List + val itemField = + new Field("item", false, ArrowType.Utf8.INSTANCE, List.empty[Field].asJava) + new Field(name, nullable, dataTypeToArrowType(dataType), List(itemField).asJava) + case _ => + new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava) + } } val arrowSchema = new Schema(arrowFields.toIterable.asJava) arrowSchema @@ -2319,16 +2338,89 @@ class Dataset[T] private[sql]( } /** - * Infer the validity map from the internal rows. - * @param rows An array of InternalRows - * @param idx Index of current column in the array of InternalRows - * @param field StructField related to the current column - * @param allocator ArrowBuf allocator + * Get an entry from the InternalRow, and then set to ArrowBuf. + * Note: No Null check for the entry. + */ + private def getAndSetToArrow( + row: InternalRow, buf: ArrowBuf, dataType: DataType, ordinal: Int): Unit = { + dataType match { + case NullType => + case BooleanType => + buf.writeBoolean(row.getBoolean(ordinal)) + case ShortType => + buf.writeShort(row.getShort(ordinal)) + case IntegerType => + buf.writeInt(row.getInt(ordinal)) + case FloatType => + buf.writeFloat(row.getFloat(ordinal)) + case DoubleType => + buf.writeDouble(row.getDouble(ordinal)) + case ByteType => + buf.writeByte(row.getByte(ordinal)) + case _ => + throw new UnsupportedOperationException( + s"Unsupported data type ${dataType.simpleString}") + } + } + + /** + * Convert an array of InternalRow to an ArrowBuf. */ - private def internalRowToValidityMap( - rows: Array[InternalRow], idx: Int, field: StructField, allocator: RootAllocator): ArrowBuf = { - val buf = allocator.buffer(numBytesOfBitmap(rows.length)) - buf + private def internalRowToArrowBuf( + rows: Array[InternalRow], + ordinal: Int, + field: StructField, + allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = { + val numOfRows = rows.length + + field.dataType match { + case IntegerType | DoubleType | FloatType | BooleanType | ByteType => + val validity = allocator.buffer(numBytesOfBitmap(numOfRows)) + val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) + var nullCount = 0 + rows.foreach { row => + if (row.isNullAt(ordinal)) { + nullCount += 1 + } else { + getAndSetToArrow(row, buf, field.dataType, ordinal) + } + } + + val fieldNode = new ArrowFieldNode(numOfRows, nullCount) + + (Array(validity, buf), Array(fieldNode)) + + case StringType => + val validityOffset = allocator.buffer(numBytesOfBitmap(numOfRows)) + val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize) + var bytesCount = 0 + bufOffset.writeInt(bytesCount) // Start position + val validityValues = allocator.buffer(numBytesOfBitmap(numOfRows)) + val bufValues = allocator.buffer(Int.MaxValue) // TODO: Reduce the size? + var nullCount = 0 + rows.foreach { row => + if (row.isNullAt(ordinal)) { + nullCount += 1 + bufOffset.writeInt(bytesCount) + } else { + val bytes = row.getUTF8String(ordinal).getBytes + bytesCount += bytes.length + bufOffset.writeInt(bytesCount) + bufValues.writeBytes(bytes) + } + } + + val fieldNodeOffset = if (field.nullable) { + new ArrowFieldNode(numOfRows, nullCount) + } else { + new ArrowFieldNode(numOfRows, 0) + } + + val fieldNodeValues = new ArrowFieldNode(bytesCount, 0) + + (Array(validityOffset, bufOffset, validityValues, bufValues), + Array(fieldNodeOffset, fieldNodeValues)) + } } /** @@ -2336,24 +2428,14 @@ class Dataset[T] private[sql]( */ private[sql] def internalRowsToArrowRecordBatch( rows: Array[InternalRow], allocator: RootAllocator): ArrowRecordBatch = { - val numOfRows = rows.length - - val buffers = this.schema.fields.zipWithIndex.flatMap { case (field, idx) => - val validity = internalRowToValidityMap(rows, idx, field, allocator) - val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) - rows.foreach { row => buf.writeInt(row.getInt(idx)) } - Array(validity, buf) - }.toList.asJava + val bufAndField = this.schema.fields.zipWithIndex.map { case (field, ordinal) => + internalRowToArrowBuf(rows, ordinal, field, allocator) + } - val fieldNodes = this.schema.fields.zipWithIndex.map { case (field, idx) => - if (field.nullable) { - new ArrowFieldNode(numOfRows, 0) - } else { - new ArrowFieldNode(numOfRows, 0) - } - }.toList.asJava + val buffers = bufAndField.flatMap(_._1).toList.asJava + val fieldNodes = bufAndField.flatMap(_._2).toList.asJava - new ArrowRecordBatch(numOfRows, fieldNodes, buffers) + new ArrowRecordBatch(rows.length, fieldNodes, buffers) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala index e954cdc751a6c..8aec3699c9dd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala @@ -17,100 +17,156 @@ package org.apache.spark.sql -import java.io.{DataInputStream, EOFException, RandomAccessFile} +import java.io._ import java.net.{InetAddress, Socket} +import java.nio.{ByteBuffer, ByteOrder} import java.nio.channels.FileChannel +import scala.util.Random + import io.netty.buffer.ArrowBuf +import org.apache.arrow.flatbuf.Precision import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.file.ArrowReader -import org.apache.arrow.vector.schema.ArrowRecordBatch +import org.apache.arrow.vector.types.pojo.{ArrowType, Field} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + -case class ArrowIntTest(a: Int, b: Int) +case class ArrowTestClass(col1: Int, col2: Double, col3: String) class DatasetToArrowSuite extends QueryTest with SharedSQLContext { import testImplicits._ + final val numElements = 4 + @transient var data: Seq[ArrowTestClass] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + data = Seq.fill(numElements)(ArrowTestClass( + Random.nextInt, Random.nextDouble, Random.nextString(Random.nextInt(100)))) + } + test("Collect as arrow to python") { + val dataset = data.toDS() + + val port = dataset.collectAsArrowToPython() + + val receiver: RecordBatchReceiver = new RecordBatchReceiver + val (buffer, numBytesRead) = receiver.connectAndRead(port) + val channel = receiver.makeFile(buffer) + val reader = new ArrowReader(channel, receiver.allocator) + + val footer = reader.readFooter() + val schema = footer.getSchema + + val numCols = schema.getFields.size() + assert(numCols === dataset.schema.fields.length) + for (i <- 0 until schema.getFields.size()) { + val arrowField = schema.getFields.get(i) + val sparkField = dataset.schema.fields(i) + assert(arrowField.getName === sparkField.name) + assert(arrowField.isNullable === sparkField.nullable) + assert(DatasetToArrowSuite.compareSchemaTypes(arrowField, sparkField)) + } + + val blockMetadata = footer.getRecordBatches + assert(blockMetadata.size() === 1) + + val recordBatch = reader.readRecordBatch(blockMetadata.get(0)) + val nodes = recordBatch.getNodes + assert(nodes.size() === numCols + 1) // +1 for Type String, which has two nodes. + + val firstNode = nodes.get(0) + assert(firstNode.getLength === numElements) + assert(firstNode.getNullCount === 0) + + val buffers = recordBatch.getBuffers + assert(buffers.size() === (numCols + 1) * 2) // +1 for Type String + + assert(receiver.getIntArray(buffers.get(1)) === data.map(_.col1)) + assert(receiver.getDoubleArray(buffers.get(3)) === data.map(_.col2)) + assert(receiver.getStringArray(buffers.get(5), buffers.get(7)) === + data.map(d => UTF8String.fromString(d.col3)).toArray) + } +} - val ds = Seq(ArrowIntTest(1, 2), ArrowIntTest(2, 3), ArrowIntTest(3, 4)).toDS() - - val port = ds.collectAsArrowToPython() - - val clientThread: Thread = new Thread(new Runnable() { - def run() { - try { - val receiver: RecordBatchReceiver = new RecordBatchReceiver - val record: ArrowRecordBatch = receiver.read(port) - } - catch { - case e: Exception => - throw e - } - } - }) - - clientThread.start() - - try { - clientThread.join() - } catch { - case e: InterruptedException => - throw e - case _ => +object DatasetToArrowSuite { + def compareSchemaTypes(arrowField: Field, sparkField: StructField): Boolean = { + val arrowType = arrowField.getType + val sparkType = sparkField.dataType + (arrowType, sparkType) match { + case (_: ArrowType.Int, _: IntegerType) => true + case (_: ArrowType.FloatingPoint, _: DoubleType) => + arrowType.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.DOUBLE + case (_: ArrowType.FloatingPoint, _: FloatType) => + arrowType.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.SINGLE + case (_: ArrowType.List, _: StringType) => + val subField = arrowField.getChildren + (subField.size() == 1) && subField.get(0).getType.isInstanceOf[ArrowType.Utf8] + case (_: ArrowType.Bool, _: BooleanType) => true + case _ => false } } } class RecordBatchReceiver { - def array(buf: ArrowBuf): Array[Byte] = { - val bytes = Array.ofDim[Byte](buf.readableBytes()) - buf.readBytes(bytes) - bytes + val allocator = new RootAllocator(Long.MaxValue) + + def getIntArray(buf: ArrowBuf): Array[Int] = { + val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer() + val resultArray = Array.ofDim[Int](buffer.remaining()) + buffer.get(resultArray) + resultArray } - def connectAndRead(port: Int): (Array[Byte], Int) = { - val s = new Socket(InetAddress.getByName("localhost"), port) - val is = s.getInputStream + def getDoubleArray(buf: ArrowBuf): Array[Double] = { + val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer() + val resultArray = Array.ofDim[Double](buffer.remaining()) + buffer.get(resultArray) + resultArray + } - val dis = new DataInputStream(is) - val len = dis.readInt() + def getStringArray(bufOffsets: ArrowBuf, bufValues: ArrowBuf): Array[UTF8String] = { + val offsets = getIntArray(bufOffsets) + val lens = offsets.zip(offsets.drop(1)) + .map { case (prevOffset, offset) => offset - prevOffset } - val buffer = Array.ofDim[Byte](len) - val bytesRead = dis.read(buffer) - if (bytesRead != len) { - throw new EOFException("Wrong EOF") + val values = array(bufValues) + val strings = offsets.zip(lens).map { case (offset, len) => + UTF8String.fromBytes(values, offset, len) } - (buffer, len) + strings } - def makeFile(buffer: Array[Byte]): FileChannel = { - var aFile = new RandomAccessFile("/tmp/nio-data.txt", "rw") - aFile.write(buffer) - aFile.close() - - aFile = new RandomAccessFile("/tmp/nio-data.txt", "r") - val fChannel = aFile.getChannel - fChannel + private def array(buf: ArrowBuf): Array[Byte] = { + val bytes = Array.ofDim[Byte](buf.readableBytes()) + buf.readBytes(bytes) + bytes } - def readRecordBatch(fc: FileChannel, len: Int): ArrowRecordBatch = { - val allocator = new RootAllocator(len) - val reader = new ArrowReader(fc, allocator) - val footer = reader.readFooter() - val schema = footer.getSchema - val blocks = footer.getRecordBatches - val recordBatch = reader.readRecordBatch(blocks.get(0)) - recordBatch + def connectAndRead(port: Int): (Array[Byte], Int) = { + val clientSocket = new Socket(InetAddress.getByName("localhost"), port) + val clientDataIns = new DataInputStream(clientSocket.getInputStream) + val messageLength = clientDataIns.readInt() + val buffer = Array.ofDim[Byte](messageLength) + clientDataIns.readFully(buffer, 0, messageLength) + (buffer, messageLength) } - def read(port: Int): ArrowRecordBatch = { - val (buffer, len) = connectAndRead(port) - val fc = makeFile(buffer) - readRecordBatch(fc, len) + def makeFile(buffer: Array[Byte]): FileChannel = { + val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName).getPath + val arrowFile = new File(tempDir, "arrow-bytes") + val arrowOus = new FileOutputStream(arrowFile.getPath) + arrowOus.write(buffer) + arrowOus.close() + + val arrowIns = new FileInputStream(arrowFile.getPath) + arrowIns.getChannel } }