diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 526623a36d2a..0ea806d6cb50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -19,9 +19,12 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.columnar.InMemoryRelation @@ -44,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) class CacheManager extends Logging { @transient - private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] + private val cachedData = new java.util.LinkedList[CachedData] @transient private val cacheLock = new ReentrantReadWriteLock @@ -69,7 +72,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) cachedData.clear() } @@ -87,92 +90,109 @@ class CacheManager extends Logging { query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { - val planToCache = query.queryExecution.analyzed + val planToCache = query.logicalPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { val sparkSession = query.sparkSession - cachedData += - CachedData( - planToCache, - InMemoryRelation( - sparkSession.sessionState.conf.useCompression, - sparkSession.sessionState.conf.columnBatchSize, - storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, - tableName)) + cachedData.add(CachedData( + planToCache, + InMemoryRelation( + sparkSession.sessionState.conf.useCompression, + sparkSession.sessionState.conf.columnBatchSize, + storageLevel, + sparkSession.sessionState.executePlan(planToCache).executedPlan, + tableName))) } } /** - * Tries to remove the data for the given [[Dataset]] from the cache. - * No operation, if it's already uncached. + * Un-cache all the cache entries that refer to the given plan. + */ + def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { + uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + } + + /** + * Un-cache all the cache entries that refer to the given plan. */ - def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { - val planToCache = query.queryExecution.analyzed - val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) - val found = dataIndex >= 0 - if (found) { - cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) - cachedData.remove(dataIndex) + def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { + val it = cachedData.iterator() + while (it.hasNext) { + val cd = it.next() + if (cd.plan.find(_.sameResult(plan)).isDefined) { + cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + it.remove() + } } - found + } + + /** + * Tries to re-cache all the cache entries that refer to the given plan. + */ + def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock { + recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined) + } + + private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = { + val it = cachedData.iterator() + val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData] + while (it.hasNext) { + val cd = it.next() + if (condition(cd.plan)) { + cd.cachedRepresentation.cachedColumnBuffers.unpersist() + // Remove the cache entry before we create a new one, so that we can have a different + // physical plan. + it.remove() + val newCache = InMemoryRelation( + useCompression = cd.cachedRepresentation.useCompression, + batchSize = cd.cachedRepresentation.batchSize, + storageLevel = cd.cachedRepresentation.storageLevel, + child = spark.sessionState.executePlan(cd.plan).executedPlan, + tableName = cd.cachedRepresentation.tableName) + needToRecache += cd.copy(cachedRepresentation = newCache) + } + } + + needToRecache.foreach(cachedData.add) } /** Optionally returns cached data for the given [[Dataset]] */ def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { - lookupCachedData(query.queryExecution.analyzed) + lookupCachedData(query.logicalPlan) } /** Optionally returns cached data for the given [[LogicalPlan]]. */ def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { - cachedData.find(cd => plan.sameResult(cd.plan)) + cachedData.asScala.find(cd => plan.sameResult(cd.plan)) } /** Replaces segments of the given logical plan with cached versions where possible. */ def useCachedData(plan: LogicalPlan): LogicalPlan = { - plan transformDown { + val newPlan = plan transformDown { case currentFragment => lookupCachedData(currentFragment) .map(_.cachedRepresentation.withOutput(currentFragment.output)) .getOrElse(currentFragment) } - } - /** - * Invalidates the cache of any data that contains `plan`. Note that it is possible that this - * function will over invalidate. - */ - def invalidateCache(plan: LogicalPlan): Unit = writeLock { - cachedData.foreach { - case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => - data.cachedRepresentation.recache() - case _ => + newPlan transformAllExpressions { + case s: SubqueryExpression => s.withNewPlan(useCachedData(s.plan)) } } /** - * Invalidates the cache of any data that contains `resourcePath` in one or more + * Tries to re-cache all the cache entries that contain `resourcePath` in one or more * `HadoopFsRelation` node(s) as part of its logical plan. */ - def invalidateCachedPath( - sparkSession: SparkSession, resourcePath: String): Unit = writeLock { + def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock { val (fs, qualifiedPath) = { val path = new Path(resourcePath) - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) - (fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory)) + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) + (fs, fs.makeQualified(path)) } - cachedData.foreach { - case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => - val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) - if (dataIndex >= 0) { - data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) - cachedData.remove(dataIndex) - } - sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) - case _ => // Do Nothing - } + recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 03cc04659bd5..949f8b61f297 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -85,12 +85,6 @@ case class InMemoryRelation( buildBuffers() } - def recache(): Unit = { - _cachedColumnBuffers.unpersist() - _cachedColumnBuffers = null - buildBuffers() - } - private def buildBuffers(): Unit = { val output = child.output val cached = child.execute().mapPartitionsInternal { rowIterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index b1bb56570cee..f9afe466d9f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -199,8 +199,7 @@ case class DropTableCommand( } } try { - sparkSession.sharedState.cacheManager.uncacheQuery( - sparkSession.table(tableName.quotedString)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) } catch { case _: NoSuchTableException if ifExists => case NonFatal(e) => log.warn(e.toString, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index 2eba1e9986ac..ac7e3bdfc32e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand( val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) relation.insert(df, overwrite.enabled) - // Invalidate the cache. - sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation) + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this + // data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 41ed9d71809e..9d0b2141d453 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -373,8 +373,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def dropTempView(viewName: String): Boolean = { - sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView => - sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView)) + sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef => + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) sessionCatalog.dropTempView(viewName) } } @@ -389,7 +389,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropGlobalTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) sessionCatalog.dropGlobalTempView(viewName) } } @@ -434,7 +434,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) } /** @@ -472,17 +472,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // If this table is cached as an InMemoryRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sparkSession.sessionState.catalog.lookupRelation(tableIdent) - // Use lookupCachedData directly since RefreshTable also takes databaseName. - val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty - if (isCached) { - // Create a data frame to represent the table. - // TODO: Use uncacheTable once it supports database name. - val df = Dataset.ofRows(sparkSession, logicalPlan) + val table = sparkSession.table(tableIdent) + if (isCached(table)) { // Uncache the logicalPlan. - sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) // Cache it again. - sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table)) + sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table)) } } @@ -494,7 +489,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def refreshByPath(resourcePath: String): Unit = { - sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath) + sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index f42402e1cc7d..5fc081c43113 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -24,19 +24,31 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.CleanerListener +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} -import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.{AccumulatorContext, Utils} private case class BigData(s: String) class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext { import testImplicits._ + setupTestData() + + override def afterEach(): Unit = { + try { + spark.catalog.clearCache() + } finally { + super.afterEach() + } + } + def rddIdOf(tableName: String): Int = { val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { @@ -53,6 +65,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext maybeBlock.nonEmpty } + private def getNumInMemoryRelations(ds: Dataset[_]): Int = { + val plan = ds.queryExecution.withCachedData + var sum = plan.collect { case _: InMemoryRelation => 1 }.sum + plan.transformAllExpressions { + case e: SubqueryExpression => + sum += getNumInMemoryRelations(e.plan) + e + } + sum + } + test("withColumn doesn't invalidate cached dataframe") { var evalCount = 0 val myUDF = udf((x: String) => { evalCount += 1; "result" }) @@ -165,9 +188,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assertCached(spark.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - spark.table("testData").queryExecution.withCachedData.collect { - case r: InMemoryRelation => r - }.size + getNumInMemoryRelations(spark.table("testData")) } spark.catalog.cacheTable("testData") @@ -560,9 +581,93 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext localRelation.createOrReplaceTempView("localRelation") spark.catalog.cacheTable("localRelation") - assert( - localRelation.queryExecution.withCachedData.collect { - case i: InMemoryRelation => i - }.size == 1) + assert(getNumInMemoryRelations(localRelation) == 1) + } + + test("SPARK-19093 Caching in side subquery") { + withTempView("t1") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + spark.catalog.cacheTable("t1") + val ds = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 2) + } + } + + test("SPARK-19093 scalar and nested predicate query") { + withTempView("t1", "t2", "t3", "t4") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + Seq(1).toDF("c1").createOrReplaceTempView("t3") + Seq(1).toDF("c1").createOrReplaceTempView("t4") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") + spark.catalog.cacheTable("t3") + spark.catalog.cacheTable("t4") + + // Nested predicate subquery + val ds = + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 3) + + // Scalar subquery and predicate subquery + val ds2 = + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin) + assert(getNumInMemoryRelations(ds2) == 4) + } + } + + test("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") { + withTable("t") { + withTempPath { path => + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) + sql(s"CREATE TABLE t USING parquet OPTIONS (PATH '${path.toURI}')") + spark.catalog.cacheTable("t") + spark.table("t").select($"i").cache() + checkAnswer(spark.table("t").select($"i"), Row(1)) + assertCached(spark.table("t").select($"i")) + + Utils.deleteRecursively(path) + spark.sessionState.catalog.refreshTable(TableIdentifier("t")) + spark.catalog.uncacheTable("t") + assert(spark.table("t").select($"i").count() == 0) + assert(getNumInMemoryRelations(spark.table("t").select($"i")) == 0) + } + } + } + + test("refreshByPath should refresh all cached plans with the specified path") { + withTempDir { dir => + val path = dir.getCanonicalPath() + + spark.range(10).write.mode("overwrite").parquet(path) + spark.read.parquet(path).cache() + spark.read.parquet(path).filter($"id" > 4).cache() + assert(spark.read.parquet(path).filter($"id" > 4).count() == 5) + + spark.range(20).write.mode("overwrite").parquet(path) + spark.catalog.refreshByPath(path) + assert(spark.read.parquet(path).count() == 20) + assert(spark.read.parquet(path).filter($"id" > 4).count() == 15) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 09d1abfa8c7a..3b9c2fcb0ce1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -388,8 +388,8 @@ case class InsertIntoHiveTable( logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e) } - // Invalidate the cache. - sqlContext.sharedState.cacheManager.invalidateCache(table) + // un-cache this table. + sqlContext.sparkSession.catalog.uncacheTable(table.catalogTable.identifier.quotedString) sqlContext.sessionState.catalog.refreshTable(table.catalogTable.identifier) // It would be nice to just return the childRdd unchanged so insert operations could be chained, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 3871b3d78588..9b24ad045d2a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -196,9 +196,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) sql("DROP TABLE IF EXISTS refreshTable") sparkSession.catalog.createExternalTable("refreshTable", tempPath.toString, "parquet") - checkAnswer( - table("refreshTable"), - table("src").collect()) + checkAnswer(table("refreshTable"), table("src")) // Cache the table. sql("CACHE TABLE refreshTable") assertCached(table("refreshTable")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index e8b81109e2a9..fbb228e0873e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -457,7 +457,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { // Converted test_parquet should be cached. sessionState.catalog.getCachedDataSourceTable(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: HadoopFsRelation, _, _) => // OK + case LogicalRelation(_: HadoopFsRelation, _, _) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " +