From 94d6804bb466820aa83b18d1ab3c9b9cc3711757 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 11 Aug 2015 11:12:55 +0800 Subject: [PATCH] fix bug of invalidate cache for HadoopFsRelation --- .../columnar/InMemoryColumnarTableScan.scala | 10 +++++ .../spark/sql/execution/CacheManager.scala | 40 +++++++++++++++---- .../InsertIntoHadoopFsRelation.scala | 3 ++ .../datasources/json/JSONRelation.scala | 9 ----- .../apache/spark/sql/sources/interfaces.scala | 1 + .../spark/sql/sources/InsertSuite.scala | 19 ++++----- .../sql/sources/hadoopFsRelationSuites.scala | 25 ++++++++++++ 7 files changed, 82 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index d553bb6169ecc..0a8db28f167eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -183,6 +183,16 @@ private[sql] case class InMemoryRelation( batchStats).asInstanceOf[this.type] } + private[sql] def withChild(newChild: SparkPlan): this.type = { + new InMemoryRelation( + output.map(_.newInstance()), + useCompression, + batchSize, + storageLevel, + newChild, + tableName)().asInstanceOf[this.type] + } + def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index d3e5c378d037d..cbdb39a70720c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -27,7 +27,16 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK /** Holds a cached logical plan and its data */ -private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) +private[sql] class CachedData( + val plan: LogicalPlan, + var cachedRepresentation: InMemoryRelation) { + private[sql] def recache(sqlContext: SQLContext): Unit = { + cachedRepresentation.uncache(true) // release the cache + // re-generate the spark plan and cache + cachedRepresentation = + cachedRepresentation.withChild(sqlContext.executePlan(plan).executedPlan) + } +} /** * Provides support in a SQLContext for caching query results and automatically using these cached @@ -97,13 +106,13 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { logWarning("Asked to cache already cached data.") } else { cachedData += - CachedData( + new CachedData( planToCache, InMemoryRelation( sqlContext.conf.useCompression, sqlContext.conf.columnBatchSize, storageLevel, - sqlContext.executePlan(query.logicalPlan).executedPlan, + sqlContext.executePlan(planToCache).executedPlan, tableName)) } } @@ -156,10 +165,27 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { * function will over invalidate. */ private[sql] def invalidateCache(plan: LogicalPlan): Unit = writeLock { - cachedData.foreach { - case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => - data.cachedRepresentation.recache() - case _ => + var i = 0 + var locatedIdx = -1 + // find the index of the cached data, according to the specified logical plan + while (i < cachedData.length && locatedIdx < 0) { + cachedData(i) match { + case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => + locatedIdx = i + case _ => + } + i += 1 + } + + if (locatedIdx >= 0) { + // if the cached data exists, remove it from the cache data list, as we need to + // re-generate the spark plan, and we don't want the this to be used during the + // re-generation + val entry = cachedData.remove(locatedIdx) // TODO do we have to use ArrayBuffer? + // rebuild the cache + entry.recache(sqlContext) + // add it back to the cache data list + cachedData += entry } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 735d52f808868..aec559f72c844 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -160,6 +160,9 @@ private[sql] case class InsertIntoHadoopFsRelation( logInfo("Skipping insertion into a relation that already exists.") } + // Invalidate the cache. + sqlContext.cacheManager.invalidateCache(LogicalRelation(relation)) + Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 114c8b211891e..ab8ca5f748f24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -111,15 +111,6 @@ private[sql] class JSONRelation( jsonSchema } - override private[sql] def buildScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputPaths: Array[String], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { - refresh() - super.buildScan(requiredColumns, filters, inputPaths, broadcastedConf) - } - override def buildScan( requiredColumns: Array[String], filters: Array[Filter], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 2f8417a48d32e..261c1086fbb1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -565,6 +565,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio filters: Array[Filter], inputPaths: Array[String], broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + refresh() val inputStatuses = inputPaths.flatMap { input => val path = new Path(input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index cdbfaf6455fe4..a65d357638c4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -231,17 +231,18 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { s""" |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, b FROM jt """.stripMargin) + // jsonTable should be recached. assertCached(sql("SELECT * FROM jsonTable")) - // TODO we need to invalidate the cached data in InsertIntoHadoopFsRelation -// // The cached data is the new data. -// checkAnswer( -// sql("SELECT a, b FROM jsonTable"), -// sql("SELECT a * 2, b FROM jt").collect()) -// -// // Verify uncaching -// caseInsensitiveContext.uncacheTable("jsonTable") -// assertCached(sql("SELECT * FROM jsonTable"), 0) + + // The cached data is the new data. + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a * 2, b FROM jt").collect()) + + // Verify uncaching + caseInsensitiveContext.uncacheTable("jsonTable") + assertCached(sql("SELECT * FROM jsonTable"), 0) } test("it's not allowed to insert into a relation that is not an InsertableRelation") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 2a69d331b6e52..7428c3ce2f74e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -48,6 +48,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { StructField("b", StringType, nullable = false))) lazy val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") + lazy val testDF2 = (5 to 8).map(i => (i, s"val_$i")).toDF("a", "b") lazy val partitionedTestDF1 = (for { i <- 1 to 3 @@ -269,6 +270,30 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } + test("invalidate the cached table - non-partitioned table") { + withTempPath { file => + withTempTable("temp_datasource") { + sql( + s""" + |CREATE TEMPORARY TABLE temp_datasource (a int, b string) + |USING $dataSourceName + |OPTIONS ( + | path '${file.toString}' + |) + """.stripMargin) + + testDF.write.format(dataSourceName).mode(SaveMode.Overwrite).save(file.toString) + checkAnswer(sqlContext.table("temp_datasource"), testDF.orderBy("a").collect()) + + sqlContext.cacheTable("temp_datasource") + checkAnswer(sqlContext.table("temp_datasource"), testDF.orderBy("a").collect()) + + testDF2.write.format(dataSourceName).mode(SaveMode.Overwrite).save(file.toString) + checkAnswer(sqlContext.table("temp_datasource"), testDF2.orderBy("a").collect()) + } + } + } + test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { Seq.empty[(Int, String)].toDF().registerTempTable("t")