diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 1258ccf4579cb..0ca91b9bf86c6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -26,10 +26,12 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat} + import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bd7a9de5c1d81..7b69a5c74397a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import io.netty.buffer.ArrowBuf - import java.io.{ByteArrayOutputStream, CharArrayWriter} import java.nio.channels.Channels @@ -27,11 +25,11 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal +import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.file.ArrowWriter import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.pojo.{Field, ArrowType, Schema} - +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} @@ -59,6 +57,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils + private[sql] object Dataset { def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) @@ -2314,23 +2313,29 @@ class Dataset[T] private[sql]( withNewExecutionId { try { - /*def toArrow(internalRow: InternalRow): ArrowBuf = { + def toArrow(internalRow: InternalRow): ArrowBuf = { val buf = rootAllocator.buffer(128) // TODO - size?? // TODO internalRow -> buf + buf.setInt(0, 1) buf } - val iter = queryExecution.executedPlan.executeCollect().iterator.map(toArrow) + val iter = queryExecution.executedPlan.executeCollect().map(toArrow) val arrowBufList = iter.toList val nodes: List[ArrowFieldNode] = null // TODO - new ArrowRecordBatch(arrowBufList.length, nodes.asJava, arrowBufList.asJava)*/ + new ArrowRecordBatch(arrowBufList.length, nodes.asJava, arrowBufList.asJava) + + /* val validity = Array[Byte](255.asInstanceOf[Byte], 0) val values = Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) val validityb = buf(validity) val valuesb = buf(values) - new ArrowRecordBatch(16, List(new ArrowFieldNode(16, 8)).asJava, List(validityb, valuesb).asJava) + new ArrowRecordBatch( + 16, List(new ArrowFieldNode(16, 8)).asJava, List(validityb, valuesb).asJava) + */ } catch { case e: Exception => - //logError(s"Error converting InternalRow to ArrowBuf; ${e.getMessage}:\n$queryExecution") + // logError + // (s"Error converting InternalRow to ArrowBuf; ${e.getMessage}:\n$queryExecution") throw e } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index cc367acae2ba4..192c8b9958e07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -17,9 +17,16 @@ package org.apache.spark.sql -import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.io._ +import java.net.{InetAddress, InetSocketAddress, Socket} +import java.nio.ByteBuffer +import java.nio.channels.{Channels, FileChannel, SocketChannel} import java.sql.{Date, Timestamp} +import io.netty.buffer.ArrowBuf +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.file.ArrowReader + import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.streaming.MemoryStream @@ -919,6 +926,49 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.withColumn("b", expr("0")).as[ClassData] .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() }) } + + def array(buf: ArrowBuf): Array[Byte] = { + val bytes = Array.ofDim[Byte](buf.readableBytes()) + buf.readBytes(bytes) + bytes + } + + test("Collect as arrow to python") { + val ds = Seq(1).toDS() + val port = ds.collectAsArrowToPython() + + val s = new Socket(InetAddress.getByName("localhost"), port) + val is = s.getInputStream + + val dis = new DataInputStream(is) + val len = dis.readInt() + val allocator = new RootAllocator(len) + + val buffer = Array.ofDim[Byte](len) + val bytes = dis.read(buffer) + + + var aFile = new RandomAccessFile("/tmp/nio-data.txt", "rw") + aFile.write(bytes) + aFile.close() + + aFile = new RandomAccessFile("/tmp/nio-data.txt", "r") + val fChannel = aFile.getChannel + + val reader = new ArrowReader(fChannel, allocator) + val footer = reader.readFooter() + val schema = footer.getSchema + val blocks = footer.getRecordBatches + val recordBatch = reader.readRecordBatch(blocks.get(0)) + + val nodes = recordBatch.getNodes + val buffers = recordBatch.getBuffers + + // scalastyle:off println + println(array(buffers.get(0)).mkString(", ")) + println(array(buffers.get(1)).mkString(", ")) + // scalastyle:on println + } } case class Generic[T](id: T, value: Double)