Skip to content

Commit 4b977ff

Browse files
gatorsmilecloud-fan
authored andcommitted
[SPARK-19765][SPARK-18549][SPARK-19093][SPARK-19736][BACKPORT-2.1][SQL] Backport Three Cache-related PRs to Spark 2.1
### What changes were proposed in this pull request? Backport a few cache related PRs: --- [[SPARK-19093][SQL] Cached tables are not used in SubqueryExpression](#16493) Consider the plans inside subquery expressions while looking up cache manager to make use of cached data. Currently CacheManager.useCachedData does not consider the subquery expressions in the plan. --- [[SPARK-19736][SQL] refreshByPath should clear all cached plans with the specified path](#17064) Catalog.refreshByPath can refresh the cache entry and the associated metadata for all dataframes (if any), that contain the given data source path. However, CacheManager.invalidateCachedPath doesn't clear all cached plans with the specified path. It causes some strange behaviors reported in SPARK-15678. --- [[SPARK-19765][SPARK-18549][SQL] UNCACHE TABLE should un-cache all cached plans that refer to this table](#17097) When un-cache a table, we should not only remove the cache entry for this table, but also un-cache any other cached plans that refer to this table. The following commands trigger the table uncache: `DropTableCommand`, `TruncateTableCommand`, `AlterTableRenameCommand`, `UncacheTableCommand`, `RefreshTable` and `InsertIntoHiveTable` This PR also includes some refactors: - use java.util.LinkedList to store the cache entries, so that it's safer to remove elements while iterating - rename invalidateCache to recacheByPlan, which is more obvious about what it does. ### How was this patch tested? N/A Author: Xiao Li <[email protected]> Closes #17319 from gatorsmile/backport-17097.
1 parent 9d032d0 commit 4b977ff

File tree

9 files changed

+200
-88
lines changed

9 files changed

+200
-88
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala

Lines changed: 70 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@ package org.apache.spark.sql.execution
1919

2020
import java.util.concurrent.locks.ReentrantReadWriteLock
2121

22+
import scala.collection.JavaConverters._
23+
2224
import org.apache.hadoop.fs.{FileSystem, Path}
2325

2426
import org.apache.spark.internal.Logging
27+
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
2528
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2629
import org.apache.spark.sql.Dataset
2730
import org.apache.spark.sql.execution.columnar.InMemoryRelation
@@ -44,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
4447
class CacheManager extends Logging {
4548

4649
@transient
47-
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
50+
private val cachedData = new java.util.LinkedList[CachedData]
4851

4952
@transient
5053
private val cacheLock = new ReentrantReadWriteLock
@@ -69,7 +72,7 @@ class CacheManager extends Logging {
6972

7073
/** Clears all cached tables. */
7174
def clearCache(): Unit = writeLock {
72-
cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
75+
cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
7376
cachedData.clear()
7477
}
7578

@@ -87,92 +90,109 @@ class CacheManager extends Logging {
8790
query: Dataset[_],
8891
tableName: Option[String] = None,
8992
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
90-
val planToCache = query.queryExecution.analyzed
93+
val planToCache = query.logicalPlan
9194
if (lookupCachedData(planToCache).nonEmpty) {
9295
logWarning("Asked to cache already cached data.")
9396
} else {
9497
val sparkSession = query.sparkSession
95-
cachedData +=
96-
CachedData(
97-
planToCache,
98-
InMemoryRelation(
99-
sparkSession.sessionState.conf.useCompression,
100-
sparkSession.sessionState.conf.columnBatchSize,
101-
storageLevel,
102-
sparkSession.sessionState.executePlan(planToCache).executedPlan,
103-
tableName))
98+
cachedData.add(CachedData(
99+
planToCache,
100+
InMemoryRelation(
101+
sparkSession.sessionState.conf.useCompression,
102+
sparkSession.sessionState.conf.columnBatchSize,
103+
storageLevel,
104+
sparkSession.sessionState.executePlan(planToCache).executedPlan,
105+
tableName)))
104106
}
105107
}
106108

107109
/**
108-
* Tries to remove the data for the given [[Dataset]] from the cache.
109-
* No operation, if it's already uncached.
110+
* Un-cache all the cache entries that refer to the given plan.
111+
*/
112+
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock {
113+
uncacheQuery(query.sparkSession, query.logicalPlan, blocking)
114+
}
115+
116+
/**
117+
* Un-cache all the cache entries that refer to the given plan.
110118
*/
111-
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock {
112-
val planToCache = query.queryExecution.analyzed
113-
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
114-
val found = dataIndex >= 0
115-
if (found) {
116-
cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
117-
cachedData.remove(dataIndex)
119+
def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock {
120+
val it = cachedData.iterator()
121+
while (it.hasNext) {
122+
val cd = it.next()
123+
if (cd.plan.find(_.sameResult(plan)).isDefined) {
124+
cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
125+
it.remove()
126+
}
118127
}
119-
found
128+
}
129+
130+
/**
131+
* Tries to re-cache all the cache entries that refer to the given plan.
132+
*/
133+
def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock {
134+
recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined)
135+
}
136+
137+
private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = {
138+
val it = cachedData.iterator()
139+
val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData]
140+
while (it.hasNext) {
141+
val cd = it.next()
142+
if (condition(cd.plan)) {
143+
cd.cachedRepresentation.cachedColumnBuffers.unpersist()
144+
// Remove the cache entry before we create a new one, so that we can have a different
145+
// physical plan.
146+
it.remove()
147+
val newCache = InMemoryRelation(
148+
useCompression = cd.cachedRepresentation.useCompression,
149+
batchSize = cd.cachedRepresentation.batchSize,
150+
storageLevel = cd.cachedRepresentation.storageLevel,
151+
child = spark.sessionState.executePlan(cd.plan).executedPlan,
152+
tableName = cd.cachedRepresentation.tableName)
153+
needToRecache += cd.copy(cachedRepresentation = newCache)
154+
}
155+
}
156+
157+
needToRecache.foreach(cachedData.add)
120158
}
121159

122160
/** Optionally returns cached data for the given [[Dataset]] */
123161
def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock {
124-
lookupCachedData(query.queryExecution.analyzed)
162+
lookupCachedData(query.logicalPlan)
125163
}
126164

127165
/** Optionally returns cached data for the given [[LogicalPlan]]. */
128166
def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
129-
cachedData.find(cd => plan.sameResult(cd.plan))
167+
cachedData.asScala.find(cd => plan.sameResult(cd.plan))
130168
}
131169

132170
/** Replaces segments of the given logical plan with cached versions where possible. */
133171
def useCachedData(plan: LogicalPlan): LogicalPlan = {
134-
plan transformDown {
172+
val newPlan = plan transformDown {
135173
case currentFragment =>
136174
lookupCachedData(currentFragment)
137175
.map(_.cachedRepresentation.withOutput(currentFragment.output))
138176
.getOrElse(currentFragment)
139177
}
140-
}
141178

142-
/**
143-
* Invalidates the cache of any data that contains `plan`. Note that it is possible that this
144-
* function will over invalidate.
145-
*/
146-
def invalidateCache(plan: LogicalPlan): Unit = writeLock {
147-
cachedData.foreach {
148-
case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty =>
149-
data.cachedRepresentation.recache()
150-
case _ =>
179+
newPlan transformAllExpressions {
180+
case s: SubqueryExpression => s.withNewPlan(useCachedData(s.plan))
151181
}
152182
}
153183

154184
/**
155-
* Invalidates the cache of any data that contains `resourcePath` in one or more
185+
* Tries to re-cache all the cache entries that contain `resourcePath` in one or more
156186
* `HadoopFsRelation` node(s) as part of its logical plan.
157187
*/
158-
def invalidateCachedPath(
159-
sparkSession: SparkSession, resourcePath: String): Unit = writeLock {
188+
def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock {
160189
val (fs, qualifiedPath) = {
161190
val path = new Path(resourcePath)
162-
val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf())
163-
(fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory))
191+
val fs = path.getFileSystem(spark.sessionState.newHadoopConf())
192+
(fs, fs.makeQualified(path))
164193
}
165194

166-
cachedData.foreach {
167-
case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined =>
168-
val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan))
169-
if (dataIndex >= 0) {
170-
data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true)
171-
cachedData.remove(dataIndex)
172-
}
173-
sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan))
174-
case _ => // Do Nothing
175-
}
195+
recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined)
176196
}
177197

178198
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,6 @@ case class InMemoryRelation(
8585
buildBuffers()
8686
}
8787

88-
def recache(): Unit = {
89-
_cachedColumnBuffers.unpersist()
90-
_cachedColumnBuffers = null
91-
buildBuffers()
92-
}
93-
9488
private def buildBuffers(): Unit = {
9589
val output = child.output
9690
val cached = child.execute().mapPartitionsInternal { rowIterator =>

sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,7 @@ case class DropTableCommand(
199199
}
200200
}
201201
try {
202-
sparkSession.sharedState.cacheManager.uncacheQuery(
203-
sparkSession.table(tableName.quotedString))
202+
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
204203
} catch {
205204
case _: NoSuchTableException if ifExists =>
206205
case NonFatal(e) => log.warn(e.toString, e)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand(
4242
val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
4343
relation.insert(df, overwrite.enabled)
4444

45-
// Invalidate the cache.
46-
sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation)
45+
// Re-cache all cached plans(including this relation itself, if it's cached) that refer to this
46+
// data source relation.
47+
sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation)
4748

4849
Seq.empty[Row]
4950
}

sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
373373
* @since 2.0.0
374374
*/
375375
override def dropTempView(viewName: String): Boolean = {
376-
sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView =>
377-
sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView))
376+
sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef =>
377+
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
378378
sessionCatalog.dropTempView(viewName)
379379
}
380380
}
@@ -389,7 +389,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
389389
*/
390390
override def dropGlobalTempView(viewName: String): Boolean = {
391391
sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef =>
392-
sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef))
392+
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
393393
sessionCatalog.dropGlobalTempView(viewName)
394394
}
395395
}
@@ -434,7 +434,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
434434
* @since 2.0.0
435435
*/
436436
override def uncacheTable(tableName: String): Unit = {
437-
sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName))
437+
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
438438
}
439439

440440
/**
@@ -472,17 +472,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
472472

473473
// If this table is cached as an InMemoryRelation, drop the original
474474
// cached version and make the new version cached lazily.
475-
val logicalPlan = sparkSession.sessionState.catalog.lookupRelation(tableIdent)
476-
// Use lookupCachedData directly since RefreshTable also takes databaseName.
477-
val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty
478-
if (isCached) {
479-
// Create a data frame to represent the table.
480-
// TODO: Use uncacheTable once it supports database name.
481-
val df = Dataset.ofRows(sparkSession, logicalPlan)
475+
val table = sparkSession.table(tableIdent)
476+
if (isCached(table)) {
482477
// Uncache the logicalPlan.
483-
sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true)
478+
sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true)
484479
// Cache it again.
485-
sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table))
480+
sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table))
486481
}
487482
}
488483

@@ -494,7 +489,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
494489
* @since 2.0.0
495490
*/
496491
override def refreshByPath(resourcePath: String): Unit = {
497-
sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath)
492+
sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath)
498493
}
499494
}
500495

0 commit comments

Comments
 (0)