Skip to content

Commit b6a36be

Browse files
committed
address the comments
1 parent 4650307 commit b6a36be

File tree

2 files changed

+23
-24
lines changed

2 files changed

+23
-24
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@ import java.util.concurrent.locks.ReentrantReadWriteLock
2222
import scala.collection.JavaConverters._
2323

2424
import org.apache.hadoop.fs.{FileSystem, Path}
25-
2625
import org.apache.spark.internal.Logging
2726
import org.apache.spark.sql.{Dataset, SparkSession}
2827
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
29-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
28+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
3029
import org.apache.spark.sql.execution.columnar.InMemoryRelation
3130
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
3231
import org.apache.spark.storage.StorageLevel
@@ -80,6 +79,14 @@ class CacheManager extends Logging {
8079
cachedData.isEmpty
8180
}
8281

82+
private def extractStatsOfPlanForCache(plan: LogicalPlan): Option[Statistics] = {
83+
if (plan.conf.cboEnabled && plan.stats.rowCount.isDefined) {
84+
Some(plan.stats)
85+
} else {
86+
None
87+
}
88+
}
89+
8390
/**
8491
* Caches the data produced by the logical representation of the given [[Dataset]].
8592
* Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because
@@ -96,13 +103,10 @@ class CacheManager extends Logging {
96103
val sparkSession = query.sparkSession
97104
val inMemoryRelation = InMemoryRelation(
98105
sparkSession.sessionState.conf.useCompression,
99-
sparkSession.sessionState.conf.columnBatchSize,
100-
storageLevel,
106+
sparkSession.sessionState.conf.columnBatchSize, storageLevel,
101107
sparkSession.sessionState.executePlan(planToCache).executedPlan,
102-
tableName)
103-
if (planToCache.conf.cboEnabled && planToCache.stats.rowCount.isDefined) {
104-
inMemoryRelation.setStatsFromCachedPlan(planToCache)
105-
}
108+
tableName,
109+
extractStatsOfPlanForCache(planToCache))
106110
cachedData.add(CachedData(planToCache, inMemoryRelation))
107111
}
108112
}
@@ -150,7 +154,8 @@ class CacheManager extends Logging {
150154
batchSize = cd.cachedRepresentation.batchSize,
151155
storageLevel = cd.cachedRepresentation.storageLevel,
152156
child = spark.sessionState.executePlan(cd.plan).executedPlan,
153-
tableName = cd.cachedRepresentation.tableName)
157+
tableName = cd.cachedRepresentation.tableName,
158+
stats = extractStatsOfPlanForCache(cd.plan))
154159
needToRecache += cd.copy(cachedRepresentation = newCache)
155160
}
156161
}

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,22 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2626
import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.catalyst.plans.logical
28-
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
28+
import org.apache.spark.sql.catalyst.plans.logical.Statistics
2929
import org.apache.spark.sql.execution.SparkPlan
30-
import org.apache.spark.sql.execution.datasources.LogicalRelation
3130
import org.apache.spark.storage.StorageLevel
3231
import org.apache.spark.util.LongAccumulator
3332

3433

3534
object InMemoryRelation {
36-
3735
def apply(
3836
useCompression: Boolean,
3937
batchSize: Int,
4038
storageLevel: StorageLevel,
4139
child: SparkPlan,
42-
tableName: Option[String]): InMemoryRelation =
43-
new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)()
40+
tableName: Option[String],
41+
stats: Option[Statistics]): InMemoryRelation =
42+
new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)(
43+
statsOfPlanToCache = stats)
4444
}
4545

4646

@@ -62,7 +62,8 @@ case class InMemoryRelation(
6262
@transient child: SparkPlan,
6363
tableName: Option[String])(
6464
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
65-
val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator)
65+
val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
66+
statsOfPlanToCache: Option[Statistics] = None)
6667
extends logical.LeafNode with MultiInstanceRelation {
6768

6869
override protected def innerChildren: Seq[SparkPlan] = Seq(child)
@@ -73,20 +74,13 @@ case class InMemoryRelation(
7374

7475
override def computeStats(): Statistics = {
7576
if (batchStats.value == 0L) {
76-
inheritedStats.getOrElse(Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes))
77+
statsOfPlanToCache.getOrElse(Statistics(sizeInBytes =
78+
child.sqlContext.conf.defaultSizeInBytes))
7779
} else {
7880
Statistics(sizeInBytes = batchStats.value.longValue)
7981
}
8082
}
8183

82-
private var inheritedStats: Option[Statistics] = None
83-
84-
private[execution] def setStatsFromCachedPlan(planToCache: LogicalPlan): Unit = {
85-
require(planToCache.conf.cboEnabled, "you cannot use the stats of cached plan in" +
86-
" InMemoryRelation without cbo enabled")
87-
inheritedStats = Some(planToCache.stats)
88-
}
89-
9084
// If the cached column buffers were not passed in, we calculate them in the constructor.
9185
// As in Spark, the actual work of caching is lazy.
9286
if (_cachedColumnBuffers == null) {

0 commit comments

Comments
 (0)