Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

val maybeV2Provider = lookupV2Provider()
if (maybeV2Provider.isDefined) {
if (partitioningColumns.nonEmpty) {
throw new AnalysisException(
"Cannot write data to TableProvider implementation if partition columns are specified.")
}

val provider = maybeV2Provider.get
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
provider, df.sparkSession.sessionState.conf)
Expand All @@ -267,6 +262,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
provider.getTable(dsOptions) match {
case table: SupportsWrite if table.supports(BATCH_WRITE) =>
if (partitioningColumns.nonEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! Even if the format is a TableProvider, we may still fall back to v1. It's better to check the partition when we are really going to do a v2 write.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, technically we only need to assert no partition columns for append and overwrite. Since the v2 write here only supports append and overwrite, we can revisit it later.

throw new AnalysisException("Cannot write data to TableProvider implementation " +
"if partition columns are specified.")
}
lazy val relation = DataSourceV2Relation.create(table, dsOptions)
modeForDSV2 match {
case SaveMode.Append =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic}
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2}
import org.apache.spark.sql.types.BooleanType
Expand All @@ -32,6 +33,10 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) {

private def failAnalysis(msg: String): Unit = throw new AnalysisException(msg)

private def supportsBatchWrite(table: Table): Boolean = {
table.supportsAny(BATCH_WRITE, V1_BATCH_WRITE)
}

override def apply(plan: LogicalPlan): Unit = {
plan foreach {
case r: DataSourceV2Relation if !r.table.supports(BATCH_READ) =>
Expand All @@ -43,8 +48,7 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) {

// TODO: check STREAMING_WRITE capability. It's not doable now because we don't have a
// a logical plan for streaming write.

case AppendData(r: DataSourceV2Relation, _, _, _) if !r.table.supports(BATCH_WRITE) =>
case AppendData(r: DataSourceV2Relation, _, _, _) if !supportsBatchWrite(r.table) =>
failAnalysis(s"Table ${r.table.name()} does not support append in batch mode.")

case OverwritePartitionsDynamic(r: DataSourceV2Relation, _, _, _)
Expand All @@ -54,13 +58,13 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) {
case OverwriteByExpression(r: DataSourceV2Relation, expr, _, _, _) =>
expr match {
case Literal(true, BooleanType) =>
if (!r.table.supports(BATCH_WRITE) ||
!r.table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER)) {
if (!supportsBatchWrite(r.table) ||
!r.table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER)) {
failAnalysis(
s"Table ${r.table.name()} does not support truncate in batch mode.")
}
case _ =>
if (!r.table.supports(BATCH_WRITE) || !r.table.supports(OVERWRITE_BY_FILTER)) {
if (!supportsBatchWrite(r.table) || !r.table.supports(OVERWRITE_BY_FILTER)) {
failAnalysis(s"Table ${r.table.name()} does not support " +
"overwrite by filter in batch mode.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,19 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession {
}

test("AppendData: check correct capabilities") {
val plan = AppendData.byName(
DataSourceV2Relation.create(CapabilityTable(BATCH_WRITE), CaseInsensitiveStringMap.empty),
TestRelation)
Seq(BATCH_WRITE, V1_BATCH_WRITE).foreach { write =>
val plan = AppendData.byName(
DataSourceV2Relation.create(CapabilityTable(write), CaseInsensitiveStringMap.empty),
TestRelation)

TableCapabilityCheck.apply(plan)
TableCapabilityCheck.apply(plan)
}
}

test("Truncate: check missing capabilities") {
Seq(CapabilityTable(),
CapabilityTable(BATCH_WRITE),
CapabilityTable(V1_BATCH_WRITE),
CapabilityTable(TRUNCATE),
CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table =>

Expand All @@ -125,7 +128,9 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession {

test("Truncate: check correct capabilities") {
Seq(CapabilityTable(BATCH_WRITE, TRUNCATE),
CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table =>
CapabilityTable(V1_BATCH_WRITE, TRUNCATE),
CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER),
CapabilityTable(V1_BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table =>

val plan = OverwriteByExpression.byName(
DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation,
Expand All @@ -137,6 +142,7 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession {

test("OverwriteByExpression: check missing capabilities") {
Seq(CapabilityTable(),
CapabilityTable(V1_BATCH_WRITE),
CapabilityTable(BATCH_WRITE),
CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table =>

Expand All @@ -153,12 +159,14 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession {
}

test("OverwriteByExpression: check correct capabilities") {
val table = CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER)
val plan = OverwriteByExpression.byName(
DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation,
EqualTo(AttributeReference("x", LongType)(), Literal(5)))
Seq(BATCH_WRITE, V1_BATCH_WRITE).foreach { write =>
val table = CapabilityTable(write, OVERWRITE_BY_FILTER)
val plan = OverwriteByExpression.byName(
DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation,
EqualTo(AttributeReference("x", LongType)(), Literal(5)))

TableCapabilityCheck.apply(plan)
TableCapabilityCheck.apply(plan)
}
}

test("OverwritePartitionsDynamic: check missing capabilities") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ import scala.collection.mutable

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode, SparkSession}
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.connector.write.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder}
import org.apache.spark.sql.sources.{DataSourceRegister, Filter, InsertableRelation}
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand All @@ -52,7 +53,11 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before
test("append fallback") {
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
df.write.mode("append").option("name", "t1").format(v2Format).save()

checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df)
assert(InMemoryV1Provider.tables("t1").schema === df.schema.asNullable)
assert(InMemoryV1Provider.tables("t1").partitioning.isEmpty)

df.write.mode("append").option("name", "t1").format(v2Format).save()
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df.union(df))
}
Expand All @@ -65,6 +70,59 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before
df2.write.mode("overwrite").option("name", "t1").format(v2Format).save()
checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df2)
}

SaveMode.values().foreach { mode =>
test(s"save: new table creations with partitioning for table - mode: $mode") {
val format = classOf[InMemoryV1Provider].getName
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
df.write.mode(mode).option("name", "t1").format(format).partitionBy("a").save()

checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df)
assert(InMemoryV1Provider.tables("t1").schema === df.schema.asNullable)
assert(InMemoryV1Provider.tables("t1").partitioning.sameElements(
Array(IdentityTransform(FieldReference(Seq("a"))))))
}
}

test("save: default mode is ErrorIfExists") {
val format = classOf[InMemoryV1Provider].getName
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")

df.write.option("name", "t1").format(format).partitionBy("a").save()
// default is ErrorIfExists, and since a table already exists we throw an exception
val e = intercept[AnalysisException] {
df.write.option("name", "t1").format(format).partitionBy("a").save()
}
assert(e.getMessage.contains("already exists"))
}

test("save: Ignore mode") {
val format = classOf[InMemoryV1Provider].getName
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")

df.write.option("name", "t1").format(format).partitionBy("a").save()
// no-op
df.write.option("name", "t1").format(format).mode("ignore").partitionBy("a").save()

checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df)
}

test("save: tables can perform schema and partitioning checks if they already exist") {
val format = classOf[InMemoryV1Provider].getName
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")

df.write.option("name", "t1").format(format).partitionBy("a").save()
val e2 = intercept[IllegalArgumentException] {
df.write.mode("append").option("name", "t1").format(format).partitionBy("b").save()
}
assert(e2.getMessage.contains("partitioning"))

val e3 = intercept[IllegalArgumentException] {
Seq((1, "x")).toDF("c", "d").write.mode("append").option("name", "t1").format(format)
.save()
}
assert(e3.getMessage.contains("schema"))
}
}

class V1WriteFallbackSessionCatalogSuite
Expand Down Expand Up @@ -114,26 +172,83 @@ private object InMemoryV1Provider {
}
}

class InMemoryV1Provider extends TableProvider with DataSourceRegister {
class InMemoryV1Provider
extends TableProvider
with DataSourceRegister
with CreatableRelationProvider {
override def getTable(options: CaseInsensitiveStringMap): Table = {
InMemoryV1Provider.tables.getOrElseUpdate(options.get("name"), {

InMemoryV1Provider.tables.getOrElse(options.get("name"), {
new InMemoryTableWithV1Fallback(
"InMemoryTableWithV1Fallback",
new StructType().add("a", IntegerType).add("b", StringType),
Array(IdentityTransform(FieldReference(Seq("a")))),
new StructType(),
Array.empty,
options.asCaseSensitiveMap()
)
})
}

override def shortName(): String = "in-memory"

override def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
val _sqlContext = sqlContext

val partitioning = parameters.get(DataSourceUtils.PARTITIONING_COLUMNS_KEY).map { value =>
DataSourceUtils.decodePartitioningColumns(value).map { partitioningColumn =>
IdentityTransform(FieldReference(partitioningColumn))
}
}.getOrElse(Nil)

val tableName = parameters("name")
val tableOpt = InMemoryV1Provider.tables.get(tableName)
val table = tableOpt.getOrElse(new InMemoryTableWithV1Fallback(
"InMemoryTableWithV1Fallback",
data.schema.asNullable,
partitioning.toArray,
Map.empty[String, String].asJava
))
if (tableOpt.isEmpty) {
InMemoryV1Provider.tables.put(tableName, table)
} else {
if (data.schema.asNullable != table.schema) {
throw new IllegalArgumentException("Wrong schema provided")
}
if (!partitioning.sameElements(table.partitioning)) {
throw new IllegalArgumentException("Wrong partitioning provided")
}
}

def getRelation: BaseRelation = new BaseRelation {
override def sqlContext: SQLContext = _sqlContext
override def schema: StructType = table.schema
}

if (mode == SaveMode.ErrorIfExists && tableOpt.isDefined) {
throw new AnalysisException("Table already exists")
} else if (mode == SaveMode.Ignore && tableOpt.isDefined) {
// do nothing
return getRelation
}
val writer = table.newWriteBuilder(new CaseInsensitiveStringMap(parameters.asJava))
if (mode == SaveMode.Overwrite) {
writer.asInstanceOf[SupportsTruncate].truncate()
}
writer.asInstanceOf[V1WriteBuilder].buildForV1Write().insert(data, overwrite = false)
getRelation
}
}

class InMemoryTableWithV1Fallback(
override val name: String,
override val schema: StructType,
override val partitioning: Array[Transform],
override val properties: util.Map[String, String]) extends Table with SupportsWrite {
override val properties: util.Map[String, String])
extends Table
with SupportsWrite {

partitioning.foreach { t =>
if (!t.isInstanceOf[IdentityTransform]) {
Expand All @@ -142,7 +257,6 @@ class InMemoryTableWithV1Fallback(
}

override def capabilities: util.Set[TableCapability] = Set(
TableCapability.BATCH_WRITE,
TableCapability.V1_BATCH_WRITE,
TableCapability.OVERWRITE_BY_FILTER,
TableCapability.TRUNCATE).asJava
Expand Down