Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartit
import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, TableCapability, TableCatalog, TableChange}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, Table, TableCapability, TableCatalog, TableChange}
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
Expand Down Expand Up @@ -78,6 +78,11 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
}
}

private def invalidateCache(catalog: TableCatalog, table: Table, ident: Identifier): Unit = {
val v2Relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident))
session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true)
}

override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(project, filters,
relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) =>
Expand Down Expand Up @@ -161,10 +166,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
catalog match {
case staging: StagingTableCatalog =>
AtomicReplaceTableExec(
staging, ident, schema, parts, propsWithOwner, orCreate = orCreate) :: Nil
staging, ident, schema, parts, propsWithOwner, orCreate = orCreate,
invalidateCache) :: Nil
case _ =>
ReplaceTableExec(
catalog, ident, schema, parts, propsWithOwner, orCreate = orCreate) :: Nil
catalog, ident, schema, parts, propsWithOwner, orCreate = orCreate,
invalidateCache) :: Nil
}

case ReplaceTableAsSelect(catalog, ident, parts, query, props, options, orCreate) =>
Expand All @@ -173,26 +180,26 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
catalog match {
case staging: StagingTableCatalog =>
AtomicReplaceTableAsSelectExec(
session,
staging,
ident,
parts,
query,
planLater(query),
propsWithOwner,
writeOptions,
orCreate = orCreate) :: Nil
orCreate = orCreate,
invalidateCache) :: Nil
case _ =>
ReplaceTableAsSelectExec(
session,
catalog,
ident,
parts,
query,
planLater(query),
propsWithOwner,
writeOptions,
orCreate = orCreate) :: Nil
orCreate = orCreate,
invalidateCache) :: Nil
}

case AppendData(r: DataSourceV2Relation, query, writeOptions, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, TableCatalog}
import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
Expand All @@ -33,10 +33,13 @@ case class ReplaceTableExec(
tableSchema: StructType,
partitioning: Seq[Transform],
tableProperties: Map[String, String],
orCreate: Boolean) extends V2CommandExec {
orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec {

override protected def run(): Seq[InternalRow] = {
if (catalog.tableExists(ident)) {
val table = catalog.loadTable(ident)
invalidateCache(catalog, table, ident)
catalog.dropTable(ident)
} else if (!orCreate) {
throw new CannotReplaceMissingTableException(ident)
Expand All @@ -54,9 +57,14 @@ case class AtomicReplaceTableExec(
tableSchema: StructType,
partitioning: Seq[Transform],
tableProperties: Map[String, String],
orCreate: Boolean) extends V2CommandExec {
orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec {

override protected def run(): Seq[InternalRow] = {
if (catalog.tableExists(identifier)) {
val table = catalog.loadTable(identifier)
invalidateCache(catalog, table, identifier)
}
val staged = if (orCreate) {
catalog.stageCreateOrReplace(
identifier, tableSchema, partitioning.toArray, tableProperties.asJava)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.expressions.Attribute
Expand Down Expand Up @@ -131,15 +130,15 @@ case class AtomicCreateTableAsSelectExec(
* ReplaceTableAsSelectStagingExec.
*/
case class ReplaceTableAsSelectExec(
session: SparkSession,
catalog: TableCatalog,
ident: Identifier,
partitioning: Seq[Transform],
plan: LogicalPlan,
query: SparkPlan,
properties: Map[String, String],
writeOptions: CaseInsensitiveStringMap,
orCreate: Boolean) extends TableWriteExecHelper {
orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends TableWriteExecHelper {

override protected def run(): Seq[InternalRow] = {
// Note that this operation is potentially unsafe, but these are the strict semantics of
Expand All @@ -152,7 +151,7 @@ case class ReplaceTableAsSelectExec(
// 3. The table returned by catalog.createTable doesn't support writing.
if (catalog.tableExists(ident)) {
val table = catalog.loadTable(ident)
uncacheTable(session, catalog, table, ident)
invalidateCache(catalog, table, ident)
catalog.dropTable(ident)
} else if (!orCreate) {
throw new CannotReplaceMissingTableException(ident)
Expand All @@ -177,21 +176,21 @@ case class ReplaceTableAsSelectExec(
* is left untouched.
*/
case class AtomicReplaceTableAsSelectExec(
session: SparkSession,
catalog: StagingTableCatalog,
ident: Identifier,
partitioning: Seq[Transform],
plan: LogicalPlan,
query: SparkPlan,
properties: Map[String, String],
writeOptions: CaseInsensitiveStringMap,
orCreate: Boolean) extends TableWriteExecHelper {
orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends TableWriteExecHelper {

override protected def run(): Seq[InternalRow] = {
val schema = CharVarcharUtils.getRawSchema(query.schema).asNullable
if (catalog.tableExists(ident)) {
val table = catalog.loadTable(ident)
uncacheTable(session, catalog, table, ident)
invalidateCache(catalog, table, ident)
}
val staged = if (orCreate) {
catalog.stageCreateOrReplace(
Expand Down Expand Up @@ -393,15 +392,6 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {

Nil
}

protected def uncacheTable(
session: SparkSession,
catalog: TableCatalog,
table: Table,
ident: Identifier): Unit = {
val plan = DataSourceV2Relation.create(table, Some(catalog), Some(ident))
session.sharedState.cacheManager.uncacheQuery(session, plan, cascade = true)
}
}

object DataWritingSparkTask extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,23 @@ class DataSourceV2SQLSuite
}
}

test("SPARK-34039: ReplaceTable (atomic or non-atomic) should invalidate cache") {
Seq("testcat.ns.t", "testcat_atomic.ns.t").foreach { t =>
val view = "view"
withTable(t) {
withTempView(view) {
sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source")
sql(s"CACHE TABLE $view AS SELECT id FROM $t")
checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source"))
checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id"))

sql(s"REPLACE TABLE $t (a bigint) USING foo")
assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isEmpty)
}
}
}
}

test("SPARK-33492: ReplaceTableAsSelect (atomic or non-atomic) should invalidate cache") {
Seq("testcat.ns.t", "testcat_atomic.ns.t").foreach { t =>
val view = "view"
Expand Down