diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 289a976c6db9e..edf0963e71e81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -253,11 +253,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val maybeV2Provider = lookupV2Provider() if (maybeV2Provider.isDefined) { - if (partitioningColumns.nonEmpty) { - throw new AnalysisException( - "Cannot write data to TableProvider implementation if partition columns are specified.") - } - val provider = maybeV2Provider.get val sessionOptions = DataSourceV2Utils.extractSessionConfigs( provider, df.sparkSession.sessionState.conf) @@ -267,6 +262,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ provider.getTable(dsOptions) match { case table: SupportsWrite if table.supports(BATCH_WRITE) => + if (partitioningColumns.nonEmpty) { + throw new AnalysisException("Cannot write data to TableProvider implementation " + + "if partition columns are specified.") + } lazy val relation = DataSourceV2Relation.create(table, dsOptions) modeForDSV2 match { case SaveMode.Append => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala index 45ca3dfb9cb93..509a5f7139cca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} +import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.types.BooleanType @@ -32,6 +33,10 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) { private def failAnalysis(msg: String): Unit = throw new AnalysisException(msg) + private def supportsBatchWrite(table: Table): Boolean = { + table.supportsAny(BATCH_WRITE, V1_BATCH_WRITE) + } + override def apply(plan: LogicalPlan): Unit = { plan foreach { case r: DataSourceV2Relation if !r.table.supports(BATCH_READ) => @@ -43,8 +48,7 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) { // TODO: check STREAMING_WRITE capability. It's not doable now because we don't have a // a logical plan for streaming write. - - case AppendData(r: DataSourceV2Relation, _, _, _) if !r.table.supports(BATCH_WRITE) => + case AppendData(r: DataSourceV2Relation, _, _, _) if !supportsBatchWrite(r.table) => failAnalysis(s"Table ${r.table.name()} does not support append in batch mode.") case OverwritePartitionsDynamic(r: DataSourceV2Relation, _, _, _) @@ -54,13 +58,13 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) { case OverwriteByExpression(r: DataSourceV2Relation, expr, _, _, _) => expr match { case Literal(true, BooleanType) => - if (!r.table.supports(BATCH_WRITE) || - !r.table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER)) { + if (!supportsBatchWrite(r.table) || + !r.table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER)) { failAnalysis( s"Table ${r.table.name()} does not support truncate in batch mode.") } case _ => - if (!r.table.supports(BATCH_WRITE) || !r.table.supports(OVERWRITE_BY_FILTER)) { + if (!supportsBatchWrite(r.table) || !r.table.supports(OVERWRITE_BY_FILTER)) { failAnalysis(s"Table ${r.table.name()} does not support " + "overwrite by filter in batch mode.") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index 39f4085a9baf9..ce6d56cf84df1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -98,16 +98,19 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { } test("AppendData: check correct capabilities") { - val plan = AppendData.byName( - DataSourceV2Relation.create(CapabilityTable(BATCH_WRITE), CaseInsensitiveStringMap.empty), - TestRelation) + Seq(BATCH_WRITE, V1_BATCH_WRITE).foreach { write => + val plan = AppendData.byName( + DataSourceV2Relation.create(CapabilityTable(write), CaseInsensitiveStringMap.empty), + TestRelation) - TableCapabilityCheck.apply(plan) + TableCapabilityCheck.apply(plan) + } } test("Truncate: check missing capabilities") { Seq(CapabilityTable(), CapabilityTable(BATCH_WRITE), + CapabilityTable(V1_BATCH_WRITE), CapabilityTable(TRUNCATE), CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table => @@ -125,7 +128,9 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { test("Truncate: check correct capabilities") { Seq(CapabilityTable(BATCH_WRITE, TRUNCATE), - CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table => + CapabilityTable(V1_BATCH_WRITE, TRUNCATE), + CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER), + CapabilityTable(V1_BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table => val plan = OverwriteByExpression.byName( DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, @@ -137,6 +142,7 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { test("OverwriteByExpression: check missing capabilities") { Seq(CapabilityTable(), + CapabilityTable(V1_BATCH_WRITE), CapabilityTable(BATCH_WRITE), CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table => @@ -153,12 +159,14 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { } test("OverwriteByExpression: check correct capabilities") { - val table = CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER) - val plan = OverwriteByExpression.byName( - DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, - EqualTo(AttributeReference("x", LongType)(), Literal(5))) + Seq(BATCH_WRITE, V1_BATCH_WRITE).foreach { write => + val table = CapabilityTable(write, OVERWRITE_BY_FILTER) + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + EqualTo(AttributeReference("x", LongType)(), Literal(5))) - TableCapabilityCheck.apply(plan) + TableCapabilityCheck.apply(plan) + } } test("OverwritePartitionsDynamic: check missing capabilities") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 7cd6ba21b56ec..de843ba4375d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -24,11 +24,12 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode, SparkSession} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.connector.write.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} -import org.apache.spark.sql.sources.{DataSourceRegister, Filter, InsertableRelation} +import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils} +import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -52,7 +53,11 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before test("append fallback") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") df.write.mode("append").option("name", "t1").format(v2Format).save() + checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df) + assert(InMemoryV1Provider.tables("t1").schema === df.schema.asNullable) + assert(InMemoryV1Provider.tables("t1").partitioning.isEmpty) + df.write.mode("append").option("name", "t1").format(v2Format).save() checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df.union(df)) } @@ -65,6 +70,59 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before df2.write.mode("overwrite").option("name", "t1").format(v2Format).save() checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df2) } + + SaveMode.values().foreach { mode => + test(s"save: new table creations with partitioning for table - mode: $mode") { + val format = classOf[InMemoryV1Provider].getName + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + df.write.mode(mode).option("name", "t1").format(format).partitionBy("a").save() + + checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df) + assert(InMemoryV1Provider.tables("t1").schema === df.schema.asNullable) + assert(InMemoryV1Provider.tables("t1").partitioning.sameElements( + Array(IdentityTransform(FieldReference(Seq("a")))))) + } + } + + test("save: default mode is ErrorIfExists") { + val format = classOf[InMemoryV1Provider].getName + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + + df.write.option("name", "t1").format(format).partitionBy("a").save() + // default is ErrorIfExists, and since a table already exists we throw an exception + val e = intercept[AnalysisException] { + df.write.option("name", "t1").format(format).partitionBy("a").save() + } + assert(e.getMessage.contains("already exists")) + } + + test("save: Ignore mode") { + val format = classOf[InMemoryV1Provider].getName + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + + df.write.option("name", "t1").format(format).partitionBy("a").save() + // no-op + df.write.option("name", "t1").format(format).mode("ignore").partitionBy("a").save() + + checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df) + } + + test("save: tables can perform schema and partitioning checks if they already exist") { + val format = classOf[InMemoryV1Provider].getName + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + + df.write.option("name", "t1").format(format).partitionBy("a").save() + val e2 = intercept[IllegalArgumentException] { + df.write.mode("append").option("name", "t1").format(format).partitionBy("b").save() + } + assert(e2.getMessage.contains("partitioning")) + + val e3 = intercept[IllegalArgumentException] { + Seq((1, "x")).toDF("c", "d").write.mode("append").option("name", "t1").format(format) + .save() + } + assert(e3.getMessage.contains("schema")) + } } class V1WriteFallbackSessionCatalogSuite @@ -114,26 +172,83 @@ private object InMemoryV1Provider { } } -class InMemoryV1Provider extends TableProvider with DataSourceRegister { +class InMemoryV1Provider + extends TableProvider + with DataSourceRegister + with CreatableRelationProvider { override def getTable(options: CaseInsensitiveStringMap): Table = { - InMemoryV1Provider.tables.getOrElseUpdate(options.get("name"), { + + InMemoryV1Provider.tables.getOrElse(options.get("name"), { new InMemoryTableWithV1Fallback( "InMemoryTableWithV1Fallback", - new StructType().add("a", IntegerType).add("b", StringType), - Array(IdentityTransform(FieldReference(Seq("a")))), + new StructType(), + Array.empty, options.asCaseSensitiveMap() ) }) } override def shortName(): String = "in-memory" + + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val _sqlContext = sqlContext + + val partitioning = parameters.get(DataSourceUtils.PARTITIONING_COLUMNS_KEY).map { value => + DataSourceUtils.decodePartitioningColumns(value).map { partitioningColumn => + IdentityTransform(FieldReference(partitioningColumn)) + } + }.getOrElse(Nil) + + val tableName = parameters("name") + val tableOpt = InMemoryV1Provider.tables.get(tableName) + val table = tableOpt.getOrElse(new InMemoryTableWithV1Fallback( + "InMemoryTableWithV1Fallback", + data.schema.asNullable, + partitioning.toArray, + Map.empty[String, String].asJava + )) + if (tableOpt.isEmpty) { + InMemoryV1Provider.tables.put(tableName, table) + } else { + if (data.schema.asNullable != table.schema) { + throw new IllegalArgumentException("Wrong schema provided") + } + if (!partitioning.sameElements(table.partitioning)) { + throw new IllegalArgumentException("Wrong partitioning provided") + } + } + + def getRelation: BaseRelation = new BaseRelation { + override def sqlContext: SQLContext = _sqlContext + override def schema: StructType = table.schema + } + + if (mode == SaveMode.ErrorIfExists && tableOpt.isDefined) { + throw new AnalysisException("Table already exists") + } else if (mode == SaveMode.Ignore && tableOpt.isDefined) { + // do nothing + return getRelation + } + val writer = table.newWriteBuilder(new CaseInsensitiveStringMap(parameters.asJava)) + if (mode == SaveMode.Overwrite) { + writer.asInstanceOf[SupportsTruncate].truncate() + } + writer.asInstanceOf[V1WriteBuilder].buildForV1Write().insert(data, overwrite = false) + getRelation + } } class InMemoryTableWithV1Fallback( override val name: String, override val schema: StructType, override val partitioning: Array[Transform], - override val properties: util.Map[String, String]) extends Table with SupportsWrite { + override val properties: util.Map[String, String]) + extends Table + with SupportsWrite { partitioning.foreach { t => if (!t.isInstanceOf[IdentityTransform]) { @@ -142,7 +257,6 @@ class InMemoryTableWithV1Fallback( } override def capabilities: util.Set[TableCapability] = Set( - TableCapability.BATCH_WRITE, TableCapability.V1_BATCH_WRITE, TableCapability.OVERWRITE_BY_FILTER, TableCapability.TRUNCATE).asJava