diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 7ac2c71c18eb3..e26a3c6aff035 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -31,7 +31,7 @@ import org.apache.spark.annotation.DeveloperApi * If your BlockId should be serializable, be sure to add it to the BlockId.apply() method. */ @DeveloperApi -sealed abstract class BlockId { +abstract class BlockId { /** A globally unique identifier for this Block. Can be used for ser/de. */ def name: String @@ -49,6 +49,11 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { override def name: String = "rdd_" + rddId + "_" + splitIndex } +@DeveloperApi +case class RDDPartitionMetadataBlockId(rddId: Int, splitIndex: Int) extends BlockId { + override def name: String = "rdd_" + rddId + "_" + splitIndex + ".metadata" +} + // Format of the shuffle block ids (including data and index) should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getBlockData(). @DeveloperApi @@ -103,6 +108,7 @@ class UnrecognizedBlockId(name: String) @DeveloperApi object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r + val PARTITION_METADATA = "rdd_([0-9]+)_([0-9]+).metadata".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r @@ -116,6 +122,8 @@ object BlockId { def apply(name: String): BlockId = name match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) + case PARTITION_METADATA(rddId, splitIndex) => + RDDPartitionMetadataBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case SHUFFLE_DATA(shuffleId, mapId, reduceId) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 21e4685fcc456..a079b2042c48b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -137,10 +137,18 @@ object SQLConf { val IN_MEMORY_PARTITION_PRUNING = buildConf("spark.sql.inMemoryColumnarStorage.partitionPruning") .internal() - .doc("When true, enable partition pruning for in-memory columnar tables.") + .doc("When true, enable partition batch pruning for in-memory columnar tables.") .booleanConf .createWithDefault(true) + val IN_MEMORY_PARTITION_METADATA = + buildConf("spark.sql.inMemoryColumnarStorage.partitionMetadata") + .internal() + .doc("When true, spark sql will collect partition level stats for in-memory columnar" + + " tables and do coarse-grained pruning") + .booleanConf + .createWithDefault(false) + val PREFER_SORTMERGEJOIN = buildConf("spark.sql.join.preferSortMergeJoin") .internal() .doc("When true, prefer sort merge join over shuffle hash join.") @@ -1134,6 +1142,8 @@ class SQLConf extends Serializable with Logging { def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) + def inMemoryPartitionMetadata: Boolean = getConf(IN_MEMORY_PARTITION_METADATA) + def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) def broadcastTimeout: Long = getConf(BROADCAST_TIMEOUT) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/CachedColumnarRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/CachedColumnarRDD.scala new file mode 100644 index 0000000000000..2cd26f72ae0f2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/CachedColumnarRDD.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.storage.{RDDPartitionMetadataBlockId, StorageLevel} + +class CachedColumnarRDD( + @transient private var _sc: SparkContext, + private var dataRDD: RDD[CachedBatch], + containsPartitionMetadata: Boolean, + expectedStorageLevel: StorageLevel) + extends RDD[CachedBatch](_sc, Seq(new OneToOneDependency(dataRDD))) { + + override def compute(split: Partition, context: TaskContext): Iterator[CachedBatch] = { + firstParent.iterator(split, context) + } + + override protected def getPartitions: Array[Partition] = dataRDD.partitions + + override private[spark] def getOrCompute(split: Partition, context: TaskContext): + Iterator[CachedBatch] = { + val metadataBlockId = RDDPartitionMetadataBlockId(id, split.index) + val superGetOrCompute: (Partition, TaskContext) => Iterator[CachedBatch] = super.getOrCompute + SparkEnv.get.blockManager.getSingle[InternalRow](metadataBlockId).map(metadataBlock => + new InterruptibleIterator[CachedBatch](context, + new CachedColumnarIterator(metadataBlock, split, context, superGetOrCompute)) + ).getOrElse { + val batchIter = superGetOrCompute(split, context) + if (containsPartitionMetadata && getStorageLevel != StorageLevel.NONE && batchIter.hasNext) { + val cachedBatch = batchIter.next() + SparkEnv.get.blockManager.putSingle(metadataBlockId, cachedBatch.stats, + expectedStorageLevel) + new InterruptibleIterator[CachedBatch](context, Iterator(cachedBatch)) + } else { + batchIter + } + } + } +} + +private[columnar] class CachedColumnarIterator( + val partitionStats: InternalRow, + partition: Partition, + context: TaskContext, + fetchRDDPartition: (Partition, TaskContext) => Iterator[CachedBatch]) + extends Iterator[CachedBatch] { + + private var delegate: Iterator[CachedBatch] = _ + + override def hasNext: Boolean = { + if (delegate == null) { + delegate = fetchRDDPartition(partition, context) + } + delegate.hasNext + } + + override def next(): CachedBatch = { + delegate.next() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index a1c62a729900e..4b5323ec87c76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -52,6 +52,68 @@ object InMemoryRelation { private[columnar] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) +private[columnar] class CachedBatchIterator( + rowIterator: Iterator[InternalRow], + output: Seq[Attribute], + batchSize: Int, + useCompression: Boolean, + batchStats: LongAccumulator, + singleBatchPerPartition: Boolean) extends Iterator[CachedBatch] { + + def next(): CachedBatch = { + val columnBuilders = output.map { attribute => + ColumnBuilder(attribute.dataType, batchSize, attribute.name, useCompression) + }.toArray + + var rowCount = 0 + var totalSize = 0L + + val terminateLoop = (singleBatch: Boolean, rowIter: Iterator[InternalRow], + rowCount: Int, size: Long) => { + if (!singleBatch) { + rowIter.hasNext && rowCount < batchSize && totalSize < ColumnBuilder.MAX_BATCH_SIZE_IN_BYTE + } else { + rowIter.hasNext + } + } + + while (terminateLoop(singleBatchPerPartition, rowIterator, rowCount, totalSize)) { + val row = rowIterator.next() + + // Added for SPARK-6082. This assertion can be useful for scenarios when something + // like Hive TRANSFORM is used. The external data generation script used in TRANSFORM + // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat + // hard to decipher. + assert( + row.numFields == columnBuilders.length, + s"Row column number mismatch, expected ${output.size} columns, " + + s"but got ${row.numFields}." + + s"\nRow content: $row") + + var i = 0 + totalSize = 0 + while (i < row.numFields) { + columnBuilders(i).appendFrom(row, i) + totalSize += columnBuilders(i).columnStats.sizeInBytes + i += 1 + } + rowCount += 1 + } + + batchStats.add(totalSize) + + val statsInSeq = columnBuilders.flatMap(_.columnStats.collectedStatistics) + + val stats = InternalRow.fromSeq(statsInSeq) + + CachedBatch(rowCount, columnBuilders.map { builder => + JavaUtils.bufferToArray(builder.build()) + }, stats) + } + + def hasNext: Boolean = rowIterator.hasNext +} + case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, @@ -69,6 +131,8 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) + private val usePartitionLevelMetadata = conf.inMemoryPartitionMetadata + override def computeStats(): Statistics = { if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information @@ -87,51 +151,14 @@ case class InMemoryRelation( private def buildBuffers(): Unit = { val output = child.output - val cached = child.execute().mapPartitionsInternal { rowIterator => - new Iterator[CachedBatch] { - def next(): CachedBatch = { - val columnBuilders = output.map { attribute => - ColumnBuilder(attribute.dataType, batchSize, attribute.name, useCompression) - }.toArray - - var rowCount = 0 - var totalSize = 0L - while (rowIterator.hasNext && rowCount < batchSize - && totalSize < ColumnBuilder.MAX_BATCH_SIZE_IN_BYTE) { - val row = rowIterator.next() - - // Added for SPARK-6082. This assertion can be useful for scenarios when something - // like Hive TRANSFORM is used. The external data generation script used in TRANSFORM - // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat - // hard to decipher. - assert( - row.numFields == columnBuilders.length, - s"Row column number mismatch, expected ${output.size} columns, " + - s"but got ${row.numFields}." + - s"\nRow content: $row") - - var i = 0 - totalSize = 0 - while (i < row.numFields) { - columnBuilders(i).appendFrom(row, i) - totalSize += columnBuilders(i).columnStats.sizeInBytes - i += 1 - } - rowCount += 1 - } - - batchStats.add(totalSize) - - val stats = InternalRow.fromSeq( - columnBuilders.flatMap(_.columnStats.collectedStatistics)) - CachedBatch(rowCount, columnBuilders.map { builder => - JavaUtils.bufferToArray(builder.build()) - }, stats) - } - - def hasNext: Boolean = rowIterator.hasNext - } - }.persist(storageLevel) + + val batchedRDD = child.execute().mapPartitionsInternal { rowIterator => + new CachedBatchIterator(rowIterator, output, batchSize, useCompression, batchStats, + usePartitionLevelMetadata) + } + + val cached = new CachedColumnarRDD(batchedRDD.sparkContext, batchedRDD, + usePartitionLevelMetadata, storageLevel).persist(storageLevel) cached.setName( tableName.map(n => s"In-memory table $n") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 2ae3f35eb1da1..dde61edbceef4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -17,15 +17,19 @@ package org.apache.spark.sql.execution.columnar +import org.apache.spark.{InterruptibleIterator, SparkEnv} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.storage.RDDPartitionMetadataBlockId case class InMemoryTableScanExec( @@ -180,37 +184,49 @@ case class InMemoryTableScanExec( private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning + private def doFilterCachedBatches( + cachedBatchIterator: Iterator[CachedBatch], + partitionStatsSchema: Seq[AttributeReference], + partitionFilter: GenPredicate): Iterator[CachedBatch] = { + val schemaIndex = partitionStatsSchema.zipWithIndex + cachedBatchIterator.filter { cachedBatch => + if (!partitionFilter.eval(cachedBatch.stats)) { + logDebug { + val statsString = schemaIndex.map { case (a, i) => + val value = cachedBatch.stats.get(i, a.dataType) + s"${a.name}: $value" + }.mkString(", ") + s"Skipping partition based on stats $statsString" + } + false + } else { + true + } + } + } + private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. val schema = relation.partitionStatistics.schema - val schemaIndex = schema.zipWithIndex val buffers = relation.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => + val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), schema) partitionFilter.initialize(index) - // Do partition batch pruning if enabled - if (inMemoryPartitionPruningEnabled) { - cachedBatchIterator.filter { cachedBatch => - if (!partitionFilter.eval(cachedBatch.stats)) { - logDebug { - val statsString = schemaIndex.map { case (a, i) => - val value = cachedBatch.stats.get(i, a.dataType) - s"${a.name}: $value" - }.mkString(", ") - s"Skipping partition based on stats $statsString" - } - false - } else { - true - } - } - } else { - cachedBatchIterator + cachedBatchIterator.asInstanceOf[InterruptibleIterator[_]].delegate match { + case cachedIter: CachedColumnarIterator + if !partitionFilter.eval(cachedIter.partitionStats) => + // scalastyle:off + println(s"skipped partition $index") + Iterator() + case _ => + doFilterCachedBatches(cachedBatchIterator, schema, partitionFilter) + // scalastyle:on } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index e662e294228db..4de87d56b92db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.columnar import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import org.apache.spark.SparkEnv +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning @@ -29,6 +31,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ +import org.apache.spark.storage.{RDDBlockId, RDDPartitionMetadataBlockId} import org.apache.spark.storage.StorageLevel._ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { @@ -479,4 +482,113 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("table cache can prune unnecessary partitions correctly") { + var bytesReadWithoutPruning = 0L + var bytesReadWithPruning = 0L + @volatile var inMemoryPartitionMetadata = false + sparkContext.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val metrics = taskEnd.taskMetrics + if (inMemoryPartitionMetadata) { + bytesReadWithPruning += metrics.inputMetrics.bytesRead + } else { + bytesReadWithoutPruning += metrics.inputMetrics.bytesRead + } + } + }) + Seq("true", "false").foreach { enabled => + withSQLConf(SQLConf.IN_MEMORY_PARTITION_METADATA.key -> enabled) { + inMemoryPartitionMetadata = conf.inMemoryPartitionMetadata + val df1 = (0 until 100000).toDF("value").repartition(4).cache() + df1.where("value >= 99999").collect() + val resultArr = df1.where("value >= 99999").collect() + assert(resultArr.length == 1) + assert(resultArr.head.getInt(0) == 99999) + df1.unpersist(true) + } + } + assert(bytesReadWithoutPruning > 0) + assert(bytesReadWithPruning > 0) + assert(bytesReadWithoutPruning > bytesReadWithPruning * 3) + } + + test("generate correct results when metadata block is removed") { + var bytesReadWithMetadata = 0L + var bytesReadWithoutMetadata = 0L + @volatile var removePartitionMetadata = false + sparkContext.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val metrics = taskEnd.taskMetrics + if (removePartitionMetadata) { + bytesReadWithoutMetadata += metrics.inputMetrics.bytesRead + } else { + bytesReadWithMetadata += metrics.inputMetrics.bytesRead + } + } + }) + Seq("true").foreach { enabled => + withSQLConf(SQLConf.IN_MEMORY_PARTITION_METADATA.key -> enabled) { + removePartitionMetadata = true + val df1 = (0 until 100000).toDF("value").repartition(4).cache() + val inMemoryRelation = df1.queryExecution.optimizedPlan.collect { + case m: InMemoryRelation => m + } + df1.where("value >= 99999").collect() + (0 until 4).foreach(partitionId => SparkEnv.get.blockManager.removeBlock( + RDDPartitionMetadataBlockId(inMemoryRelation.head.cachedColumnBuffers.id, partitionId))) + var resultArr = df1.where("value >= 99999").collect() + assert(resultArr.length === 1) + assert(resultArr.head.getInt(0) === 99999) + // scalastyle:off + removePartitionMetadata = false + resultArr = df1.where("value >= 99999").collect() + assert(resultArr.length === 1) + assert(resultArr.head.getInt(0) === 99999) + df1.unpersist(blocking = true) + assert(bytesReadWithMetadata > 0) + assert(bytesReadWithoutMetadata > 0) + assert(bytesReadWithoutMetadata > bytesReadWithMetadata * 3) + } + } + } + + test("generate correct results when data block is removed") { + var bytesReadWithCachedBlock = 0L + var bytesReadWithoutCachedBlock = 0L + @volatile var removeCachedBlock = false + sparkContext.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val metrics = taskEnd.taskMetrics + if (removeCachedBlock) { + bytesReadWithoutCachedBlock += metrics.inputMetrics.bytesRead + } else { + bytesReadWithCachedBlock += metrics.inputMetrics.bytesRead + } + } + }) + Seq("true").foreach { enabled => + withSQLConf(SQLConf.IN_MEMORY_PARTITION_METADATA.key -> enabled) { + removeCachedBlock = true + val df1 = (0 until 100000).toDF("value").repartition(4).cache() + val inMemoryRelation = df1.queryExecution.optimizedPlan.collect { + case m: InMemoryRelation => m + } + df1.where("value >= 99999").collect() + (0 until 4).foreach(partitionId => SparkEnv.get.blockManager.removeBlock( + RDDBlockId(inMemoryRelation.head.cachedColumnBuffers.id, partitionId))) + var resultArr = df1.where("value >= 99999").collect() + assert(resultArr.length === 1) + assert(resultArr.head.getInt(0) === 99999) + // scalastyle:off + removeCachedBlock = false + resultArr = df1.where("value >= 99999").collect() + assert(resultArr.length === 1) + assert(resultArr.head.getInt(0) === 99999) + df1.unpersist(blocking = true) + assert(bytesReadWithCachedBlock > 0) + assert(bytesReadWithoutCachedBlock == 0) + } + } + } }