From 12d8006f738e299a08621c382bef4a0a23a72b6f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 7 Jul 2014 09:55:59 -0700 Subject: [PATCH] Handling case sensitivity correctly. This patch introduces three changes. 1. If a table has an alias, the catalog will not lowercase the alias. If a lowercase alias is needed, the analyzer will do the work. 2. A catalog has a new val caseSensitive that indicates if this catalog is case sensitive or not. For example, a SimpleCatalog is case sensitive, but 3. Corresponding unit tests. With this patch, case sensitivity of database names and table names is handled by the catalog. Case sensitivity of other identifiers is handled by the analyzer. --- .../spark/sql/catalyst/analysis/Catalog.scala | 55 +++++++++++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 69 ++++++++++++++++--- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 23 ++++--- ...e table-0-5d14d21a239daa42b086cc895215009a | 14 ++++ .../sql/hive/execution/HiveQuerySuite.scala | 16 +++++ 6 files changed, 149 insertions(+), 30 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/case sensitivity: Hive table-0-5d14d21a239daa42b086cc895215009a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index f30b5d816703..0d05d9808b40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} * An interface for looking up relations by name. Used by an [[Analyzer]]. */ trait Catalog { + + def caseSensitive: Boolean + def lookupRelation( databaseName: Option[String], tableName: String, @@ -35,22 +38,44 @@ trait Catalog { def unregisterTable(databaseName: Option[String], tableName: String): Unit def unregisterAllTables(): Unit + + protected def processDatabaseAndTableName( + databaseName: Option[String], + tableName: String): (Option[String], String) = { + if (!caseSensitive) { + (databaseName.map(_.toLowerCase), tableName.toLowerCase) + } else { + (databaseName, tableName) + } + } + + protected def processDatabaseAndTableName( + databaseName: String, + tableName: String): (String, String) = { + if (!caseSensitive) { + (databaseName.toLowerCase, tableName.toLowerCase) + } else { + (databaseName, tableName) + } + } } -class SimpleCatalog extends Catalog { +class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { val tables = new mutable.HashMap[String, LogicalPlan]() override def registerTable( databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = { - tables += ((tableName, plan)) + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + tables += ((tblName, plan)) } override def unregisterTable( databaseName: Option[String], tableName: String) = { - tables -= tableName + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + tables -= tblName } override def unregisterAllTables() = { @@ -61,12 +86,13 @@ class SimpleCatalog extends Catalog { databaseName: Option[String], tableName: String, alias: Option[String] = None): LogicalPlan = { - val table = tables.get(tableName).getOrElse(sys.error(s"Table Not Found: $tableName")) - val tableWithQualifiers = Subquery(tableName, table) + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + val table = tables.get(tblName).getOrElse(sys.error(s"Table Not Found: $tableName")) + val tableWithQualifiers = Subquery(tblName, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. - alias.map(a => Subquery(a.toLowerCase, tableWithQualifiers)).getOrElse(tableWithQualifiers) + alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) } } @@ -85,26 +111,28 @@ trait OverrideCatalog extends Catalog { databaseName: Option[String], tableName: String, alias: Option[String] = None): LogicalPlan = { - - val overriddenTable = overrides.get((databaseName, tableName)) + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + val overriddenTable = overrides.get((dbName, tblName)) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. val withAlias = - overriddenTable.map(r => alias.map(a => Subquery(a.toLowerCase, r)).getOrElse(r)) + overriddenTable.map(r => alias.map(a => Subquery(a, r)).getOrElse(r)) - withAlias.getOrElse(super.lookupRelation(databaseName, tableName, alias)) + withAlias.getOrElse(super.lookupRelation(dbName, tblName, alias)) } override def registerTable( databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = { - overrides.put((databaseName, tableName), plan) + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + overrides.put((dbName, tblName), plan) } override def unregisterTable(databaseName: Option[String], tableName: String): Unit = { - overrides.remove((databaseName, tableName)) + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + overrides.remove((dbName, tblName)) } override def unregisterAllTables(): Unit = { @@ -117,6 +145,9 @@ trait OverrideCatalog extends Catalog { * relations are already filled in and the analyser needs only to resolve attribute references. */ object EmptyCatalog extends Catalog { + + val caseSensitive: Boolean = true + def lookupRelation( databaseName: Option[String], tableName: String, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f14df8137683..0a4fde3de775 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,28 +17,81 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.FunSuite +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types.IntegerType -/* Implicit conversions */ -import org.apache.spark.sql.catalyst.dsl.expressions._ +class AnalysisSuite extends FunSuite with BeforeAndAfter { + val caseSensitiveCatalog = new SimpleCatalog(true) + val caseInsensitiveCatalog = new SimpleCatalog(false) + val caseSensitiveAnalyze = + new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) + val caseInsensitiveAnalyze = + new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) -class AnalysisSuite extends FunSuite { - val analyze = SimpleAnalyzer + val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) - val testRelation = LocalRelation('a.int) + before { + caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation) + caseInsensitiveCatalog.registerTable(None, "TaBlE", testRelation) + } test("analyze project") { assert( - analyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) === + caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) === + Project(testRelation.output, testRelation)) + + assert( + caseSensitiveAnalyze( + Project(Seq(UnresolvedAttribute("TbL.a")), + UnresolvedRelation(None, "TaBlE", Some("TbL")))) === + Project(testRelation.output, testRelation)) + + val e = intercept[TreeNodeException[_]] { + caseSensitiveAnalyze( + Project(Seq(UnresolvedAttribute("tBl.a")), + UnresolvedRelation(None, "TaBlE", Some("TbL")))) + } + assert(e.getMessage().toLowerCase.contains("unresolved")) + + assert( + caseInsensitiveAnalyze( + Project(Seq(UnresolvedAttribute("TbL.a")), + UnresolvedRelation(None, "TaBlE", Some("TbL")))) === Project(testRelation.output, testRelation)) + + assert( + caseInsensitiveAnalyze( + Project(Seq(UnresolvedAttribute("tBl.a")), + UnresolvedRelation(None, "TaBlE", Some("TbL")))) === + Project(testRelation.output, testRelation)) + } + + test("resolve relations") { + val e = intercept[RuntimeException] { + caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) + } + assert(e.getMessage === "Table Not Found: tAbLe") + + assert( + caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) === + testRelation) + + assert( + caseInsensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) === + testRelation) + + assert( + caseInsensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) === + testRelation) } test("throw errors for unresolved attributes during analysis") { val e = intercept[TreeNodeException[_]] { - analyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation)) + caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation)) } assert(e.getMessage().toLowerCase.contains("unresolved")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7edb548678c3..4abd89955bd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -57,7 +57,7 @@ class SQLContext(@transient val sparkContext: SparkContext) self => @transient - protected[sql] lazy val catalog: Catalog = new SimpleCatalog + protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true) @transient protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = true) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 7c24b5cabf61..f83068860701 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -45,12 +45,15 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val client = Hive.get(hive.hiveconf) + val caseSensitive: Boolean = false + def lookupRelation( db: Option[String], tableName: String, alias: Option[String]): LogicalPlan = { - val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase) - val table = client.getTable(databaseName, tableName) + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) + val table = client.getTable(databaseName, tblName) val partitions: Seq[Partition] = if (table.isPartitioned) { client.getAllPartitionsForPruner(table).toSeq @@ -60,8 +63,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Since HiveQL is case insensitive for table names we make them all lowercase. MetastoreRelation( - databaseName.toLowerCase, - tableName.toLowerCase, + databaseName, + tblName, alias)(table.getTTable, partitions.map(part => part.getTPartition)) } @@ -70,7 +73,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with tableName: String, schema: Seq[Attribute], allowExisting: Boolean = false): Unit = { - val table = new Table(databaseName, tableName) + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) + val table = new Table(dbName, tblName) val hiveSchema = schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), "")) table.setFields(hiveSchema) @@ -85,7 +89,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with sd.setInputFormat("org.apache.hadoop.mapred.TextInputFormat") sd.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat") val serDeInfo = new SerDeInfo() - serDeInfo.setName(tableName) + serDeInfo.setName(tblName) serDeInfo.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") serDeInfo.setParameters(Map[String, String]()) sd.setSerdeInfo(serDeInfo) @@ -104,13 +108,14 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with object CreateTables extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case InsertIntoCreatedTable(db, tableName, child) => - val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase) + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) - createTable(databaseName, tableName, child.output) + createTable(databaseName, tblName, child.output) InsertIntoTable( EliminateAnalysisOperators( - lookupRelation(Some(databaseName), tableName, None)), + lookupRelation(Some(databaseName), tblName, None)), Map.empty, child, overwrite = false) diff --git a/sql/hive/src/test/resources/golden/case sensitivity: Hive table-0-5d14d21a239daa42b086cc895215009a b/sql/hive/src/test/resources/golden/case sensitivity: Hive table-0-5d14d21a239daa42b086cc895215009a new file mode 100644 index 000000000000..4d7127c0faab --- /dev/null +++ b/sql/hive/src/test/resources/golden/case sensitivity: Hive table-0-5d14d21a239daa42b086cc895215009a @@ -0,0 +1,14 @@ +0 val_0 +4 val_4 +12 val_12 +8 val_8 +0 val_0 +0 val_0 +10 val_10 +5 val_5 +11 val_11 +5 val_5 +2 val_2 +12 val_12 +5 val_5 +9 val_9 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 9f1cd703103e..a623d29b5397 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -210,6 +210,22 @@ class HiveQuerySuite extends HiveComparisonTest { } } + createQueryTest("case sensitivity: Hive table", + "SELECT srcalias.KEY, SRCALIAS.value FROM sRc SrCAlias WHERE SrCAlias.kEy < 15") + + test("case sensitivity: registered table") { + val testData: SchemaRDD = + TestHive.sparkContext.parallelize( + TestData(1, "str1") :: + TestData(2, "str2") :: Nil) + testData.registerAsTable("REGisteredTABle") + + assertResult(Array(Array(2, "str2"))) { + hql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + + "WHERE TableAliaS.a > 1").collect() + } + } + def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } explanation.size > 1 && explanation.head.startsWith("Physical execution plan")