Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 109 additions & 27 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import io.netty.buffer.ArrowBuf
import org.apache.arrow.flatbuf.Precision
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.file.ArrowWriter
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
Expand Down Expand Up @@ -2291,6 +2292,16 @@ class Dataset[T] private[sql](
dt match {
case IntegerType =>
new ArrowType.Int(8 * IntegerType.defaultSize, true)
case StringType =>
ArrowType.List.INSTANCE
case DoubleType =>
new ArrowType.FloatingPoint(Precision.DOUBLE)
case FloatType =>
new ArrowType.FloatingPoint(Precision.SINGLE)
case BooleanType =>
ArrowType.Bool.INSTANCE
case ByteType =>
new ArrowType.Int(8, false)
case _ =>
throw new IllegalArgumentException(s"Unsupported data type")
}
Expand All @@ -2302,8 +2313,16 @@ class Dataset[T] private[sql](
private[sql] def schemaToArrowSchema(schema: StructType): Schema = {
val arrowFields = schema.fields.map {
case StructField(name, dataType, nullable, metadata) =>
// TODO: Consider nested types
new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava)
dataType match {
// TODO: Consider other nested types
case StringType =>
// TODO: Make sure String => List<Utf8>
val itemField =
new Field("item", false, ArrowType.Utf8.INSTANCE, List.empty[Field].asJava)
new Field(name, nullable, dataTypeToArrowType(dataType), List(itemField).asJava)
case _ =>
new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava)
}
}
val arrowSchema = new Schema(arrowFields.toIterable.asJava)
arrowSchema
Expand All @@ -2319,41 +2338,104 @@ class Dataset[T] private[sql](
}

/**
* Infer the validity map from the internal rows.
* @param rows An array of InternalRows
* @param idx Index of current column in the array of InternalRows
* @param field StructField related to the current column
* @param allocator ArrowBuf allocator
* Get an entry from the InternalRow, and then set to ArrowBuf.
* Note: No Null check for the entry.
*/
private def getAndSetToArrow(
row: InternalRow, buf: ArrowBuf, dataType: DataType, ordinal: Int): Unit = {
dataType match {
case NullType =>
case BooleanType =>
buf.writeBoolean(row.getBoolean(ordinal))
case ShortType =>
buf.writeShort(row.getShort(ordinal))
case IntegerType =>
buf.writeInt(row.getInt(ordinal))
case FloatType =>
buf.writeFloat(row.getFloat(ordinal))
case DoubleType =>
buf.writeDouble(row.getDouble(ordinal))
case ByteType =>
buf.writeByte(row.getByte(ordinal))
case _ =>
throw new UnsupportedOperationException(
s"Unsupported data type ${dataType.simpleString}")
}
}

/**
* Convert an array of InternalRow to an ArrowBuf.
*/
private def internalRowToValidityMap(
rows: Array[InternalRow], idx: Int, field: StructField, allocator: RootAllocator): ArrowBuf = {
val buf = allocator.buffer(numBytesOfBitmap(rows.length))
buf
private def internalRowToArrowBuf(
rows: Array[InternalRow],
ordinal: Int,
field: StructField,
allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = {
val numOfRows = rows.length

field.dataType match {
case IntegerType | DoubleType | FloatType | BooleanType | ByteType =>
val validity = allocator.buffer(numBytesOfBitmap(numOfRows))
val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
var nullCount = 0
rows.foreach { row =>
if (row.isNullAt(ordinal)) {
nullCount += 1
} else {
getAndSetToArrow(row, buf, field.dataType, ordinal)
}
}

val fieldNode = new ArrowFieldNode(numOfRows, nullCount)

(Array(validity, buf), Array(fieldNode))

case StringType =>
val validityOffset = allocator.buffer(numBytesOfBitmap(numOfRows))
val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize)
var bytesCount = 0
bufOffset.writeInt(bytesCount) // Start position
val validityValues = allocator.buffer(numBytesOfBitmap(numOfRows))
val bufValues = allocator.buffer(Int.MaxValue) // TODO: Reduce the size?
var nullCount = 0
rows.foreach { row =>
if (row.isNullAt(ordinal)) {
nullCount += 1
bufOffset.writeInt(bytesCount)
} else {
val bytes = row.getUTF8String(ordinal).getBytes
bytesCount += bytes.length
bufOffset.writeInt(bytesCount)
bufValues.writeBytes(bytes)
}
}

val fieldNodeOffset = if (field.nullable) {
new ArrowFieldNode(numOfRows, nullCount)
} else {
new ArrowFieldNode(numOfRows, 0)
}

val fieldNodeValues = new ArrowFieldNode(bytesCount, 0)

(Array(validityOffset, bufOffset, validityValues, bufValues),
Array(fieldNodeOffset, fieldNodeValues))
}
}

/**
* Transfer an array of InternalRow to an ArrowRecordBatch.
*/
private[sql] def internalRowsToArrowRecordBatch(
rows: Array[InternalRow], allocator: RootAllocator): ArrowRecordBatch = {
val numOfRows = rows.length

val buffers = this.schema.fields.zipWithIndex.flatMap { case (field, idx) =>
val validity = internalRowToValidityMap(rows, idx, field, allocator)
val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
rows.foreach { row => buf.writeInt(row.getInt(idx)) }
Array(validity, buf)
}.toList.asJava
val bufAndField = this.schema.fields.zipWithIndex.map { case (field, ordinal) =>
internalRowToArrowBuf(rows, ordinal, field, allocator)
}

val fieldNodes = this.schema.fields.zipWithIndex.map { case (field, idx) =>
if (field.nullable) {
new ArrowFieldNode(numOfRows, 0)
} else {
new ArrowFieldNode(numOfRows, 0)
}
}.toList.asJava
val buffers = bufAndField.flatMap(_._1).toList.asJava
val fieldNodes = bufAndField.flatMap(_._2).toList.asJava

new ArrowRecordBatch(numOfRows, fieldNodes, buffers)
new ArrowRecordBatch(rows.length, fieldNodes, buffers)
}

/**
Expand Down
180 changes: 118 additions & 62 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,100 +17,156 @@

package org.apache.spark.sql

import java.io.{DataInputStream, EOFException, RandomAccessFile}
import java.io._
import java.net.{InetAddress, Socket}
import java.nio.{ByteBuffer, ByteOrder}
import java.nio.channels.FileChannel

import scala.util.Random

import io.netty.buffer.ArrowBuf
import org.apache.arrow.flatbuf.Precision
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.file.ArrowReader
import org.apache.arrow.vector.schema.ArrowRecordBatch
import org.apache.arrow.vector.types.pojo.{ArrowType, Field}

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils


case class ArrowIntTest(a: Int, b: Int)
case class ArrowTestClass(col1: Int, col2: Double, col3: String)

class DatasetToArrowSuite extends QueryTest with SharedSQLContext {

import testImplicits._

final val numElements = 4
@transient var data: Seq[ArrowTestClass] = _

override def beforeAll(): Unit = {
super.beforeAll()
data = Seq.fill(numElements)(ArrowTestClass(
Random.nextInt, Random.nextDouble, Random.nextString(Random.nextInt(100))))
}

test("Collect as arrow to python") {
val dataset = data.toDS()

val port = dataset.collectAsArrowToPython()

val receiver: RecordBatchReceiver = new RecordBatchReceiver
val (buffer, numBytesRead) = receiver.connectAndRead(port)
val channel = receiver.makeFile(buffer)
val reader = new ArrowReader(channel, receiver.allocator)

val footer = reader.readFooter()
val schema = footer.getSchema

val numCols = schema.getFields.size()
assert(numCols === dataset.schema.fields.length)
for (i <- 0 until schema.getFields.size()) {
val arrowField = schema.getFields.get(i)
val sparkField = dataset.schema.fields(i)
assert(arrowField.getName === sparkField.name)
assert(arrowField.isNullable === sparkField.nullable)
assert(DatasetToArrowSuite.compareSchemaTypes(arrowField, sparkField))
}

val blockMetadata = footer.getRecordBatches
assert(blockMetadata.size() === 1)

val recordBatch = reader.readRecordBatch(blockMetadata.get(0))
val nodes = recordBatch.getNodes
assert(nodes.size() === numCols + 1) // +1 for Type String, which has two nodes.

val firstNode = nodes.get(0)
assert(firstNode.getLength === numElements)
assert(firstNode.getNullCount === 0)

val buffers = recordBatch.getBuffers
assert(buffers.size() === (numCols + 1) * 2) // +1 for Type String

assert(receiver.getIntArray(buffers.get(1)) === data.map(_.col1))
assert(receiver.getDoubleArray(buffers.get(3)) === data.map(_.col2))
assert(receiver.getStringArray(buffers.get(5), buffers.get(7)) ===
data.map(d => UTF8String.fromString(d.col3)).toArray)
}
}

val ds = Seq(ArrowIntTest(1, 2), ArrowIntTest(2, 3), ArrowIntTest(3, 4)).toDS()

val port = ds.collectAsArrowToPython()

val clientThread: Thread = new Thread(new Runnable() {
def run() {
try {
val receiver: RecordBatchReceiver = new RecordBatchReceiver
val record: ArrowRecordBatch = receiver.read(port)
}
catch {
case e: Exception =>
throw e
}
}
})

clientThread.start()

try {
clientThread.join()
} catch {
case e: InterruptedException =>
throw e
case _ =>
object DatasetToArrowSuite {
def compareSchemaTypes(arrowField: Field, sparkField: StructField): Boolean = {
val arrowType = arrowField.getType
val sparkType = sparkField.dataType
(arrowType, sparkType) match {
case (_: ArrowType.Int, _: IntegerType) => true
case (_: ArrowType.FloatingPoint, _: DoubleType) =>
arrowType.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.DOUBLE
case (_: ArrowType.FloatingPoint, _: FloatType) =>
arrowType.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.SINGLE
case (_: ArrowType.List, _: StringType) =>
val subField = arrowField.getChildren
(subField.size() == 1) && subField.get(0).getType.isInstanceOf[ArrowType.Utf8]
case (_: ArrowType.Bool, _: BooleanType) => true
case _ => false
}
}
}

class RecordBatchReceiver {

def array(buf: ArrowBuf): Array[Byte] = {
val bytes = Array.ofDim[Byte](buf.readableBytes())
buf.readBytes(bytes)
bytes
val allocator = new RootAllocator(Long.MaxValue)

def getIntArray(buf: ArrowBuf): Array[Int] = {
val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer()
val resultArray = Array.ofDim[Int](buffer.remaining())
buffer.get(resultArray)
resultArray
}

def connectAndRead(port: Int): (Array[Byte], Int) = {
val s = new Socket(InetAddress.getByName("localhost"), port)
val is = s.getInputStream
def getDoubleArray(buf: ArrowBuf): Array[Double] = {
val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer()
val resultArray = Array.ofDim[Double](buffer.remaining())
buffer.get(resultArray)
resultArray
}

val dis = new DataInputStream(is)
val len = dis.readInt()
def getStringArray(bufOffsets: ArrowBuf, bufValues: ArrowBuf): Array[UTF8String] = {
val offsets = getIntArray(bufOffsets)
val lens = offsets.zip(offsets.drop(1))
.map { case (prevOffset, offset) => offset - prevOffset }

val buffer = Array.ofDim[Byte](len)
val bytesRead = dis.read(buffer)
if (bytesRead != len) {
throw new EOFException("Wrong EOF")
val values = array(bufValues)
val strings = offsets.zip(lens).map { case (offset, len) =>
UTF8String.fromBytes(values, offset, len)
}
(buffer, len)
strings
}

def makeFile(buffer: Array[Byte]): FileChannel = {
var aFile = new RandomAccessFile("/tmp/nio-data.txt", "rw")
aFile.write(buffer)
aFile.close()

aFile = new RandomAccessFile("/tmp/nio-data.txt", "r")
val fChannel = aFile.getChannel
fChannel
private def array(buf: ArrowBuf): Array[Byte] = {
val bytes = Array.ofDim[Byte](buf.readableBytes())
buf.readBytes(bytes)
bytes
}

def readRecordBatch(fc: FileChannel, len: Int): ArrowRecordBatch = {
val allocator = new RootAllocator(len)
val reader = new ArrowReader(fc, allocator)
val footer = reader.readFooter()
val schema = footer.getSchema
val blocks = footer.getRecordBatches
val recordBatch = reader.readRecordBatch(blocks.get(0))
recordBatch
def connectAndRead(port: Int): (Array[Byte], Int) = {
val clientSocket = new Socket(InetAddress.getByName("localhost"), port)
val clientDataIns = new DataInputStream(clientSocket.getInputStream)
val messageLength = clientDataIns.readInt()
val buffer = Array.ofDim[Byte](messageLength)
clientDataIns.readFully(buffer, 0, messageLength)
(buffer, messageLength)
}

def read(port: Int): ArrowRecordBatch = {
val (buffer, len) = connectAndRead(port)
val fc = makeFile(buffer)
readRecordBatch(fc, len)
def makeFile(buffer: Array[Byte]): FileChannel = {
val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName).getPath
val arrowFile = new File(tempDir, "arrow-bytes")
val arrowOus = new FileOutputStream(arrowFile.getPath)
arrowOus.write(buffer)
arrowOus.close()

val arrowIns = new FileInputStream(arrowFile.getPath)
arrowIns.getChannel
}
}