From dd6d8ca1c091d00bfc29363ebc0d518b12927325 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 25 Feb 2017 05:58:40 +0000 Subject: [PATCH 1/2] refreshByPath should clear all cached plans with the specified path. --- .../spark/sql/execution/CacheManager.scala | 19 ++++++++------- .../apache/spark/sql/CachedTableSuite.scala | 24 +++++++++++++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 4ca1347008575..80138510dc9ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -168,15 +168,16 @@ class CacheManager extends Logging { (fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } - cachedData.foreach { - case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => - val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) - if (dataIndex >= 0) { - data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) - cachedData.remove(dataIndex) - } - sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) - case _ => // Do Nothing + cachedData.filter { + case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => true + case _ => false + }.foreach { data => + val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) + if (dataIndex >= 0) { + data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) + cachedData.remove(dataIndex) + } + sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1af1a3652971c..f3386fb5fb548 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -634,4 +634,28 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(getNumInMemoryRelations(cachedPlan2) == 4) } } + + test("refreshByPath should refresh all cached plans with the specified path") { + def f(path: String, spark: SparkSession, dataCount: Int): DataFrame = { + spark.catalog.refreshByPath(path) + val data = spark.read.parquet(path) + assert(data.count == dataCount) + val df = data.filter("id > 10") + df.cache + assert(df.count == dataCount - 11) + val df1 = df.filter("id > 11") + df1.cache + assert(df1.count == dataCount - 12) + df1 + } + + withTempDir { dir => + val path = dir.getPath() + spark.range(100).write.mode("overwrite").parquet(path) + assert(f(path, spark, 100).count == 88) + + spark.range(1000).write.mode("overwrite").parquet(path) + assert(f(path, spark, 1000).count == 988) + } + } } From a9fe0bee8cb136741e6e3f96b3cb9f5a6f051e94 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 1 Mar 2017 02:14:43 +0000 Subject: [PATCH 2/2] Address comments. --- .../apache/spark/sql/CachedTableSuite.scala | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index f3386fb5fb548..2a0e088437fda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -636,26 +636,18 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } test("refreshByPath should refresh all cached plans with the specified path") { - def f(path: String, spark: SparkSession, dataCount: Int): DataFrame = { - spark.catalog.refreshByPath(path) - val data = spark.read.parquet(path) - assert(data.count == dataCount) - val df = data.filter("id > 10") - df.cache - assert(df.count == dataCount - 11) - val df1 = df.filter("id > 11") - df1.cache - assert(df1.count == dataCount - 12) - df1 - } - withTempDir { dir => - val path = dir.getPath() - spark.range(100).write.mode("overwrite").parquet(path) - assert(f(path, spark, 100).count == 88) + val path = dir.getCanonicalPath() - spark.range(1000).write.mode("overwrite").parquet(path) - assert(f(path, spark, 1000).count == 988) + spark.range(10).write.mode("overwrite").parquet(path) + spark.read.parquet(path).cache() + spark.read.parquet(path).filter($"id" > 4).cache() + assert(spark.read.parquet(path).filter($"id" > 4).count() == 5) + + spark.range(20).write.mode("overwrite").parquet(path) + spark.catalog.refreshByPath(path) + assert(spark.read.parquet(path).count() == 20) + assert(spark.read.parquet(path).filter($"id" > 4).count() == 15) } } }