Skip to content

Commit b06e11f

Browse files
authored
Merge pull request #4 from yinxusen/wip-toPandas_with_arrow-SPARK-13534
add a local test for spark-side arrow
2 parents 3f855ec + 1d6e5b9 commit b06e11f

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql
1919

20-
import io.netty.buffer.ArrowBuf
21-
2220
import java.io.{ByteArrayOutputStream, CharArrayWriter}
2321
import java.nio.channels.Channels
2422

@@ -27,11 +25,11 @@ import scala.language.implicitConversions
2725
import scala.reflect.runtime.universe.TypeTag
2826
import scala.util.control.NonFatal
2927

28+
import io.netty.buffer.ArrowBuf
3029
import org.apache.arrow.memory.RootAllocator
3130
import org.apache.arrow.vector.file.ArrowWriter
3231
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
33-
import org.apache.arrow.vector.types.pojo.{Field, ArrowType, Schema}
34-
32+
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
3533
import org.apache.commons.lang3.StringUtils
3634

3735
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
@@ -59,6 +57,7 @@ import org.apache.spark.sql.types._
5957
import org.apache.spark.storage.StorageLevel
6058
import org.apache.spark.util.Utils
6159

60+
6261
private[sql] object Dataset {
6362
def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
6463
new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
@@ -2314,23 +2313,29 @@ class Dataset[T] private[sql](
23142313
withNewExecutionId {
23152314
try {
23162315

2317-
/*def toArrow(internalRow: InternalRow): ArrowBuf = {
2316+
def toArrow(internalRow: InternalRow): ArrowBuf = {
23182317
val buf = rootAllocator.buffer(128) // TODO - size??
23192318
// TODO internalRow -> buf
2319+
buf.setInt(0, 1)
23202320
buf
23212321
}
2322-
val iter = queryExecution.executedPlan.executeCollect().iterator.map(toArrow)
2322+
val iter = queryExecution.executedPlan.executeCollect().map(toArrow)
23232323
val arrowBufList = iter.toList
23242324
val nodes: List[ArrowFieldNode] = null // TODO
2325-
new ArrowRecordBatch(arrowBufList.length, nodes.asJava, arrowBufList.asJava)*/
2325+
new ArrowRecordBatch(arrowBufList.length, nodes.asJava, arrowBufList.asJava)
2326+
2327+
/*
23262328
val validity = Array[Byte](255.asInstanceOf[Byte], 0)
23272329
val values = Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
23282330
val validityb = buf(validity)
23292331
val valuesb = buf(values)
2330-
new ArrowRecordBatch(16, List(new ArrowFieldNode(16, 8)).asJava, List(validityb, valuesb).asJava)
2332+
new ArrowRecordBatch(
2333+
16, List(new ArrowFieldNode(16, 8)).asJava, List(validityb, valuesb).asJava)
2334+
*/
23312335
} catch {
23322336
case e: Exception =>
2333-
//logError(s"Error converting InternalRow to ArrowBuf; ${e.getMessage}:\n$queryExecution")
2337+
// logError
2338+
// (s"Error converting InternalRow to ArrowBuf; ${e.getMessage}:\n$queryExecution")
23342339
throw e
23352340
}
23362341
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,16 @@
1717

1818
package org.apache.spark.sql
1919

20-
import java.io.{Externalizable, ObjectInput, ObjectOutput}
20+
import java.io._
21+
import java.net.{InetAddress, InetSocketAddress, Socket}
22+
import java.nio.ByteBuffer
23+
import java.nio.channels.{Channels, FileChannel, SocketChannel}
2124
import java.sql.{Date, Timestamp}
2225

26+
import io.netty.buffer.ArrowBuf
27+
import org.apache.arrow.memory.RootAllocator
28+
import org.apache.arrow.vector.file.ArrowReader
29+
2330
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
2431
import org.apache.spark.sql.catalyst.util.sideBySide
2532
import org.apache.spark.sql.execution.streaming.MemoryStream
@@ -919,6 +926,49 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
919926
df.withColumn("b", expr("0")).as[ClassData]
920927
.groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
921928
}
929+
930+
def array(buf: ArrowBuf): Array[Byte] = {
931+
val bytes = Array.ofDim[Byte](buf.readableBytes())
932+
buf.readBytes(bytes)
933+
bytes
934+
}
935+
936+
test("Collect as arrow to python") {
937+
val ds = Seq(1).toDS()
938+
val port = ds.collectAsArrowToPython()
939+
940+
val s = new Socket(InetAddress.getByName("localhost"), port)
941+
val is = s.getInputStream
942+
943+
val dis = new DataInputStream(is)
944+
val len = dis.readInt()
945+
val allocator = new RootAllocator(len)
946+
947+
val buffer = Array.ofDim[Byte](len)
948+
val bytes = dis.read(buffer)
949+
950+
951+
var aFile = new RandomAccessFile("/tmp/nio-data.txt", "rw")
952+
aFile.write(bytes)
953+
aFile.close()
954+
955+
aFile = new RandomAccessFile("/tmp/nio-data.txt", "r")
956+
val fChannel = aFile.getChannel
957+
958+
val reader = new ArrowReader(fChannel, allocator)
959+
val footer = reader.readFooter()
960+
val schema = footer.getSchema
961+
val blocks = footer.getRecordBatches
962+
val recordBatch = reader.readRecordBatch(blocks.get(0))
963+
964+
val nodes = recordBatch.getNodes
965+
val buffers = recordBatch.getBuffers
966+
967+
// scalastyle:off println
968+
println(array(buffers.get(0)).mkString(", "))
969+
println(array(buffers.get(1)).mkString(", "))
970+
// scalastyle:on println
971+
}
922972
}
923973

924974
case class Generic[T](id: T, value: Double)

0 commit comments

Comments
 (0)