diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 1afe83ea3539e..eb01e126bcbef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType @@ -31,8 +30,6 @@ import org.apache.spark.sql.types.DataType */ private[sql] trait ColumnarBatchScan extends CodegenSupport { - val inMemoryTableScan: InMemoryTableScanExec = null - def vectorTypes: Option[Seq[String]] = None override lazy val metrics = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 1aaaf896692d1..e37d133ff336a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -282,6 +282,18 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp object WholeStageCodegenExec { val PIPELINE_DURATION_METRIC = "duration" + + private def numOfNestedFields(dataType: DataType): Int = dataType match { + case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum + case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) + case a: ArrayType => numOfNestedFields(a.elementType) + case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) + case _ => 1 + } + + def isTooManyFields(conf: SQLConf, dataType: DataType): Boolean = { + numOfNestedFields(dataType) > conf.wholeStageMaxNumFields + } } /** @@ -490,22 +502,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case _ => true } - private def numOfNestedFields(dataType: DataType): Int = dataType match { - case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum - case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) - case a: ArrayType => numOfNestedFields(a.elementType) - case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) - case _ => 1 - } - private def supportCodegen(plan: SparkPlan): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns val hasTooManyOutputFields = - numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields + WholeStageCodegenExec.isTooManyFields(conf, plan.schema) val hasTooManyInputFields = - plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields) + plan.children.exists(p => WholeStageCodegenExec.isTooManyFields(conf, p.schema)) !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 24c8ac81420cb..445933d98e9d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -163,4 +163,12 @@ private[sql] object ColumnAccessor { throw new RuntimeException("Not support non-primitive type now") } } + + def decompress( + array: Array[Byte], columnVector: WritableColumnVector, dataType: DataType, numRows: Int): + Unit = { + val byteBuffer = ByteBuffer.wrap(array) + val columnAccessor = ColumnAccessor(dataType, byteBuffer) + decompress(columnAccessor, columnVector, numRows) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index af3636a5a2ca7..3307c091a7a86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,21 +23,66 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.execution.LeafExecNode -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.UserDefinedType +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.vectorized._ +import org.apache.spark.sql.types._ case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @transient relation: InMemoryRelation) - extends LeafExecNode { + extends LeafExecNode with ColumnarBatchScan { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override def vectorTypes: Option[Seq[String]] = + Option(Seq.fill(attributes.length)(classOf[OnHeapColumnVector].getName)) + + /** + * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. + * If false, get data from UnsafeRow build from ColumnVector + */ + override val supportCodegen: Boolean = { + // In the initial implementation, for ease of review + // support only primitive data types and # of fields is less than wholeStageMaxNumFields + relation.schema.fields.forall(f => f.dataType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType => true + case _ => false + }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) + } + + private val columnIndices = + attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray + + private val relationSchema = relation.schema.toArray + + private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) + + private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + val rowCount = cachedColumnarBatch.numRows + val columnVectors = OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) + val columnarBatch = new ColumnarBatch( + columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) + columnarBatch.setNumRows(rowCount) + + for (i <- 0 until attributes.length) { + ColumnAccessor.decompress( + cachedColumnarBatch.buffers(columnIndices(i)), + columnarBatch.column(i).asInstanceOf[WritableColumnVector], + columnarBatchSchema.fields(i).dataType, rowCount) + } + columnarBatch + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + assert(supportCodegen) + val buffers = relation.cachedColumnBuffers + // HACK ALERT: This is actually an RDD[ColumnarBatch]. + // We're taking advantage of Scala's type erasure here to pass these batches along. + Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) + } override def output: Seq[Attribute] = attributes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index fe6ba83b4cbfb..0881212a64de8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -73,4 +73,40 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(outerStruct)) } + + test("primitive data type accesses in persist data") { + val data = Seq(true, 1.toByte, 3.toShort, 7, 15.toLong, + 31.25.toFloat, 63.75, null) + val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, IntegerType) + val schemas = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data))) + val df = spark.createDataFrame(rdd, StructType(schemas)) + val row = df.persist.take(1).apply(0) + checkAnswer(df, row) + } + + test("access cache multiple times") { + val df0 = sparkContext.parallelize(Seq(1, 2, 3), 1).toDF("x").cache + df0.count + val df1 = df0.filter("x > 1") + checkAnswer(df1, Seq(Row(2), Row(3))) + val df2 = df0.filter("x > 2") + checkAnswer(df2, Row(3)) + + val df10 = sparkContext.parallelize(Seq(3, 4, 5, 6), 1).toDF("x").cache + for (_ <- 0 to 2) { + val df11 = df10.filter("x > 5") + checkAnswer(df11, Row(6)) + } + } + + test("access only some column of the all of columns") { + val df = spark.range(1, 10).map(i => (i, (i + 1).toDouble)).toDF("l", "d") + df.cache + df.count + assert(df.filter("d < 3").count == 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 098e4cfeb15b2..bc05dca578c47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed @@ -117,6 +118,37 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } + test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { + import testImplicits._ + + val dsInt = spark.range(3).cache + dsInt.count + val dsIntFilter = dsInt.filter(_ > 0) + val planInt = dsIntFilter.queryExecution.executedPlan + assert(planInt.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && + p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .isInstanceOf[InMemoryTableScanExec] && + p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined + ) + assert(dsIntFilter.collect() === Array(1, 2)) + + // cache for string type is not supported for InMemoryTableScanExec + val dsString = spark.range(3).map(_.toString).cache + dsString.count + val dsStringFilter = dsString.filter(_ == "1") + val planString = dsStringFilter.queryExecution.executedPlan + assert(planString.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && + !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .isInstanceOf[InMemoryTableScanExec]).isDefined + ) + assert(dsStringFilter.collect() === Array("1")) + } + test("SPARK-19512 codegen for comparing structs is incorrect") { // this would raise CompileException before the fix spark.range(10)