Skip to content

Commit 20ca208

Browse files
maropucloud-fan
authored andcommitted
[SPARK-23880][SQL] Do not trigger any jobs for caching data
## What changes were proposed in this pull request? This pr fixed code so that `cache` could prevent any jobs from being triggered. For example, in the current master, an operation below triggers a actual job; ``` val df = spark.range(10000000000L) .filter('id > 1000) .orderBy('id.desc) .cache() ``` This triggers a job while the cache should be lazy. The problem is that, when creating `InMemoryRelation`, we build the RDD, which calls `SparkPlan.execute` and may trigger jobs, like sampling job for range partitioner, or broadcast job. This pr removed the code to build a cached `RDD` in the constructor of `InMemoryRelation` and added `CachedRDDBuilder` to lazily build the `RDD` in `InMemoryRelation`. Then, the first call of `CachedRDDBuilder.cachedColumnBuffers` triggers a job to materialize the cache in `InMemoryTableScanExec` . ## How was this patch tested? Added tests in `CachedTableSuite`. Author: Takeshi Yamamuro <[email protected]> Closes #21018 from maropu/SPARK-23880.
1 parent 64e8408 commit 20ca208

File tree

8 files changed

+133
-94
lines changed

8 files changed

+133
-94
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2933,7 +2933,7 @@ class Dataset[T] private[sql](
29332933
*/
29342934
def storageLevel: StorageLevel = {
29352935
sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData =>
2936-
cachedData.cachedRepresentation.storageLevel
2936+
cachedData.cachedRepresentation.cacheBuilder.storageLevel
29372937
}.getOrElse(StorageLevel.NONE)
29382938
}
29392939

sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class CacheManager extends Logging {
7171

7272
/** Clears all cached tables. */
7373
def clearCache(): Unit = writeLock {
74-
cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
74+
cachedData.asScala.foreach(_.cachedRepresentation.cacheBuilder.clearCache())
7575
cachedData.clear()
7676
}
7777

@@ -119,7 +119,7 @@ class CacheManager extends Logging {
119119
while (it.hasNext) {
120120
val cd = it.next()
121121
if (cd.plan.find(_.sameResult(plan)).isDefined) {
122-
cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
122+
cd.cachedRepresentation.cacheBuilder.clearCache(blocking)
123123
it.remove()
124124
}
125125
}
@@ -138,16 +138,14 @@ class CacheManager extends Logging {
138138
while (it.hasNext) {
139139
val cd = it.next()
140140
if (condition(cd.plan)) {
141-
cd.cachedRepresentation.cachedColumnBuffers.unpersist()
141+
cd.cachedRepresentation.cacheBuilder.clearCache()
142142
// Remove the cache entry before we create a new one, so that we can have a different
143143
// physical plan.
144144
it.remove()
145+
val plan = spark.sessionState.executePlan(cd.plan).executedPlan
145146
val newCache = InMemoryRelation(
146-
useCompression = cd.cachedRepresentation.useCompression,
147-
batchSize = cd.cachedRepresentation.batchSize,
148-
storageLevel = cd.cachedRepresentation.storageLevel,
149-
child = spark.sessionState.executePlan(cd.plan).executedPlan,
150-
tableName = cd.cachedRepresentation.tableName,
147+
cacheBuilder = cd.cachedRepresentation
148+
.cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null),
151149
logicalPlan = cd.plan)
152150
needToRecache += cd.copy(cachedRepresentation = newCache)
153151
}

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala

Lines changed: 85 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,6 @@ import org.apache.spark.storage.StorageLevel
3232
import org.apache.spark.util.LongAccumulator
3333

3434

35-
object InMemoryRelation {
36-
def apply(
37-
useCompression: Boolean,
38-
batchSize: Int,
39-
storageLevel: StorageLevel,
40-
child: SparkPlan,
41-
tableName: Option[String],
42-
logicalPlan: LogicalPlan): InMemoryRelation =
43-
new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)(
44-
statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
45-
}
46-
47-
4835
/**
4936
* CachedBatch is a cached batch of rows.
5037
*
@@ -55,58 +42,41 @@ object InMemoryRelation {
5542
private[columnar]
5643
case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow)
5744

58-
case class InMemoryRelation(
59-
output: Seq[Attribute],
45+
case class CachedRDDBuilder(
6046
useCompression: Boolean,
6147
batchSize: Int,
6248
storageLevel: StorageLevel,
63-
@transient child: SparkPlan,
49+
@transient cachedPlan: SparkPlan,
6450
tableName: Option[String])(
65-
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
66-
val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
67-
statsOfPlanToCache: Statistics,
68-
override val outputOrdering: Seq[SortOrder])
69-
extends logical.LeafNode with MultiInstanceRelation {
70-
71-
override protected def innerChildren: Seq[SparkPlan] = Seq(child)
72-
73-
override def doCanonicalize(): logical.LogicalPlan =
74-
copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)),
75-
storageLevel = StorageLevel.NONE,
76-
child = child.canonicalized,
77-
tableName = None)(
78-
_cachedColumnBuffers,
79-
sizeInBytesStats,
80-
statsOfPlanToCache,
81-
outputOrdering)
51+
@transient private var _cachedColumnBuffers: RDD[CachedBatch] = null) {
8252

83-
override def producedAttributes: AttributeSet = outputSet
84-
85-
@transient val partitionStatistics = new PartitionStatistics(output)
53+
val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator
8654

87-
override def computeStats(): Statistics = {
88-
if (sizeInBytesStats.value == 0L) {
89-
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
90-
// Note that we should drop the hint info here. We may cache a plan whose root node is a hint
91-
// node. When we lookup the cache with a semantically same plan without hint info, the plan
92-
// returned by cache lookup should not have hint info. If we lookup the cache with a
93-
// semantically same plan with a different hint info, `CacheManager.useCachedData` will take
94-
// care of it and retain the hint info in the lookup input plan.
95-
statsOfPlanToCache.copy(hints = HintInfo())
96-
} else {
97-
Statistics(sizeInBytes = sizeInBytesStats.value.longValue)
55+
def cachedColumnBuffers: RDD[CachedBatch] = {
56+
if (_cachedColumnBuffers == null) {
57+
synchronized {
58+
if (_cachedColumnBuffers == null) {
59+
_cachedColumnBuffers = buildBuffers()
60+
}
61+
}
9862
}
63+
_cachedColumnBuffers
9964
}
10065

101-
// If the cached column buffers were not passed in, we calculate them in the constructor.
102-
// As in Spark, the actual work of caching is lazy.
103-
if (_cachedColumnBuffers == null) {
104-
buildBuffers()
66+
def clearCache(blocking: Boolean = true): Unit = {
67+
if (_cachedColumnBuffers != null) {
68+
synchronized {
69+
if (_cachedColumnBuffers != null) {
70+
_cachedColumnBuffers.unpersist(blocking)
71+
_cachedColumnBuffers = null
72+
}
73+
}
74+
}
10575
}
10676

107-
private def buildBuffers(): Unit = {
108-
val output = child.output
109-
val cached = child.execute().mapPartitionsInternal { rowIterator =>
77+
private def buildBuffers(): RDD[CachedBatch] = {
78+
val output = cachedPlan.output
79+
val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator =>
11080
new Iterator[CachedBatch] {
11181
def next(): CachedBatch = {
11282
val columnBuilders = output.map { attribute =>
@@ -154,32 +124,77 @@ case class InMemoryRelation(
154124

155125
cached.setName(
156126
tableName.map(n => s"In-memory table $n")
157-
.getOrElse(StringUtils.abbreviate(child.toString, 1024)))
158-
_cachedColumnBuffers = cached
127+
.getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024)))
128+
cached
129+
}
130+
}
131+
132+
object InMemoryRelation {
133+
134+
def apply(
135+
useCompression: Boolean,
136+
batchSize: Int,
137+
storageLevel: StorageLevel,
138+
child: SparkPlan,
139+
tableName: Option[String],
140+
logicalPlan: LogicalPlan): InMemoryRelation = {
141+
val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName)()
142+
new InMemoryRelation(child.output, cacheBuilder)(
143+
statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
144+
}
145+
146+
def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = {
147+
new InMemoryRelation(cacheBuilder.cachedPlan.output, cacheBuilder)(
148+
statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
149+
}
150+
}
151+
152+
case class InMemoryRelation(
153+
output: Seq[Attribute],
154+
@transient cacheBuilder: CachedRDDBuilder)(
155+
statsOfPlanToCache: Statistics,
156+
override val outputOrdering: Seq[SortOrder])
157+
extends logical.LeafNode with MultiInstanceRelation {
158+
159+
override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan)
160+
161+
override def doCanonicalize(): logical.LogicalPlan =
162+
copy(output = output.map(QueryPlan.normalizeExprId(_, cachedPlan.output)),
163+
cacheBuilder)(
164+
statsOfPlanToCache,
165+
outputOrdering)
166+
167+
override def producedAttributes: AttributeSet = outputSet
168+
169+
@transient val partitionStatistics = new PartitionStatistics(output)
170+
171+
def cachedPlan: SparkPlan = cacheBuilder.cachedPlan
172+
173+
override def computeStats(): Statistics = {
174+
if (cacheBuilder.sizeInBytesStats.value == 0L) {
175+
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
176+
// Note that we should drop the hint info here. We may cache a plan whose root node is a hint
177+
// node. When we lookup the cache with a semantically same plan without hint info, the plan
178+
// returned by cache lookup should not have hint info. If we lookup the cache with a
179+
// semantically same plan with a different hint info, `CacheManager.useCachedData` will take
180+
// care of it and retain the hint info in the lookup input plan.
181+
statsOfPlanToCache.copy(hints = HintInfo())
182+
} else {
183+
Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue)
184+
}
159185
}
160186

161187
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
162-
InMemoryRelation(
163-
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
164-
_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering)
188+
InMemoryRelation(newOutput, cacheBuilder)(statsOfPlanToCache, outputOrdering)
165189
}
166190

167191
override def newInstance(): this.type = {
168192
new InMemoryRelation(
169193
output.map(_.newInstance()),
170-
useCompression,
171-
batchSize,
172-
storageLevel,
173-
child,
174-
tableName)(
175-
_cachedColumnBuffers,
176-
sizeInBytesStats,
194+
cacheBuilder)(
177195
statsOfPlanToCache,
178196
outputOrdering).asInstanceOf[this.type]
179197
}
180198

181-
def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers
182-
183-
override protected def otherCopyArgs: Seq[AnyRef] =
184-
Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
199+
override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache)
185200
}

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ case class InMemoryTableScanExec(
154154
private def updateAttribute(expr: Expression): Expression = {
155155
// attributes can be pruned so using relation's output.
156156
// E.g., relation.output is [id, item] but this scan's output can be [item] only.
157-
val attrMap = AttributeMap(relation.child.output.zip(relation.output))
157+
val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output))
158158
expr.transform {
159159
case attr: Attribute => attrMap.getOrElse(attr, attr)
160160
}
@@ -163,16 +163,16 @@ case class InMemoryTableScanExec(
163163
// The cached version does not change the outputPartitioning of the original SparkPlan.
164164
// But the cached version could alias output, so we need to replace output.
165165
override def outputPartitioning: Partitioning = {
166-
relation.child.outputPartitioning match {
166+
relation.cachedPlan.outputPartitioning match {
167167
case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning]
168-
case _ => relation.child.outputPartitioning
168+
case _ => relation.cachedPlan.outputPartitioning
169169
}
170170
}
171171

172172
// The cached version does not change the outputOrdering of the original SparkPlan.
173173
// But the cached version could alias output, so we need to replace output.
174174
override def outputOrdering: Seq[SortOrder] =
175-
relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])
175+
relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])
176176

177177
// Keeps relation's partition statistics because we don't serialize relation.
178178
private val stats = relation.partitionStatistics
@@ -252,7 +252,7 @@ case class InMemoryTableScanExec(
252252
// within the map Partitions closure.
253253
val schema = stats.schema
254254
val schemaIndex = schema.zipWithIndex
255-
val buffers = relation.cachedColumnBuffers
255+
val buffers = relation.cacheBuilder.cachedColumnBuffers
256256

257257
buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
258258
val partitionFilter = newPredicate(

sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.concurrent.duration._
2222
import scala.language.postfixOps
2323

2424
import org.apache.spark.CleanerListener
25+
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
2526
import org.apache.spark.sql.catalyst.TableIdentifier
2627
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
2728
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
@@ -52,7 +53,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
5253
val plan = spark.table(tableName).queryExecution.sparkPlan
5354
plan.collect {
5455
case InMemoryTableScanExec(_, _, relation) =>
55-
relation.cachedColumnBuffers.id
56+
relation.cacheBuilder.cachedColumnBuffers.id
5657
case _ =>
5758
fail(s"Table $tableName is not cached\n" + plan)
5859
}.head
@@ -78,7 +79,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
7879
private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = {
7980
plan.collect {
8081
case InMemoryTableScanExec(_, _, relation) =>
81-
getNumInMemoryTablesRecursively(relation.child) + 1
82+
getNumInMemoryTablesRecursively(relation.cachedPlan) + 1
8283
}.sum
8384
}
8485

@@ -200,7 +201,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
200201
spark.catalog.cacheTable("testData")
201202
assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") {
202203
spark.table("testData").queryExecution.withCachedData.collect {
203-
case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r
204+
case r: InMemoryRelation if r.cachedPlan.isInstanceOf[InMemoryTableScanExec] => r
204205
}.size
205206
}
206207

@@ -367,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
367368
val toBeCleanedAccIds = new HashSet[Long]
368369

369370
val accId1 = spark.table("t1").queryExecution.withCachedData.collect {
370-
case i: InMemoryRelation => i.sizeInBytesStats.id
371+
case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id
371372
}.head
372373
toBeCleanedAccIds += accId1
373374

374375
val accId2 = spark.table("t1").queryExecution.withCachedData.collect {
375-
case i: InMemoryRelation => i.sizeInBytesStats.id
376+
case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id
376377
}.head
377378
toBeCleanedAccIds += accId2
378379

@@ -794,4 +795,29 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
794795
}
795796
}
796797
}
798+
799+
private def checkIfNoJobTriggered[T](f: => T): T = {
800+
var numJobTrigered = 0
801+
val jobListener = new SparkListener {
802+
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
803+
numJobTrigered += 1
804+
}
805+
}
806+
sparkContext.addSparkListener(jobListener)
807+
try {
808+
val result = f
809+
sparkContext.listenerBus.waitUntilEmpty(10000L)
810+
assert(numJobTrigered === 0)
811+
result
812+
} finally {
813+
sparkContext.removeSparkListener(jobListener)
814+
}
815+
}
816+
817+
test("SPARK-23880 table cache should be lazy and don't trigger any jobs") {
818+
val cachedData = checkIfNoJobTriggered {
819+
spark.range(1002).filter('id > 1000).orderBy('id.desc).cache()
820+
}
821+
assert(cachedData.collect === Seq(1001))
822+
}
797823
}

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class PlannerSuite extends SharedSQLContext {
194194
test("CollectLimit can appear in the middle of a plan when caching is used") {
195195
val query = testData.select('key, 'value).limit(2).cache()
196196
val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation]
197-
assert(planned.child.isInstanceOf[CollectLimitExec])
197+
assert(planned.cachedPlan.isInstanceOf[CollectLimitExec])
198198
}
199199

200200
test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") {

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
4545
val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None,
4646
data.logicalPlan)
4747

48-
assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel)
49-
inMemoryRelation.cachedColumnBuffers.collect().head match {
48+
assert(inMemoryRelation.cacheBuilder.cachedColumnBuffers.getStorageLevel == storageLevel)
49+
inMemoryRelation.cacheBuilder.cachedColumnBuffers.collect().head match {
5050
case _: CachedBatch =>
5151
case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}")
5252
}
@@ -337,7 +337,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
337337
checkAnswer(cached, expectedAnswer)
338338

339339
// Check that the right size was calculated.
340-
assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize)
340+
assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize)
341341
}
342342

343343
test("access primitive-type columns in CachedBatch without whole stage codegen") {

0 commit comments

Comments
 (0)