diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 479934a7afc75..c68ea0666ed2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.CollectionAccumulator +import org.apache.spark.util.AccumulatorV2 object InMemoryRelation { @@ -44,6 +44,70 @@ object InMemoryRelation { new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() } +/** + * Accumulator for storing column stats. Summarizes the data in the driver to curb the amount of + * memory being used. Only "sizeInBytes" for each column is kept. + */ +class ColStatsAccumulator(originalOutput: Seq[Attribute]) + extends AccumulatorV2[Seq[ColumnStats], Array[Long]] { + + private var stats: Array[Long] = null + + override def isZero: Boolean = stats == null + + override def copy(): AccumulatorV2[Seq[ColumnStats], Array[Long]] = { + val newAcc = new ColStatsAccumulator(originalOutput) + newAcc.stats = stats + newAcc + } + + override def reset(): Unit = { + stats = null + } + + override def add(update: Seq[ColumnStats]): Unit = { + if (update != null) { + require(isZero || stats.length == update.size, "Input stats doesn't match expected size.") + + val newStats = new Array[Long](update.size) + + update.toIndexedSeq.zipWithIndex.foreach { case (colStats, idx) => + val current = if (!isZero) stats(idx) else 0L + newStats(idx) = current + colStats.sizeInBytes + } + + stats = newStats + } + } + + override def merge(other: AccumulatorV2[Seq[ColumnStats], Array[Long]]): Unit = { + if (other.value != null) { + require(isZero || stats.length == other.value.length, + "Merging accumulators of different size.") + + val newStats = new Array[Long](other.value.length) + for (i <- 0 until other.value.size) { + val current = if (!isZero) stats(i) else 0L + newStats(i) = current + other.value(i) + } + stats = newStats + } + } + + override def value: Array[Long] = stats + + /** + * Calculate the size of the relation for a given output. Adds up all the known column sizes + * that match the desired output. + */ + def sizeForOutput(output: Seq[Attribute]): Long = { + originalOutput.toIndexedSeq.zipWithIndex.map { case (a, idx) => + val count = output.count(a.semanticEquals) + stats(idx) * count + }.fold(0L)(_ + _) + } + +} /** * CachedBatch is a cached batch of rows. @@ -63,8 +127,7 @@ case class InMemoryRelation( @transient child: SparkPlan, tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val batchStats: CollectionAccumulator[InternalRow] = - child.sqlContext.sparkContext.collectionAccumulator[InternalRow]) + _batchStats: ColStatsAccumulator = null) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) @@ -73,22 +136,23 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) + val batchStats = if (_batchStats != null) { + _batchStats + } else { + val _newStats = new ColStatsAccumulator(output) + child.sqlContext.sparkContext.register(_newStats) + _newStats + } + override lazy val statistics: Statistics = { - if (batchStats.value.isEmpty) { + if (batchStats.isZero) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) } else { // Underlying columnar RDD has been materialized, required information has also been // collected via the `batchStats` accumulator. - val sizeOfRow: Expression = - BindReferences.bindReference( - output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add), - partitionStatistics.schema) - - val sizeInBytes = - batchStats.value.asScala.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum - Statistics(sizeInBytes = sizeInBytes) + Statistics(sizeInBytes = batchStats.sizeForOutput(output)) } } @@ -139,13 +203,13 @@ case class InMemoryRelation( rowCount += 1 } - val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) - .flatMap(_.values)) - + val stats = columnBuilders.map(_.columnStats) batchStats.add(stats) + + val statsRow = InternalRow.fromSeq(stats.map(_.collectedStatistics).flatMap(_.values)) CachedBatch(rowCount, columnBuilders.map { builder => JavaUtils.bufferToArray(builder.build()) - }, stats) + }, statsRow) } def hasNext: Boolean = rowIterator.hasNext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 937839644ad5f..96281dedf379e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -232,4 +232,29 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val columnTypes2 = List.fill(length2)(IntegerType) val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2) } + + test("SPARK-17549: cached table size should be correctly calculated") { + val data = spark.sparkContext.parallelize(1 to 10, 5).map { i => (i, i.toLong) } + .toDF("col1", "col2") + val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None) + + // Materialize the data. + val expectedAnswer = data.collect() + checkAnswer(cached, expectedAnswer) + + // Check that the right size was calculated. + val expectedColSizes = expectedAnswer.size * (INT.defaultSize + LONG.defaultSize) + assert(cached.statistics.sizeInBytes === expectedColSizes) + + // Create a projection of the cached data and make sure the statistics are correct. + val projected = cached.withOutput(Seq(plan.output.last)) + assert(projected.statistics.sizeInBytes === expectedAnswer.size * LONG.defaultSize) + + // Create a silly projection that repeats columns of the first cached relation, and + // check that the size is calculated correctly. + val projected2 = cached.withOutput(Seq(plan.output.last, plan.output.last)) + assert(projected2.statistics.sizeInBytes === 2 * expectedAnswer.size * LONG.defaultSize) + } + }