Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

object CatalystSerde {
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
Expand Down Expand Up @@ -210,13 +211,48 @@ case class TypedFilter(
def typedCondition(input: Expression): Expression = {
val (funcClass, methodName) = func match {
case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call"
case _ => classOf[Any => Boolean] -> "apply"
case _ => FunctionUtils.getFunctionOneName(BooleanType, input.dataType)
}
val funcObj = Literal.create(func, ObjectType(funcClass))
Invoke(funcObj, methodName, BooleanType, input :: Nil)
}
}

object FunctionUtils {
private def getMethodType(dt: DataType, isOutput: Boolean): Option[String] = {
dt match {
case BooleanType if isOutput => Some("Z")
case IntegerType => Some("I")
case LongType => Some("J")
case FloatType => Some("F")
case DoubleType => Some("D")
case _ => None
}
}

def getFunctionOneName(outputDT: DataType, inputDT: DataType): (Class[_], String) = {
// load "scala.Function1" using Java API to avoid requirements of type parameters
Utils.classForName("scala.Function1") -> {
// if a pair of an argument and return types is one of specific types
// whose specialized method (apply$mc..$sp) is generated by scalac,
// Catalyst generated a direct method call to the specialized method.
// The followings are references for this specialization:
// http://www.scala-lang.org/api/2.12.0/scala/Function1.html
// https://github.com/scala/scala/blob/2.11.x/src/compiler/scala/tools/nsc/transform/
// SpecializeTypes.scala
// http://www.cakesolutions.net/teamblogs/scala-dissection-functions
// http://axel22.github.io/2013/11/03/specialization-quirks.html
val inputType = getMethodType(inputDT, false)
val outputType = getMethodType(outputDT, true)
if (inputType.isDefined && outputType.isDefined) {
s"apply$$mc${outputType.get}${inputType.get}$$sp"
} else {
"apply"
}
}
}
}

/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T : Encoder, U : Encoder](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
import org.apache.spark.sql.execution.streaming.KeyedStateImpl
import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


/**
Expand Down Expand Up @@ -219,7 +221,7 @@ case class MapElementsExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val (funcClass, methodName) = func match {
case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
case _ => classOf[Any => Any] -> "apply"
case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType)
}
val funcObj = Literal.create(func, ObjectType(funcClass))
val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output)
Expand Down
122 changes: 116 additions & 6 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,49 @@ object DatasetBenchmark {

case class Data(l: Long, s: String)

def backToBackMapLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._

val rdd = spark.sparkContext.range(0, numRows)
val ds = spark.range(0, numRows)
val df = ds.toDF("l")
val func = (l: Long) => l + 1

val benchmark = new Benchmark("back-to-back map long", numRows)

benchmark.addCase("RDD") { iter =>
var res = rdd
var i = 0
while (i < numChains) {
res = res.map(func)
i += 1
}
res.foreach(_ => Unit)
}

benchmark.addCase("DataFrame") { iter =>
var res = df
var i = 0
while (i < numChains) {
res = res.select($"l" + 1 as "l")
i += 1
}
res.queryExecution.toRdd.foreach(_ => Unit)
}

benchmark.addCase("Dataset") { iter =>
var res = ds.as[Long]
var i = 0
while (i < numChains) {
res = res.map(func)
i += 1
}
res.queryExecution.toRdd.foreach(_ => Unit)
}

benchmark
}

def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._

Expand Down Expand Up @@ -72,6 +115,49 @@ object DatasetBenchmark {
benchmark
}

def backToBackFilterLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._

val rdd = spark.sparkContext.range(1, numRows)
val ds = spark.range(1, numRows)
val df = ds.toDF("l")
val func = (l: Long) => l % 2L == 0L

val benchmark = new Benchmark("back-to-back filter Long", numRows)

benchmark.addCase("RDD") { iter =>
var res = rdd
var i = 0
while (i < numChains) {
res = res.filter(func)
i += 1
}
res.foreach(_ => Unit)
}

benchmark.addCase("DataFrame") { iter =>
var res = df
var i = 0
while (i < numChains) {
res = res.filter($"l" % 2L === 0L)
i += 1
}
res.queryExecution.toRdd.foreach(_ => Unit)
}

benchmark.addCase("Dataset") { iter =>
var res = ds.as[Long]
var i = 0
while (i < numChains) {
res = res.filter(func)
i += 1
}
res.queryExecution.toRdd.foreach(_ => Unit)
}

benchmark
}

def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._

Expand Down Expand Up @@ -165,9 +251,22 @@ object DatasetBenchmark {
val numRows = 100000000
val numChains = 10

val benchmark = backToBackMap(spark, numRows, numChains)
val benchmark2 = backToBackFilter(spark, numRows, numChains)
val benchmark3 = aggregate(spark, numRows)
val benchmark0 = backToBackMapLong(spark, numRows, numChains)
val benchmark1 = backToBackMap(spark, numRows, numChains)
val benchmark2 = backToBackFilterLong(spark, numRows, numChains)
val benchmark3 = backToBackFilter(spark, numRows, numChains)
val benchmark4 = aggregate(spark, numRows)

/*
OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
back-to-back map long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
RDD 1883 / 1892 53.1 18.8 1.0X
DataFrame 502 / 642 199.1 5.0 3.7X
Dataset 657 / 784 152.2 6.6 2.9X
*/
benchmark0.run()

/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
Expand All @@ -178,7 +277,18 @@ object DatasetBenchmark {
DataFrame 2647 / 3116 37.8 26.5 1.3X
Dataset 4781 / 5155 20.9 47.8 0.7X
*/
benchmark.run()
benchmark1.run()

/*
OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
back-to-back filter Long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
RDD 846 / 1120 118.1 8.5 1.0X
DataFrame 270 / 329 370.9 2.7 3.1X
Dataset 545 / 789 183.5 5.4 1.6X
*/
benchmark2.run()

/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
Expand All @@ -189,7 +299,7 @@ object DatasetBenchmark {
DataFrame 59 / 72 1695.4 0.6 22.8X
Dataset 2777 / 2805 36.0 27.8 0.5X
*/
benchmark2.run()
benchmark3.run()

/*
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1
Expand All @@ -201,6 +311,6 @@ object DatasetBenchmark {
Dataset sum using Aggregator 4656 / 4758 21.5 46.6 0.4X
Dataset complex Aggregator 6636 / 7039 15.1 66.4 0.3X
*/
benchmark3.run()
benchmark4.run()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,64 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
}

test("mapPrimitive") {
val dsInt = Seq(1, 2, 3).toDS()
checkDataset(dsInt.map(_ > 1), false, true, true)
checkDataset(dsInt.map(_ + 1), 2, 3, 4)
checkDataset(dsInt.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L)
checkDataset(dsInt.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
checkDataset(dsInt.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)

val dsLong = Seq(1L, 2L, 3L).toDS()
checkDataset(dsLong.map(_ > 1), false, true, true)
checkDataset(dsLong.map(e => (e + 1).toInt), 2, 3, 4)
checkDataset(dsLong.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L)
checkDataset(dsLong.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
checkDataset(dsLong.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)

val dsFloat = Seq(1F, 2F, 3F).toDS()
checkDataset(dsFloat.map(_ > 1), false, true, true)
checkDataset(dsFloat.map(e => (e + 1).toInt), 2, 3, 4)
checkDataset(dsFloat.map(e => (e + 123456L).toLong), 123457L, 123458L, 123459L)
checkDataset(dsFloat.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
checkDataset(dsFloat.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)

val dsDouble = Seq(1D, 2D, 3D).toDS()
checkDataset(dsDouble.map(_ > 1), false, true, true)
checkDataset(dsDouble.map(e => (e + 1).toInt), 2, 3, 4)
checkDataset(dsDouble.map(e => (e + 8589934592L).toLong),
8589934593L, 8589934594L, 8589934595L)
checkDataset(dsDouble.map(e => (e + 1.1F).toFloat), 2.1F, 3.1F, 4.1F)
checkDataset(dsDouble.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)

val dsBoolean = Seq(true, false).toDS()
checkDataset(dsBoolean.map(e => !e), false, true)
}

test("filter") {
val ds = Seq(1, 2, 3, 4).toDS()
checkDataset(
ds.filter(_ % 2 == 0),
2, 4)
}

test("filterPrimitive") {
val dsInt = Seq(1, 2, 3).toDS()
checkDataset(dsInt.filter(_ > 1), 2, 3)

val dsLong = Seq(1L, 2L, 3L).toDS()
checkDataset(dsLong.filter(_ > 1), 2L, 3L)

val dsFloat = Seq(1F, 2F, 3F).toDS()
checkDataset(dsFloat.filter(_ > 1), 2F, 3F)

val dsDouble = Seq(1D, 2D, 3D).toDS()
checkDataset(dsDouble.filter(_ > 1), 2D, 3D)

val dsBoolean = Seq(true, false).toDS()
checkDataset(dsBoolean.filter(e => !e), false)
}

test("foreach") {
val ds = Seq(1, 2, 3).toDS()
val acc = sparkContext.longAccumulator
Expand Down