diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala index 8c14b5e370736..691629e64956a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.Locale +import scala.util.control.NonFatal + import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.LocalTempView @@ -26,8 +28,10 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper -import org.apache.spark.sql.execution.command.CreateViewCommand +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.execution.command.{CreateViewCommand, DropTempViewCommand} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils trait BaseCacheTableExec extends LeafV2CommandExec { def relationName: String @@ -60,7 +64,16 @@ trait BaseCacheTableExec extends LeafV2CommandExec { if (!isLazy) { // Performs eager caching. - dataFrameForCachedPlan.count() + try { + dataFrameForCachedPlan.count() + } catch { + case NonFatal(e) => + // If the query fails, we should remove the cached table. + Utils.tryLogNonFatalError { + session.sharedState.cacheManager.uncacheQuery(session, planToCache, cascade = false) + } + throw e + } } Seq.empty @@ -113,6 +126,18 @@ case class CacheTableAsSelectExec( override lazy val dataFrameForCachedPlan: DataFrame = { session.table(tempViewName) } + + override def run(): Seq[InternalRow] = { + try { + super.run() + } catch { + case NonFatal(e) => + Utils.tryLogNonFatalError { + DropTempViewCommand(Identifier.of(Array.empty, tempViewName)).run(session) + } + throw e + } + } } case class UncacheTableExec( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 9815cb816c994..e54947266951b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -26,7 +26,7 @@ import scala.concurrent.duration._ import org.apache.commons.io.FileUtils -import org.apache.spark.CleanerListener +import org.apache.spark.{CleanerListener, SparkException} import org.apache.spark.executor.DataReadMethod._ import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} @@ -1729,4 +1729,13 @@ class CachedTableSuite extends QueryTest with SQLTestUtils } } + + test("SPARK-52684: Atomicity of cache table on error") { + withTempView("SPARK_52684") { + intercept[SparkException] { + spark.sql("CACHE TABLE SPARK_52684 AS SELECT raise_error('SPARK-52684') AS c1") + } + assert(!spark.catalog.tableExists("SPARK_52684")) + } + } }