Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.catalog
import org.apache.hadoop.fs.Path
import org.apache.hadoop.util.Shell

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec

object ExternalCatalogUtils {
Expand Down Expand Up @@ -133,4 +135,39 @@ object CatalogUtils {
case o => o
}
}

def normalizePartCols(
tableName: String,
tableCols: Seq[String],
partCols: Seq[String],
resolver: Resolver): Seq[String] = {
partCols.map(normalizeColumnName(tableName, tableCols, _, "partition", resolver))
}

def normalizeBucketSpec(
tableName: String,
tableCols: Seq[String],
bucketSpec: BucketSpec,
resolver: Resolver): BucketSpec = {
val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec
val normalizedBucketCols = bucketColumnNames.map { colName =>
normalizeColumnName(tableName, tableCols, colName, "bucket", resolver)
}
val normalizedSortCols = sortColumnNames.map { colName =>
normalizeColumnName(tableName, tableCols, colName, "sort", resolver)
}
BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols)
}

private def normalizeColumnName(
tableName: String,
tableCols: Seq[String],
colName: String,
colType: String,
resolver: Resolver): String = {
tableCols.find(resolver(_, colName)).getOrElse {
throw new AnalysisException(s"$colType column $colName is not defined in table $tableName, " +
s"defined table columns are: ${tableCols.mkString(", ")}")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ case class BucketSpec(
if (numBuckets <= 0) {
throw new AnalysisException(s"Expected positive number of buckets, but got `$numBuckets`.")
}

override def toString: String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we implement toString here, we can simplify our logics in describeBucketingInfo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how? the toString returns a single line, while describeBucketingInfo generates 3 result lines.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

: ) If we want to keep the existing format (3 lines), then we are unable to do it.

val bucketString = s"bucket columns: [${bucketColumnNames.mkString(", ")}]"
val sortString = if (sortColumnNames.nonEmpty) {
s", sort columns: [${sortColumnNames.mkString(", ")}]"
} else {
""
}
s"$numBuckets buckets, $bucketString$sortString"
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
package org.apache.spark.sql.execution.command

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation}
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources.BaseRelation

/**
* A command used to create a data source table.
Expand Down Expand Up @@ -143,8 +141,9 @@ case class CreateDataSourceTableAsSelectCommand(
val tableName = tableIdentWithDB.unquotedString

var createMetastoreTable = false
var existingSchema = Option.empty[StructType]
if (sparkSession.sessionState.catalog.tableExists(tableIdentWithDB)) {
// We may need to reorder the columns of the query to match the existing table.
var reorderedColumns = Option.empty[Seq[NamedExpression]]
if (sessionState.catalog.tableExists(tableIdentWithDB)) {
// Check if we need to throw an exception or just return.
mode match {
case SaveMode.ErrorIfExists =>
Expand All @@ -157,39 +156,76 @@ case class CreateDataSourceTableAsSelectCommand(
// Since the table already exists and the save mode is Ignore, we will just return.
return Seq.empty[Row]
case SaveMode.Append =>
val existingTable = sessionState.catalog.getTableMetadata(tableIdentWithDB)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that possible we can directly use the input parameter table?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh, I see. Will do the final review today.


if (existingTable.provider.get == DDLUtils.HIVE_PROVIDER) {
throw new AnalysisException(s"Saving data in the Hive serde table $tableName is " +
"not supported yet. Please use the insertInto() API as an alternative.")
}

// Check if the specified data source match the data source of the existing table.
Copy link
Member

@gatorsmile gatorsmile Dec 19, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, the checking logics are split two places for CTAS of data source tables using the Append mode. Maybe we can improve the comment to explain AnalyzeCreateTable verifies the consistency between the user-specified table schema/definition and the SELECT query. Here, we verifies the consistency between the user-specified table schema/definition and the existing table schema/definition, the consistency between the existing table schema/definition and the SELECT query.

val existingProvider = DataSource.lookupDataSource(provider)
val existingProvider = DataSource.lookupDataSource(existingTable.provider.get)
val specifiedProvider = DataSource.lookupDataSource(table.provider.get)
// TODO: Check that options from the resolved relation match the relation that we are
// inserting into (i.e. using the same compression).
if (existingProvider != specifiedProvider) {
throw new AnalysisException(s"The format of the existing table $tableName is " +
s"`${existingProvider.getSimpleName}`. It doesn't match the specified format " +
s"`${specifiedProvider.getSimpleName}`.")
}

// Pass a table identifier with database part, so that `lookupRelation` won't get temp
// views unexpectedly.
EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) match {
case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) =>
// check if the file formats match
l.relation match {
case r: HadoopFsRelation if r.fileFormat.getClass != existingProvider =>
throw new AnalysisException(
s"The file format of the existing table $tableName is " +
s"`${r.fileFormat.getClass.getName}`. It doesn't match the specified " +
s"format `$provider`")
case _ =>
}
if (query.schema.size != l.schema.size) {
throw new AnalysisException(
s"The column number of the existing schema[${l.schema}] " +
s"doesn't match the data schema[${query.schema}]'s")
}
existingSchema = Some(l.schema)
case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) =>
existingSchema = Some(s.metadata.schema)
case c: CatalogRelation if c.catalogTable.provider == Some(DDLUtils.HIVE_PROVIDER) =>
throw new AnalysisException("Saving data in the Hive serde table " +
s"${c.catalogTable.identifier} is not supported yet. Please use the " +
"insertInto() API as an alternative..")
case o =>
throw new AnalysisException(s"Saving data in ${o.toString} is not supported.")
if (query.schema.length != existingTable.schema.length) {
throw new AnalysisException(
s"The column number of the existing table $tableName" +
s"(${existingTable.schema.catalogString}) doesn't match the data schema" +
s"(${query.schema.catalogString})")
}

val resolver = sessionState.conf.resolver
val tableCols = existingTable.schema.map(_.name)

reorderedColumns = Some(existingTable.schema.map { f =>
query.resolve(Seq(f.name), resolver).getOrElse {
val inputColumns = query.schema.map(_.name).mkString(", ")
throw new AnalysisException(
s"cannot resolve '${f.name}' given input columns: [$inputColumns]")
}
})

// In `AnalyzeCreateTable`, we verified the consistency between the user-specified table
// definition(partition columns, bucketing) and the SELECT query, here we also need to
// verify the the consistency between the user-specified table definition and the existing
// table definition.

// Check if the specified partition columns match the existing table.
val specifiedPartCols = CatalogUtils.normalizePartCols(
tableName, tableCols, table.partitionColumnNames, resolver)
if (specifiedPartCols != existingTable.partitionColumnNames) {
throw new AnalysisException(
s"""
|Specified partitioning does not match that of the existing table $tableName.
|Specified partition columns: [${specifiedPartCols.mkString(", ")}]
|Existing partition columns: [${existingTable.partitionColumnNames.mkString(", ")}]
""".stripMargin)
}

// Check if the specified bucketing match the existing table.
val specifiedBucketSpec = table.bucketSpec.map { bucketSpec =>
CatalogUtils.normalizeBucketSpec(tableName, tableCols, bucketSpec, resolver)
}
if (specifiedBucketSpec != existingTable.bucketSpec) {
val specifiedBucketString =
specifiedBucketSpec.map(_.toString).getOrElse("not bucketed")
val existingBucketString =
existingTable.bucketSpec.map(_.toString).getOrElse("not bucketed")
throw new AnalysisException(
s"""
|Specified bucketing does not match that of the existing table $tableName.
|Specified bucketing: $specifiedBucketString
|Existing bucketing: $existingBucketString
""".stripMargin)
}

case SaveMode.Overwrite =>
sessionState.catalog.dropTable(tableIdentWithDB, ignoreIfNotExists = true, purge = false)
// Need to create the table again.
Expand All @@ -201,9 +237,9 @@ case class CreateDataSourceTableAsSelectCommand(
}

val data = Dataset.ofRows(sparkSession, query)
val df = existingSchema match {
// If we are inserting into an existing table, just use the existing schema.
case Some(s) => data.selectExpr(s.fieldNames: _*)
val df = reorderedColumns match {
// Reorder the columns of the query to match the existing table.
case Some(cols) => data.select(cols.map(Column(_)): _*)
case None => data
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@

package org.apache.spark.sql.execution.datasources

import java.util.regex.Pattern

import scala.util.control.NonFatal

import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, SessionCatalog}
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable, CatalogUtils, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -122,9 +119,12 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl
}

private def checkPartitionColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = {
val normalizedPartitionCols = tableDesc.partitionColumnNames.map { colName =>
normalizeColumnName(tableDesc.identifier, schema, colName, "partition")
}
val normalizedPartitionCols = CatalogUtils.normalizePartCols(
tableName = tableDesc.identifier.unquotedString,
tableCols = schema.map(_.name),
partCols = tableDesc.partitionColumnNames,
resolver = sparkSession.sessionState.conf.resolver)

checkDuplication(normalizedPartitionCols, "partition")

if (schema.nonEmpty && normalizedPartitionCols.length == schema.length) {
Expand All @@ -149,25 +149,21 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl

private def checkBucketColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = {
tableDesc.bucketSpec match {
case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) =>
val normalizedBucketCols = bucketColumnNames.map { colName =>
normalizeColumnName(tableDesc.identifier, schema, colName, "bucket")
}
checkDuplication(normalizedBucketCols, "bucket")

val normalizedSortCols = sortColumnNames.map { colName =>
normalizeColumnName(tableDesc.identifier, schema, colName, "sort")
}
checkDuplication(normalizedSortCols, "sort")

schema.filter(f => normalizedSortCols.contains(f.name)).map(_.dataType).foreach {
case Some(bucketSpec) =>
val normalizedBucketing = CatalogUtils.normalizeBucketSpec(
tableName = tableDesc.identifier.unquotedString,
tableCols = schema.map(_.name),
bucketSpec = bucketSpec,
resolver = sparkSession.sessionState.conf.resolver)
checkDuplication(normalizedBucketing.bucketColumnNames, "bucket")
checkDuplication(normalizedBucketing.sortColumnNames, "sort")

normalizedBucketing.sortColumnNames.map(schema(_)).map(_.dataType).foreach {
case dt if RowOrdering.isOrderable(dt) => // OK
case other => failAnalysis(s"Cannot use ${other.simpleString} for sorting column")
}

tableDesc.copy(
bucketSpec = Some(BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols))
)
tableDesc.copy(bucketSpec = Some(normalizedBucketing))

case None => tableDesc
}
Expand All @@ -182,19 +178,6 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl
}
}

private def normalizeColumnName(
tableIdent: TableIdentifier,
schema: StructType,
colName: String,
colType: String): String = {
val tableCols = schema.map(_.name)
val resolver = sparkSession.sessionState.conf.resolver
tableCols.find(resolver(_, colName)).getOrElse {
failAnalysis(s"$colType column $colName is not defined in table $tableIdent, " +
s"defined table columns are: ${tableCols.mkString(", ")}")
}
}

private def failAnalysis(msg: String) = throw new AnalysisException(msg)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
val e = intercept[AnalysisException] {
sql("CREATE TABLE tbl(a int, b string) USING json PARTITIONED BY (c)")
}
assert(e.message == "partition column c is not defined in table `tbl`, " +
assert(e.message == "partition column c is not defined in table tbl, " +
"defined table columns are: a, b")
}

test("create table - bucket column names not in table definition") {
val e = intercept[AnalysisException] {
sql("CREATE TABLE tbl(a int, b string) USING json CLUSTERED BY (c) INTO 4 BUCKETS")
}
assert(e.message == "bucket column c is not defined in table `tbl`, " +
assert(e.message == "bucket column c is not defined in table tbl, " +
"defined table columns are: a, b")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,14 @@ class DefaultSourceWithoutUserSpecifiedSchema
}

class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {

import testImplicits._

private val userSchema = new StructType().add("s", StringType)
private val textSchema = new StructType().add("value", StringType)
private val data = Seq("1", "2", "3")
private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath
private implicit var enc: Encoder[String] = _

before {
enc = spark.implicits.newStringEncoder
Utils.deleteRecursively(new File(dir))
}

Expand Down Expand Up @@ -459,8 +457,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
}

test("column nullability and comment - write and then read") {
import testImplicits._

Seq("json", "parquet", "csv").foreach { format =>
val schema = StructType(
StructField("cl1", IntegerType, nullable = false).withComment("test") ::
Expand Down Expand Up @@ -576,7 +572,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be

test("SPARK-18510: use user specified types for partition columns in file sources") {
import org.apache.spark.sql.functions.udf
import testImplicits._
withTempDir { src =>
val createArray = udf { (length: Long) =>
for (i <- 1 to length.toInt) yield i.toString
Expand Down Expand Up @@ -609,4 +604,35 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
)
}
}

test("SPARK-18899: append to a bucketed table using DataFrameWriter with mismatched bucketing") {
withTable("t") {
Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.bucketBy(2, "i").saveAsTable("t")
val e = intercept[AnalysisException] {
Seq(3 -> "c").toDF("i", "j").write.bucketBy(3, "i").mode("append").saveAsTable("t")
}
assert(e.message.contains("Specified bucketing does not match that of the existing table"))
}
}

test("SPARK-18912: number of columns mismatch for non-file-based data source table") {
withTable("t") {
sql("CREATE TABLE t USING org.apache.spark.sql.test.DefaultSource")

val e = intercept[AnalysisException] {
Seq(1 -> "a").toDF("a", "b").write
.format("org.apache.spark.sql.test.DefaultSource")
.mode("append").saveAsTable("t")
}
assert(e.message.contains("The column number of the existing table"))
}
}

test("SPARK-18913: append to a table with special column names") {
withTable("t") {
Seq(1 -> "a").toDF("x.x", "y.y").write.saveAsTable("t")
Seq(2 -> "b").toDF("x.x", "y.y").write.mode("append").saveAsTable("t")
checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Nil)
}
}
}
Loading