Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.catalog

import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException}
import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException}
import org.apache.spark.sql.catalyst.expressions.Expression


Expand All @@ -39,6 +39,12 @@ abstract class ExternalCatalog {
}
}

protected def requireTableExists(db: String, table: String): Unit = {
if (!tableExists(db, table)) {
throw new NoSuchTableException(db = db, table = table)
}
}

protected def requireFunctionExists(db: String, funcName: String): Unit = {
if (!functionExists(db, funcName)) {
throw new NoSuchFunctionException(db = db, func = funcName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ class InMemoryCatalog(
catalog(db).tables(table).partitions.contains(spec)
}

private def requireTableExists(db: String, table: String): Unit = {
if (!tableExists(db, table)) {
throw new NoSuchTableException(db = db, table = table)
}
}

private def requireTableNotExists(db: String, table: String): Unit = {
if (tableExists(db, table)) {
throw new TableAlreadyExistsException(db = db, table = table)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,28 @@ object DataType {
case (fromDataType, toDataType) => fromDataType == toDataType
}
}

/**
* Compares two types, ignoring nullability of ArrayType, MapType, StructType, and ignoring case
* sensitivity of field names in StructType.
*/
private[sql] def equalsIgnoreCaseAndNullability(from: DataType, to: DataType): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is really tricky to add this method in DataType.

(from, to) match {
case (ArrayType(fromElement, _), ArrayType(toElement, _)) =>
equalsIgnoreCaseAndNullability(fromElement, toElement)

case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
equalsIgnoreCaseAndNullability(fromKey, toKey) &&
equalsIgnoreCaseAndNullability(fromValue, toValue)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall { case (l, r) =>
l.name.equalsIgnoreCase(r.name) &&
equalsIgnoreCaseAndNullability(l.dataType, r.dataType)
}

case (fromDataType, toDataType) => fromDataType == toDataType
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,26 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
assert(catalog.listTables("db2", "*1").toSet == Set("tbl1"))
}

test("column names should be case-preserving and column nullability should be retained") {
val catalog = newBasicCatalog()
val tbl = CatalogTable(
identifier = TableIdentifier("tbl", Some("db1")),
tableType = CatalogTableType.MANAGED,
storage = storageFormat,
schema = new StructType()
.add("HelLo", "int", nullable = false)
.add("WoRLd", "int", nullable = true),
provider = Some("hive"),
partitionColumnNames = Seq("WoRLd"),
bucketSpec = Some(BucketSpec(4, Seq("HelLo"), Nil)))
catalog.createTable(tbl, ignoreIfExists = false)

val readBack = catalog.getTable("db1", "tbl")
assert(readBack.schema == tbl.schema)
assert(readBack.partitionColumnNames == tbl.partitionColumnNames)
assert(readBack.bucketSpec == tbl.bucketSpec)
}

// --------------------------------------------------------------------------
// Partitions
// --------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, OverwriteOptions, Union}
import org.apache.spark.sql.execution.command.AlterTableRecoverPartitionsCommand
import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, CreateTable, DataSource, HadoopFsRelation}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, OverwriteOptions}
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, DDLUtils}
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, HadoopFsRelation}
import org.apache.spark.sql.types.StructType

/**
Expand Down Expand Up @@ -359,7 +359,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}

private def saveAsTable(tableIdent: TableIdentifier): Unit = {
if (source.toLowerCase == "hive") {
if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) {
throw new AnalysisException("Cannot create hive serde table with saveAsTable API")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
}
val options = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)
val provider = ctx.tableProvider.qualifiedName.getText
if (provider.toLowerCase == "hive") {
if (provider.toLowerCase == DDLUtils.HIVE_PROVIDER) {
throw new AnalysisException("Cannot create hive serde table with CREATE TABLE USING")
}
val schema = Option(ctx.colTypeList()).map(createSchema)
Expand Down Expand Up @@ -1034,7 +1034,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
tableType = tableType,
storage = storage,
schema = schema,
provider = Some("hive"),
provider = Some(DDLUtils.HIVE_PROVIDER),
partitionColumnNames = partitionCols.map(_.name),
properties = properties,
comment = comment)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

object DDLStrategy extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case CreateTable(tableDesc, mode, None) if tableDesc.provider.get == "hive" =>
case CreateTable(tableDesc, mode, None)
if tableDesc.provider.get == DDLUtils.HIVE_PROVIDER =>
val cmd = CreateTableCommand(tableDesc, ifNotExists = mode == SaveMode.Ignore)
ExecutedCommandExec(cmd) :: Nil

Expand All @@ -427,7 +428,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// CREATE TABLE ... AS SELECT ... for hive serde table is handled in hive module, by rule
// `CreateTables`

case CreateTable(tableDesc, mode, Some(query)) if tableDesc.provider.get != "hive" =>
case CreateTable(tableDesc, mode, Some(query))
if tableDesc.provider.get != DDLUtils.HIVE_PROVIDER =>
val cmd =
CreateDataSourceTableAsSelectCommand(
tableDesc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,10 @@ case class AlterTableSetLocationCommand(


object DDLUtils {
val HIVE_PROVIDER = "hive"

def isDatasourceTable(table: CatalogTable): Boolean = {
table.provider.isDefined && table.provider.get != "hive"
table.provider.isDefined && table.provider.get != HIVE_PROVIDER
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrd
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation}
import org.apache.spark.sql.types.{AtomicType, StructType}
Expand Down Expand Up @@ -127,7 +128,7 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl
checkDuplication(normalizedPartitionCols, "partition")

if (schema.nonEmpty && normalizedPartitionCols.length == schema.length) {
if (tableDesc.provider.get == "hive") {
if (tableDesc.provider.get == DDLUtils.HIVE_PROVIDER) {
// When we hit this branch, it means users didn't specify schema for the table to be
// created, as we always include partition columns in table schema for hive serde tables.
// The real schema will be inferred at hive metastore by hive serde, plus the given
Expand Down Expand Up @@ -292,7 +293,7 @@ object HiveOnlyCheck extends (LogicalPlan => Unit) {
def apply(plan: LogicalPlan): Unit = {
plan.foreach {
case CreateTable(tableDesc, _, Some(_))
if tableDesc.provider.get == "hive" =>
if tableDesc.provider.get == DDLUtils.HIVE_PROVIDER =>
throw new AnalysisException("Hive support is required to use CREATE Hive TABLE AS SELECT")

case _ => // OK
Expand Down
Loading