diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index f30b5d816703..0d05d9808b40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} * An interface for looking up relations by name. Used by an [[Analyzer]]. */ trait Catalog { + + def caseSensitive: Boolean + def lookupRelation( databaseName: Option[String], tableName: String, @@ -35,22 +38,44 @@ trait Catalog { def unregisterTable(databaseName: Option[String], tableName: String): Unit def unregisterAllTables(): Unit + + protected def processDatabaseAndTableName( + databaseName: Option[String], + tableName: String): (Option[String], String) = { + if (!caseSensitive) { + (databaseName.map(_.toLowerCase), tableName.toLowerCase) + } else { + (databaseName, tableName) + } + } + + protected def processDatabaseAndTableName( + databaseName: String, + tableName: String): (String, String) = { + if (!caseSensitive) { + (databaseName.toLowerCase, tableName.toLowerCase) + } else { + (databaseName, tableName) + } + } } -class SimpleCatalog extends Catalog { +class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { val tables = new mutable.HashMap[String, LogicalPlan]() override def registerTable( databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = { - tables += ((tableName, plan)) + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + tables += ((tblName, plan)) } override def unregisterTable( databaseName: Option[String], tableName: String) = { - tables -= tableName + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + tables -= tblName } override def unregisterAllTables() = { @@ -61,12 +86,13 @@ class SimpleCatalog extends Catalog { databaseName: Option[String], tableName: String, alias: Option[String] = None): LogicalPlan = { - val table = tables.get(tableName).getOrElse(sys.error(s"Table Not Found: $tableName")) - val tableWithQualifiers = Subquery(tableName, table) + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + val table = tables.get(tblName).getOrElse(sys.error(s"Table Not Found: $tableName")) + val tableWithQualifiers = Subquery(tblName, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. - alias.map(a => Subquery(a.toLowerCase, tableWithQualifiers)).getOrElse(tableWithQualifiers) + alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) } } @@ -85,26 +111,28 @@ trait OverrideCatalog extends Catalog { databaseName: Option[String], tableName: String, alias: Option[String] = None): LogicalPlan = { - - val overriddenTable = overrides.get((databaseName, tableName)) + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + val overriddenTable = overrides.get((dbName, tblName)) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. val withAlias = - overriddenTable.map(r => alias.map(a => Subquery(a.toLowerCase, r)).getOrElse(r)) + overriddenTable.map(r => alias.map(a => Subquery(a, r)).getOrElse(r)) - withAlias.getOrElse(super.lookupRelation(databaseName, tableName, alias)) + withAlias.getOrElse(super.lookupRelation(dbName, tblName, alias)) } override def registerTable( databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = { - overrides.put((databaseName, tableName), plan) + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + overrides.put((dbName, tblName), plan) } override def unregisterTable(databaseName: Option[String], tableName: String): Unit = { - overrides.remove((databaseName, tableName)) + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + overrides.remove((dbName, tblName)) } override def unregisterAllTables(): Unit = { @@ -117,6 +145,9 @@ trait OverrideCatalog extends Catalog { * relations are already filled in and the analyser needs only to resolve attribute references. */ object EmptyCatalog extends Catalog { + + val caseSensitive: Boolean = true + def lookupRelation( databaseName: Option[String], tableName: String, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f14df8137683..0a4fde3de775 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,28 +17,81 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.FunSuite +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types.IntegerType -/* Implicit conversions */ -import org.apache.spark.sql.catalyst.dsl.expressions._ +class AnalysisSuite extends FunSuite with BeforeAndAfter { + val caseSensitiveCatalog = new SimpleCatalog(true) + val caseInsensitiveCatalog = new SimpleCatalog(false) + val caseSensitiveAnalyze = + new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) + val caseInsensitiveAnalyze = + new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) -class AnalysisSuite extends FunSuite { - val analyze = SimpleAnalyzer + val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) - val testRelation = LocalRelation('a.int) + before { + caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation) + caseInsensitiveCatalog.registerTable(None, "TaBlE", testRelation) + } test("analyze project") { assert( - analyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) === + caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) === + Project(testRelation.output, testRelation)) + + assert( + caseSensitiveAnalyze( + Project(Seq(UnresolvedAttribute("TbL.a")), + UnresolvedRelation(None, "TaBlE", Some("TbL")))) === + Project(testRelation.output, testRelation)) + + val e = intercept[TreeNodeException[_]] { + caseSensitiveAnalyze( + Project(Seq(UnresolvedAttribute("tBl.a")), + UnresolvedRelation(None, "TaBlE", Some("TbL")))) + } + assert(e.getMessage().toLowerCase.contains("unresolved")) + + assert( + caseInsensitiveAnalyze( + Project(Seq(UnresolvedAttribute("TbL.a")), + UnresolvedRelation(None, "TaBlE", Some("TbL")))) === Project(testRelation.output, testRelation)) + + assert( + caseInsensitiveAnalyze( + Project(Seq(UnresolvedAttribute("tBl.a")), + UnresolvedRelation(None, "TaBlE", Some("TbL")))) === + Project(testRelation.output, testRelation)) + } + + test("resolve relations") { + val e = intercept[RuntimeException] { + caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) + } + assert(e.getMessage === "Table Not Found: tAbLe") + + assert( + caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) === + testRelation) + + assert( + caseInsensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) === + testRelation) + + assert( + caseInsensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) === + testRelation) } test("throw errors for unresolved attributes during analysis") { val e = intercept[TreeNodeException[_]] { - analyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation)) + caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation)) } assert(e.getMessage().toLowerCase.contains("unresolved")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7edb548678c3..4abd89955bd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -57,7 +57,7 @@ class SQLContext(@transient val sparkContext: SparkContext) self => @transient - protected[sql] lazy val catalog: Catalog = new SimpleCatalog + protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true) @transient protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = true) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 7c24b5cabf61..f83068860701 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -45,12 +45,15 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val client = Hive.get(hive.hiveconf) + val caseSensitive: Boolean = false + def lookupRelation( db: Option[String], tableName: String, alias: Option[String]): LogicalPlan = { - val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase) - val table = client.getTable(databaseName, tableName) + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) + val table = client.getTable(databaseName, tblName) val partitions: Seq[Partition] = if (table.isPartitioned) { client.getAllPartitionsForPruner(table).toSeq @@ -60,8 +63,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Since HiveQL is case insensitive for table names we make them all lowercase. MetastoreRelation( - databaseName.toLowerCase, - tableName.toLowerCase, + databaseName, + tblName, alias)(table.getTTable, partitions.map(part => part.getTPartition)) } @@ -70,7 +73,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with tableName: String, schema: Seq[Attribute], allowExisting: Boolean = false): Unit = { - val table = new Table(databaseName, tableName) + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + val table = new Table(dbName, tblName) val hiveSchema = schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), "")) table.setFields(hiveSchema) @@ -85,7 +89,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with sd.setInputFormat("org.apache.hadoop.mapred.TextInputFormat") sd.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat") val serDeInfo = new SerDeInfo() - serDeInfo.setName(tableName) + serDeInfo.setName(tblName) serDeInfo.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") serDeInfo.setParameters(Map[String, String]()) sd.setSerdeInfo(serDeInfo) @@ -104,13 +108,14 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with object CreateTables extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case InsertIntoCreatedTable(db, tableName, child) => - val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase) + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) - createTable(databaseName, tableName, child.output) + createTable(databaseName, tblName, child.output) InsertIntoTable( EliminateAnalysisOperators( - lookupRelation(Some(databaseName), tableName, None)), + lookupRelation(Some(databaseName), tblName, None)), Map.empty, child, overwrite = false) diff --git a/sql/hive/src/test/resources/golden/case sensitivity: Hive table-0-5d14d21a239daa42b086cc895215009a b/sql/hive/src/test/resources/golden/case sensitivity: Hive table-0-5d14d21a239daa42b086cc895215009a new file mode 100644 index 000000000000..4d7127c0faab --- /dev/null +++ b/sql/hive/src/test/resources/golden/case sensitivity: Hive table-0-5d14d21a239daa42b086cc895215009a @@ -0,0 +1,14 @@ +0 val_0 +4 val_4 +12 val_12 +8 val_8 +0 val_0 +0 val_0 +10 val_10 +5 val_5 +11 val_11 +5 val_5 +2 val_2 +12 val_12 +5 val_5 +9 val_9 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 9f1cd703103e..a623d29b5397 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -210,6 +210,22 @@ class HiveQuerySuite extends HiveComparisonTest { } } + createQueryTest("case sensitivity: Hive table", + "SELECT srcalias.KEY, SRCALIAS.value FROM sRc SrCAlias WHERE SrCAlias.kEy < 15") + + test("case sensitivity: registered table") { + val testData: SchemaRDD = + TestHive.sparkContext.parallelize( + TestData(1, "str1") :: + TestData(2, "str2") :: Nil) + testData.registerAsTable("REGisteredTABle") + + assertResult(Array(Array(2, "str2"))) { + hql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + + "WHERE TableAliaS.a > 1").collect() + } + } + def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } explanation.size > 1 && explanation.head.startsWith("Physical execution plan")