diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 34bc80cf9026a..dffe0a838d80f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartit import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, TableCapability, TableCatalog, TableChange} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, Table, TableCapability, TableCatalog, TableChange} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -78,6 +78,11 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } } + private def invalidateCache(catalog: TableCatalog, table: Table, ident: Identifier): Unit = { + val v2Relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) + session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true) + } + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) => @@ -161,10 +166,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat catalog match { case staging: StagingTableCatalog => AtomicReplaceTableExec( - staging, ident, schema, parts, propsWithOwner, orCreate = orCreate) :: Nil + staging, ident, schema, parts, propsWithOwner, orCreate = orCreate, + invalidateCache) :: Nil case _ => ReplaceTableExec( - catalog, ident, schema, parts, propsWithOwner, orCreate = orCreate) :: Nil + catalog, ident, schema, parts, propsWithOwner, orCreate = orCreate, + invalidateCache) :: Nil } case ReplaceTableAsSelect(catalog, ident, parts, query, props, options, orCreate) => @@ -173,7 +180,6 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat catalog match { case staging: StagingTableCatalog => AtomicReplaceTableAsSelectExec( - session, staging, ident, parts, @@ -181,10 +187,10 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat planLater(query), propsWithOwner, writeOptions, - orCreate = orCreate) :: Nil + orCreate = orCreate, + invalidateCache) :: Nil case _ => ReplaceTableAsSelectExec( - session, catalog, ident, parts, @@ -192,7 +198,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat planLater(query), propsWithOwner, writeOptions, - orCreate = orCreate) :: Nil + orCreate = orCreate, + invalidateCache) :: Nil } case AppendData(r: DataSourceV2Relation, query, writeOptions, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index 1f3bcf2e3fe57..10c09f4be711f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, TableCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -33,10 +33,13 @@ case class ReplaceTableExec( tableSchema: StructType, partitioning: Seq[Transform], tableProperties: Map[String, String], - orCreate: Boolean) extends V2CommandExec { + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec { override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) catalog.dropTable(ident) } else if (!orCreate) { throw new CannotReplaceMissingTableException(ident) @@ -54,9 +57,14 @@ case class AtomicReplaceTableExec( tableSchema: StructType, partitioning: Seq[Transform], tableProperties: Map[String, String], - orCreate: Boolean) extends V2CommandExec { + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec { override protected def run(): Seq[InternalRow] = { + if (catalog.tableExists(identifier)) { + val table = catalog.loadTable(identifier) + invalidateCache(catalog, table, identifier) + } val staged = if (orCreate) { catalog.stageCreateOrReplace( identifier, tableSchema, partitioning.toArray, tableProperties.asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index b0aff4a6b763e..a41b048f38cce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,7 +26,6 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.expressions.Attribute @@ -131,7 +130,6 @@ case class AtomicCreateTableAsSelectExec( * ReplaceTableAsSelectStagingExec. */ case class ReplaceTableAsSelectExec( - session: SparkSession, catalog: TableCatalog, ident: Identifier, partitioning: Seq[Transform], @@ -139,7 +137,8 @@ case class ReplaceTableAsSelectExec( query: SparkPlan, properties: Map[String, String], writeOptions: CaseInsensitiveStringMap, - orCreate: Boolean) extends TableWriteExecHelper { + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends TableWriteExecHelper { override protected def run(): Seq[InternalRow] = { // Note that this operation is potentially unsafe, but these are the strict semantics of @@ -152,7 +151,7 @@ case class ReplaceTableAsSelectExec( // 3. The table returned by catalog.createTable doesn't support writing. if (catalog.tableExists(ident)) { val table = catalog.loadTable(ident) - uncacheTable(session, catalog, table, ident) + invalidateCache(catalog, table, ident) catalog.dropTable(ident) } else if (!orCreate) { throw new CannotReplaceMissingTableException(ident) @@ -177,7 +176,6 @@ case class ReplaceTableAsSelectExec( * is left untouched. */ case class AtomicReplaceTableAsSelectExec( - session: SparkSession, catalog: StagingTableCatalog, ident: Identifier, partitioning: Seq[Transform], @@ -185,13 +183,14 @@ case class AtomicReplaceTableAsSelectExec( query: SparkPlan, properties: Map[String, String], writeOptions: CaseInsensitiveStringMap, - orCreate: Boolean) extends TableWriteExecHelper { + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends TableWriteExecHelper { override protected def run(): Seq[InternalRow] = { val schema = CharVarcharUtils.getRawSchema(query.schema).asNullable if (catalog.tableExists(ident)) { val table = catalog.loadTable(ident) - uncacheTable(session, catalog, table, ident) + invalidateCache(catalog, table, ident) } val staged = if (orCreate) { catalog.stageCreateOrReplace( @@ -393,15 +392,6 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { Nil } - - protected def uncacheTable( - session: SparkSession, - catalog: TableCatalog, - table: Table, - ident: Identifier): Unit = { - val plan = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) - session.sharedState.cacheManager.uncacheQuery(session, plan, cascade = true) - } } object DataWritingSparkTask extends Logging { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 92edd2b6209dc..f0f6e7cc25ad0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -785,6 +785,23 @@ class DataSourceV2SQLSuite } } + test("SPARK-34039: ReplaceTable (atomic or non-atomic) should invalidate cache") { + Seq("testcat.ns.t", "testcat_atomic.ns.t").foreach { t => + val view = "view" + withTable(t) { + withTempView(view) { + sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source") + sql(s"CACHE TABLE $view AS SELECT id FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source")) + checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id")) + + sql(s"REPLACE TABLE $t (a bigint) USING foo") + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isEmpty) + } + } + } + } + test("SPARK-33492: ReplaceTableAsSelect (atomic or non-atomic) should invalidate cache") { Seq("testcat.ns.t", "testcat_atomic.ns.t").foreach { t => val view = "view"