Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ package org.apache.spark.sql.execution

import java.util.concurrent.locks.ReentrantReadWriteLock

import scala.collection.JavaConverters._

import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.execution.columnar.InMemoryRelation
Expand All @@ -44,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
class CacheManager extends Logging {

@transient
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
private val cachedData = new java.util.LinkedList[CachedData]

@transient
private val cacheLock = new ReentrantReadWriteLock
Expand All @@ -69,7 +72,7 @@ class CacheManager extends Logging {

/** Clears all cached tables. */
def clearCache(): Unit = writeLock {
cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.clear()
}

Expand All @@ -87,92 +90,109 @@ class CacheManager extends Logging {
query: Dataset[_],
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
val planToCache = query.logicalPlan
if (lookupCachedData(planToCache).nonEmpty) {
logWarning("Asked to cache already cached data.")
} else {
val sparkSession = query.sparkSession
cachedData +=
CachedData(
planToCache,
InMemoryRelation(
sparkSession.sessionState.conf.useCompression,
sparkSession.sessionState.conf.columnBatchSize,
storageLevel,
sparkSession.sessionState.executePlan(planToCache).executedPlan,
tableName))
cachedData.add(CachedData(
planToCache,
InMemoryRelation(
sparkSession.sessionState.conf.useCompression,
sparkSession.sessionState.conf.columnBatchSize,
storageLevel,
sparkSession.sessionState.executePlan(planToCache).executedPlan,
tableName)))
}
}

/**
* Tries to remove the data for the given [[Dataset]] from the cache.
* No operation, if it's already uncached.
* Un-cache all the cache entries that refer to the given plan.
*/
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock {
uncacheQuery(query.sparkSession, query.logicalPlan, blocking)
}

/**
* Un-cache all the cache entries that refer to the given plan.
*/
def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
val found = dataIndex >= 0
if (found) {
cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
cachedData.remove(dataIndex)
def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock {
val it = cachedData.iterator()
while (it.hasNext) {
val cd = it.next()
if (cd.plan.find(_.sameResult(plan)).isDefined) {
cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
it.remove()
}
}
found
}

/**
* Tries to re-cache all the cache entries that refer to the given plan.
*/
def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock {
recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined)
}

private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = {
val it = cachedData.iterator()
val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData]
while (it.hasNext) {
val cd = it.next()
if (condition(cd.plan)) {
cd.cachedRepresentation.cachedColumnBuffers.unpersist()
// Remove the cache entry before we create a new one, so that we can have a different
// physical plan.
it.remove()
val newCache = InMemoryRelation(
useCompression = cd.cachedRepresentation.useCompression,
batchSize = cd.cachedRepresentation.batchSize,
storageLevel = cd.cachedRepresentation.storageLevel,
child = spark.sessionState.executePlan(cd.plan).executedPlan,
tableName = cd.cachedRepresentation.tableName)
needToRecache += cd.copy(cachedRepresentation = newCache)
}
}

needToRecache.foreach(cachedData.add)
}

/** Optionally returns cached data for the given [[Dataset]] */
def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock {
lookupCachedData(query.queryExecution.analyzed)
lookupCachedData(query.logicalPlan)
}

/** Optionally returns cached data for the given [[LogicalPlan]]. */
def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
cachedData.find(cd => plan.sameResult(cd.plan))
cachedData.asScala.find(cd => plan.sameResult(cd.plan))
}

/** Replaces segments of the given logical plan with cached versions where possible. */
def useCachedData(plan: LogicalPlan): LogicalPlan = {
plan transformDown {
val newPlan = plan transformDown {
case currentFragment =>
lookupCachedData(currentFragment)
.map(_.cachedRepresentation.withOutput(currentFragment.output))
.getOrElse(currentFragment)
}
}

/**
* Invalidates the cache of any data that contains `plan`. Note that it is possible that this
* function will over invalidate.
*/
def invalidateCache(plan: LogicalPlan): Unit = writeLock {
cachedData.foreach {
case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty =>
data.cachedRepresentation.recache()
case _ =>
newPlan transformAllExpressions {
case s: SubqueryExpression => s.withNewPlan(useCachedData(s.plan))
}
}

/**
* Invalidates the cache of any data that contains `resourcePath` in one or more
* Tries to re-cache all the cache entries that contain `resourcePath` in one or more
* `HadoopFsRelation` node(s) as part of its logical plan.
*/
def invalidateCachedPath(
sparkSession: SparkSession, resourcePath: String): Unit = writeLock {
def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock {
val (fs, qualifiedPath) = {
val path = new Path(resourcePath)
val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf())
(fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory))
val fs = path.getFileSystem(spark.sessionState.newHadoopConf())
(fs, fs.makeQualified(path))
}

cachedData.foreach {
case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined =>
val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan))
if (dataIndex >= 0) {
data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true)
cachedData.remove(dataIndex)
}
sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan))
case _ => // Do Nothing
}
recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,6 @@ case class InMemoryRelation(
buildBuffers()
}

def recache(): Unit = {
_cachedColumnBuffers.unpersist()
_cachedColumnBuffers = null
buildBuffers()
}

private def buildBuffers(): Unit = {
val output = child.output
val cached = child.execute().mapPartitionsInternal { rowIterator =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ case class DropTableCommand(
}
}
try {
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession.table(tableName.quotedString))
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
} catch {
case _: NoSuchTableException if ifExists =>
case NonFatal(e) => log.warn(e.toString, e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand(
val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
relation.insert(df, overwrite.enabled)

// Invalidate the cache.
sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation)
// Re-cache all cached plans(including this relation itself, if it's cached) that refer to this
// data source relation.
sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation)

Seq.empty[Row]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def dropTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView =>
sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView))
sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef =>
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sessionCatalog.dropTempView(viewName)
}
}
Expand All @@ -389,7 +389,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
override def dropGlobalTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef =>
sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef))
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sessionCatalog.dropGlobalTempView(viewName)
}
}
Expand Down Expand Up @@ -434,7 +434,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def uncacheTable(tableName: String): Unit = {
sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName))
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
}

/**
Expand Down Expand Up @@ -472,17 +472,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {

// If this table is cached as an InMemoryRelation, drop the original
// cached version and make the new version cached lazily.
val logicalPlan = sparkSession.sessionState.catalog.lookupRelation(tableIdent)
// Use lookupCachedData directly since RefreshTable also takes databaseName.
val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty
if (isCached) {
// Create a data frame to represent the table.
// TODO: Use uncacheTable once it supports database name.
val df = Dataset.ofRows(sparkSession, logicalPlan)
val table = sparkSession.table(tableIdent)
if (isCached(table)) {
// Uncache the logicalPlan.
sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true)
sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true)
// Cache it again.
sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table))
sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table))
}
}

Expand All @@ -494,7 +489,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def refreshByPath(resourcePath: String): Unit = {
sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath)
sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath)
}
}

Expand Down
Loading