Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ class KafkaRelationSuiteV2 extends KafkaRelationSuiteBase {
val topic = newTopic()
val df = createDF(topic)
assert(df.logicalPlan.collect {
case DataSourceV2Relation(_, _, _) => true
case _: DataSourceV2Relation => true
}.nonEmpty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,8 @@ class Analyzer(

case alter @ AlterTable(_, _, u: UnresolvedV2Relation, _) =>
CatalogV2Util.loadRelation(u.catalog, u.tableName)
.map(rel => alter.copy(table = rel))
.getOrElse(alter)
.map(rel => alter.copy(table = rel))
.getOrElse(alter)

case u: UnresolvedV2Relation =>
CatalogV2Util.loadRelation(u.catalog, u.tableName).getOrElse(u)
Expand All @@ -831,7 +831,8 @@ class Analyzer(
expandRelationName(identifier) match {
case NonSessionCatalogAndIdentifier(catalog, ident) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(table) => Some(DataSourceV2Relation.create(table))
case Some(table) =>
Some(DataSourceV2Relation.create(table, Some(catalog), Some(ident)))
case None => None
}
case _ => None
Expand Down Expand Up @@ -923,7 +924,7 @@ class Analyzer(
AnalysisContext.get.relationCache.getOrElseUpdate(
key, v1SessionCatalog.getRelation(v1Table.v1Table))
case table =>
DataSourceV2Relation.create(table)
DataSourceV2Relation.create(table, Some(catalog), Some(ident))
}
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ private[sql] object CatalogV2Util {
}

def loadRelation(catalog: CatalogPlugin, ident: Identifier): Option[NamedRelation] = {
loadTable(catalog, ident).map(DataSourceV2Relation.create)
loadTable(catalog, ident).map(DataSourceV2Relation.create(_, Some(catalog), Some(ident)))
}

def isSessionCatalog(catalog: CatalogPlugin): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelat
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.catalog.{Table, TableCapability}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, Table, TableCapability}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, Statistics => V2Statistics, SupportsReportStatistics}
import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream}
import org.apache.spark.sql.connector.write.WriteBuilder
Expand All @@ -32,12 +32,17 @@ import org.apache.spark.util.Utils
* A logical plan representing a data source v2 table.
*
* @param table The table that this relation represents.
* @param output the output attributes of this relation.
* @param catalog catalogPlugin for the table. None if no catalog is specified.
* @param identifier the identifier for the table. None if no identifier is defined.
* @param options The options for this table operation. It's used to create fresh [[ScanBuilder]]
* and [[WriteBuilder]].
*/
case class DataSourceV2Relation(
table: Table,
output: Seq[AttributeReference],
catalog: Option[CatalogPlugin],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want the catalog identifier here too, not the plugin

identifier: Option[Identifier],
options: CaseInsensitiveStringMap)
extends LeafNode with MultiInstanceRelation with NamedRelation {

Expand Down Expand Up @@ -137,12 +142,20 @@ case class StreamingDataSourceV2Relation(
}

object DataSourceV2Relation {
def create(table: Table, options: CaseInsensitiveStringMap): DataSourceV2Relation = {
def create(
table: Table,
catalog: Option[CatalogPlugin],
identifier: Option[Identifier],
options: CaseInsensitiveStringMap): DataSourceV2Relation = {
val output = table.schema().toAttributes
DataSourceV2Relation(table, output, options)
DataSourceV2Relation(table, output, catalog, identifier, options)
}

def create(table: Table): DataSourceV2Relation = create(table, CaseInsensitiveStringMap.empty)
def create(
table: Table,
catalog: Option[CatalogPlugin],
identifier: Option[Identifier]): DataSourceV2Relation =
create(table, catalog, identifier, CaseInsensitiveStringMap.empty)

/**
* This is used to transform data source v2 statistics to logical.Statistics.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.catalog

import org.mockito.Mockito.{mock, when}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.StructType

class CatalogV2UtilSuite extends SparkFunSuite {
test("Load relation should encode the identifiers for V2Relations") {
val testCatalog = mock(classOf[TableCatalog])
val ident = mock(classOf[Identifier])
val table = mock(classOf[Table])
when(table.schema()).thenReturn(mock(classOf[StructType]))
when(testCatalog.loadTable(ident)).thenReturn(table)
val r = CatalogV2Util.loadRelation(testCatalog, ident)
assert(r.isDefined)
assert(r.get.isInstanceOf[DataSourceV2Relation])
val v2Relation = r.get.asInstanceOf[DataSourceV2Relation]
assert(v2Relation.catalog.exists(_ == testCatalog))
assert(v2Relation.identifier.exists(_ == ident))
}
}
16 changes: 10 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}

DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).map { provider =>
val catalogManager = sparkSession.sessionState.catalogManager
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
source = provider, conf = sparkSession.sessionState.conf)
val pathsOption = if (paths.isEmpty) {
Expand All @@ -206,27 +207,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {

val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption
val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava)
val table = provider match {
val (table, catalog, ident) = provider match {
case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty =>
throw new IllegalArgumentException(
s"$source does not support user specified schema. Please don't specify the schema.")
case hasCatalog: SupportsCatalogOptions =>
val ident = hasCatalog.extractIdentifier(dsOptions)
val catalog = CatalogV2Util.getTableProviderCatalog(
hasCatalog,
sparkSession.sessionState.catalogManager,
catalogManager,
dsOptions)
catalog.loadTable(ident)
(catalog.loadTable(ident), Some(catalog), Some(ident))
case _ =>
// TODO: Non-catalog paths for DSV2 are currently not well defined.
userSpecifiedSchema match {
case Some(schema) => provider.getTable(dsOptions, schema)
case _ => provider.getTable(dsOptions)
case Some(schema) => (provider.getTable(dsOptions, schema), None, None)
case _ => (provider.getTable(dsOptions), None, None)
}
}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
table match {
case _: SupportsRead if table.supports(BATCH_READ) =>
Dataset.ofRows(sparkSession, DataSourceV2Relation.create(table, dsOptions))
Dataset.ofRows(
sparkSession,
DataSourceV2Relation.create(table, catalog, ident, dsOptions))

case _ => loadV1Source(paths: _*)
}
Expand Down
22 changes: 11 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -258,20 +258,20 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val dsOptions = new CaseInsensitiveStringMap(options.asJava)

import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
val catalogManager = df.sparkSession.sessionState.catalogManager
mode match {
case SaveMode.Append | SaveMode.Overwrite =>
val table = provider match {
val (table, catalog, ident) = provider match {
case supportsExtract: SupportsCatalogOptions =>
val ident = supportsExtract.extractIdentifier(dsOptions)
val sessionState = df.sparkSession.sessionState
val catalog = CatalogV2Util.getTableProviderCatalog(
supportsExtract, sessionState.catalogManager, dsOptions)
supportsExtract, catalogManager, dsOptions)

catalog.loadTable(ident)
(catalog.loadTable(ident), Some(catalog), Some(ident))
case tableProvider: TableProvider =>
val t = tableProvider.getTable(dsOptions)
if (t.supports(BATCH_WRITE)) {
t
(t, None, None)
} else {
// Streaming also uses the data source V2 API. So it may be that the data source
// implements v2, but has no v2 implementation for batch writes. In that case, we
Expand All @@ -280,7 +280,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
}

val relation = DataSourceV2Relation.create(table, dsOptions)
val relation = DataSourceV2Relation.create(table, catalog, ident, dsOptions)
checkPartitioningMatchesV2Table(table)
if (mode == SaveMode.Append) {
runCommand(df.sparkSession, "save") {
Expand All @@ -299,9 +299,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
provider match {
case supportsExtract: SupportsCatalogOptions =>
val ident = supportsExtract.extractIdentifier(dsOptions)
val sessionState = df.sparkSession.sessionState
val catalog = CatalogV2Util.getTableProviderCatalog(
supportsExtract, sessionState.catalogManager, dsOptions)
supportsExtract, catalogManager, dsOptions)

val location = Option(dsOptions.get("path")).map(TableCatalog.PROP_LOCATION -> _)

Expand Down Expand Up @@ -419,7 +418,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
case _: V1Table =>
return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption))
case t =>
DataSourceV2Relation.create(t)
DataSourceV2Relation.create(t, Some(catalog), Some(ident))
}

val command = mode match {
Expand Down Expand Up @@ -554,12 +553,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}

val command = (mode, tableOpt) match {
case (_, Some(table: V1Table)) =>
case (_, Some(_: V1Table)) =>
return saveAsTable(TableIdentifier(ident.name(), ident.namespace().headOption))

case (SaveMode.Append, Some(table)) =>
checkPartitioningMatchesV2Table(table)
AppendData.byName(DataSourceV2Relation.create(table), df.logicalPlan, extraOptions.toMap)
val v2Relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident))
AppendData.byName(v2Relation, df.logicalPlan, extraOptions.toMap)

case (SaveMode.Overwrite, _) =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little curious why Overwrite doesn't need to create a DataSourceV2Relation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It drops and recreates a table. It's a DDL operation instead of DML

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still probably a dumb question. Why does DDL/DML affects how we generate the query plan? I'm asking this because in the save() function for the DataFrameWriter, we do generate a DataSourceV2Relation for Overwrite mode. I'm curious about why there is such a difference here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in save, we don't update table information in a catalog. saveAsTable updates the catalog, therefore is doing different work. So we need to do a separate operation

ReplaceTableAsSelect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
def append(): Unit = {
val append = loadTable(catalog, identifier) match {
case Some(t) =>
AppendData.byName(DataSourceV2Relation.create(t), logicalPlan, options.toMap)
AppendData.byName(
DataSourceV2Relation.create(t, Some(catalog), Some(identifier)),
logicalPlan, options.toMap)
case _ =>
throw new NoSuchTableException(identifier)
}
Expand All @@ -181,7 +183,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
val overwrite = loadTable(catalog, identifier) match {
case Some(t) =>
OverwriteByExpression.byName(
DataSourceV2Relation.create(t), logicalPlan, condition.expr, options.toMap)
DataSourceV2Relation.create(t, Some(catalog), Some(identifier)),
logicalPlan, condition.expr, options.toMap)
case _ =>
throw new NoSuchTableException(identifier)
}
Expand All @@ -207,7 +210,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
val dynamicOverwrite = loadTable(catalog, identifier) match {
case Some(t) =>
OverwritePartitionsDynamic.byName(
DataSourceV2Relation.create(t), logicalPlan, options.toMap)
DataSourceV2Relation.create(t, Some(catalog), Some(identifier)),
logicalPlan, options.toMap)
case _ =>
throw new NoSuchTableException(identifier)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class CacheManager extends Logging {
case _ => false
}

case DataSourceV2Relation(fileTable: FileTable, _, _) =>
case DataSourceV2Relation(fileTable: FileTable, _, _, _, _) =>
refreshFileIndexIfNecessary(fileTable.fileIndex, fs, qualifiedPath)

case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, File
*/
class FallBackFileSourceV2(sparkSession: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(d @ DataSourceV2Relation(table: FileTable, _, _), _, _, _, _) =>
case i @
InsertIntoStatement(d @ DataSourceV2Relation(table: FileTable, _, _, _, _), _, _, _, _) =>
val v1FileFormat = table.fallbackFileFormat.newInstance()
val relation = HadoopFsRelation(
table.fileIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,11 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case desc @ DescribeNamespace(ResolvedNamespace(catalog, ns), extended) =>
DescribeNamespaceExec(desc.output, catalog, ns, extended) :: Nil

case desc @ DescribeRelation(ResolvedTable(_, _, table), partitionSpec, isExtended) =>
case desc @ DescribeRelation(r: ResolvedTable, partitionSpec, isExtended) =>
if (partitionSpec.nonEmpty) {
throw new AnalysisException("DESCRIBE does not support partition for v2 tables.")
}
DescribeTableExec(desc.output, table, isExtended) :: Nil
DescribeTableExec(desc.output, r.table, isExtended) :: Nil

case DropTable(catalog, ident, ifExists) =>
DropTableExec(catalog, ident, ifExists) :: Nil
Expand Down Expand Up @@ -284,8 +284,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case r: ShowCurrentNamespace =>
ShowCurrentNamespaceExec(r.output, r.catalogManager) :: Nil

case r @ ShowTableProperties(ResolvedTable(_, _, table), propertyKey) =>
ShowTablePropertiesExec(r.output, table, propertyKey) :: Nil
case r @ ShowTableProperties(rt: ResolvedTable, propertyKey) =>
ShowTablePropertiesExec(r.output, rt.table, propertyKey) :: Nil

case _ => Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ import scala.collection.JavaConverters._
import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic}
import org.apache.spark.sql.connector.InMemoryTableCatalog
import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.util.Utils

class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with BeforeAndAfter {
Expand Down Expand Up @@ -54,6 +58,45 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
spark.sessionState.conf.clear()
}

test("DataFrameWriteV2 encode identifiers correctly") {
spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")

var plan: LogicalPlan = null
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
plan = qe.analyzed

}
override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {}
}
spark.listenerManager.register(listener)

spark.table("source").writeTo("testcat.table_name").append()
sparkContext.listenerBus.waitUntilEmpty()
assert(plan.isInstanceOf[AppendData])
checkV2Identifiers(plan.asInstanceOf[AppendData].table)

spark.table("source").writeTo("testcat.table_name").overwrite(lit(true))
sparkContext.listenerBus.waitUntilEmpty()
assert(plan.isInstanceOf[OverwriteByExpression])
checkV2Identifiers(plan.asInstanceOf[OverwriteByExpression].table)

spark.table("source").writeTo("testcat.table_name").overwritePartitions()
sparkContext.listenerBus.waitUntilEmpty()
assert(plan.isInstanceOf[OverwritePartitionsDynamic])
checkV2Identifiers(plan.asInstanceOf[OverwritePartitionsDynamic].table)
}

private def checkV2Identifiers(
plan: LogicalPlan,
identifier: String = "table_name",
catalogPlugin: TableCatalog = catalog("testcat")): Unit = {
assert(plan.isInstanceOf[DataSourceV2Relation])
val v2 = plan.asInstanceOf[DataSourceV2Relation]
assert(v2.identifier.exists(_.name() == identifier))
assert(v2.catalog.exists(_ == catalogPlugin))
}

test("Append: basic append") {
spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")

Expand Down
Loading