Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials
import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}

import org.apache.spark._
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
Expand Down
23 changes: 14 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql

import io.netty.buffer.ArrowBuf

import java.io.{ByteArrayOutputStream, CharArrayWriter}
import java.nio.channels.Channels

Expand All @@ -27,11 +25,11 @@ import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import io.netty.buffer.ArrowBuf
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.file.ArrowWriter
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
import org.apache.arrow.vector.types.pojo.{Field, ArrowType, Schema}

import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
import org.apache.commons.lang3.StringUtils

import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
Expand Down Expand Up @@ -59,6 +57,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils


private[sql] object Dataset {
def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
Expand Down Expand Up @@ -2314,23 +2313,29 @@ class Dataset[T] private[sql](
withNewExecutionId {
try {

/*def toArrow(internalRow: InternalRow): ArrowBuf = {
def toArrow(internalRow: InternalRow): ArrowBuf = {
val buf = rootAllocator.buffer(128) // TODO - size??
// TODO internalRow -> buf
buf.setInt(0, 1)
buf
}
val iter = queryExecution.executedPlan.executeCollect().iterator.map(toArrow)
val iter = queryExecution.executedPlan.executeCollect().map(toArrow)
val arrowBufList = iter.toList
val nodes: List[ArrowFieldNode] = null // TODO
new ArrowRecordBatch(arrowBufList.length, nodes.asJava, arrowBufList.asJava)*/
new ArrowRecordBatch(arrowBufList.length, nodes.asJava, arrowBufList.asJava)

/*
val validity = Array[Byte](255.asInstanceOf[Byte], 0)
val values = Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
val validityb = buf(validity)
val valuesb = buf(values)
new ArrowRecordBatch(16, List(new ArrowFieldNode(16, 8)).asJava, List(validityb, valuesb).asJava)
new ArrowRecordBatch(
16, List(new ArrowFieldNode(16, 8)).asJava, List(validityb, valuesb).asJava)
*/
} catch {
case e: Exception =>
//logError(s"Error converting InternalRow to ArrowBuf; ${e.getMessage}:\n$queryExecution")
// logError
// (s"Error converting InternalRow to ArrowBuf; ${e.getMessage}:\n$queryExecution")
throw e
}
}
Expand Down
52 changes: 51 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@

package org.apache.spark.sql

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.io._
import java.net.{InetAddress, InetSocketAddress, Socket}
import java.nio.ByteBuffer
import java.nio.channels.{Channels, FileChannel, SocketChannel}
import java.sql.{Date, Timestamp}

import io.netty.buffer.ArrowBuf
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.file.ArrowReader

import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.streaming.MemoryStream
Expand Down Expand Up @@ -919,6 +926,49 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
df.withColumn("b", expr("0")).as[ClassData]
.groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
}

def array(buf: ArrowBuf): Array[Byte] = {
val bytes = Array.ofDim[Byte](buf.readableBytes())
buf.readBytes(bytes)
bytes
}

test("Collect as arrow to python") {
val ds = Seq(1).toDS()
val port = ds.collectAsArrowToPython()

val s = new Socket(InetAddress.getByName("localhost"), port)
val is = s.getInputStream

val dis = new DataInputStream(is)
val len = dis.readInt()
val allocator = new RootAllocator(len)

val buffer = Array.ofDim[Byte](len)
val bytes = dis.read(buffer)


var aFile = new RandomAccessFile("/tmp/nio-data.txt", "rw")
aFile.write(bytes)
aFile.close()

aFile = new RandomAccessFile("/tmp/nio-data.txt", "r")
val fChannel = aFile.getChannel

val reader = new ArrowReader(fChannel, allocator)
val footer = reader.readFooter()
val schema = footer.getSchema
val blocks = footer.getRecordBatches
val recordBatch = reader.readRecordBatch(blocks.get(0))

val nodes = recordBatch.getNodes
val buffers = recordBatch.getBuffers

// scalastyle:off println
println(array(buffers.get(0)).mkString(", "))
println(array(buffers.get(1)).mkString(", "))
// scalastyle:on println
}
}

case class Generic[T](id: T, value: Double)
Expand Down