Skip to content

Commit 12d8006

Browse files
committed
Handling case sensitivity correctly.
This patch introduces three changes. 1. If a table has an alias, the catalog will not lowercase the alias. If a lowercase alias is needed, the analyzer will do the work. 2. A catalog has a new val caseSensitive that indicates if this catalog is case sensitive or not. For example, a SimpleCatalog is case sensitive, but 3. Corresponding unit tests. With this patch, case sensitivity of database names and table names is handled by the catalog. Case sensitivity of other identifiers is handled by the analyzer.
1 parent 9d5ecf8 commit 12d8006

File tree

6 files changed

+149
-30
lines changed

6 files changed

+149
-30
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery}
2525
* An interface for looking up relations by name. Used by an [[Analyzer]].
2626
*/
2727
trait Catalog {
28+
29+
def caseSensitive: Boolean
30+
2831
def lookupRelation(
2932
databaseName: Option[String],
3033
tableName: String,
@@ -35,22 +38,44 @@ trait Catalog {
3538
def unregisterTable(databaseName: Option[String], tableName: String): Unit
3639

3740
def unregisterAllTables(): Unit
41+
42+
protected def processDatabaseAndTableName(
43+
databaseName: Option[String],
44+
tableName: String): (Option[String], String) = {
45+
if (!caseSensitive) {
46+
(databaseName.map(_.toLowerCase), tableName.toLowerCase)
47+
} else {
48+
(databaseName, tableName)
49+
}
50+
}
51+
52+
protected def processDatabaseAndTableName(
53+
databaseName: String,
54+
tableName: String): (String, String) = {
55+
if (!caseSensitive) {
56+
(databaseName.toLowerCase, tableName.toLowerCase)
57+
} else {
58+
(databaseName, tableName)
59+
}
60+
}
3861
}
3962

40-
class SimpleCatalog extends Catalog {
63+
class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
4164
val tables = new mutable.HashMap[String, LogicalPlan]()
4265

4366
override def registerTable(
4467
databaseName: Option[String],
4568
tableName: String,
4669
plan: LogicalPlan): Unit = {
47-
tables += ((tableName, plan))
70+
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
71+
tables += ((tblName, plan))
4872
}
4973

5074
override def unregisterTable(
5175
databaseName: Option[String],
5276
tableName: String) = {
53-
tables -= tableName
77+
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
78+
tables -= tblName
5479
}
5580

5681
override def unregisterAllTables() = {
@@ -61,12 +86,13 @@ class SimpleCatalog extends Catalog {
6186
databaseName: Option[String],
6287
tableName: String,
6388
alias: Option[String] = None): LogicalPlan = {
64-
val table = tables.get(tableName).getOrElse(sys.error(s"Table Not Found: $tableName"))
65-
val tableWithQualifiers = Subquery(tableName, table)
89+
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
90+
val table = tables.get(tblName).getOrElse(sys.error(s"Table Not Found: $tableName"))
91+
val tableWithQualifiers = Subquery(tblName, table)
6692

6793
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
6894
// properly qualified with this alias.
69-
alias.map(a => Subquery(a.toLowerCase, tableWithQualifiers)).getOrElse(tableWithQualifiers)
95+
alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
7096
}
7197
}
7298

@@ -85,26 +111,28 @@ trait OverrideCatalog extends Catalog {
85111
databaseName: Option[String],
86112
tableName: String,
87113
alias: Option[String] = None): LogicalPlan = {
88-
89-
val overriddenTable = overrides.get((databaseName, tableName))
114+
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
115+
val overriddenTable = overrides.get((dbName, tblName))
90116

91117
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
92118
// properly qualified with this alias.
93119
val withAlias =
94-
overriddenTable.map(r => alias.map(a => Subquery(a.toLowerCase, r)).getOrElse(r))
120+
overriddenTable.map(r => alias.map(a => Subquery(a, r)).getOrElse(r))
95121

96-
withAlias.getOrElse(super.lookupRelation(databaseName, tableName, alias))
122+
withAlias.getOrElse(super.lookupRelation(dbName, tblName, alias))
97123
}
98124

99125
override def registerTable(
100126
databaseName: Option[String],
101127
tableName: String,
102128
plan: LogicalPlan): Unit = {
103-
overrides.put((databaseName, tableName), plan)
129+
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
130+
overrides.put((dbName, tblName), plan)
104131
}
105132

106133
override def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
107-
overrides.remove((databaseName, tableName))
134+
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
135+
overrides.remove((dbName, tblName))
108136
}
109137

110138
override def unregisterAllTables(): Unit = {
@@ -117,6 +145,9 @@ trait OverrideCatalog extends Catalog {
117145
* relations are already filled in and the analyser needs only to resolve attribute references.
118146
*/
119147
object EmptyCatalog extends Catalog {
148+
149+
val caseSensitive: Boolean = true
150+
120151
def lookupRelation(
121152
databaseName: Option[String],
122153
tableName: String,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,81 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20-
import org.scalatest.FunSuite
20+
import org.scalatest.{BeforeAndAfter, FunSuite}
2121

22+
import org.apache.spark.sql.catalyst.expressions.AttributeReference
2223
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2324
import org.apache.spark.sql.catalyst.plans.logical._
25+
import org.apache.spark.sql.catalyst.types.IntegerType
2426

25-
/* Implicit conversions */
26-
import org.apache.spark.sql.catalyst.dsl.expressions._
27+
class AnalysisSuite extends FunSuite with BeforeAndAfter {
28+
val caseSensitiveCatalog = new SimpleCatalog(true)
29+
val caseInsensitiveCatalog = new SimpleCatalog(false)
30+
val caseSensitiveAnalyze =
31+
new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true)
32+
val caseInsensitiveAnalyze =
33+
new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false)
2734

28-
class AnalysisSuite extends FunSuite {
29-
val analyze = SimpleAnalyzer
35+
val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
3036

31-
val testRelation = LocalRelation('a.int)
37+
before {
38+
caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation)
39+
caseInsensitiveCatalog.registerTable(None, "TaBlE", testRelation)
40+
}
3241

3342
test("analyze project") {
3443
assert(
35-
analyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
44+
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
45+
Project(testRelation.output, testRelation))
46+
47+
assert(
48+
caseSensitiveAnalyze(
49+
Project(Seq(UnresolvedAttribute("TbL.a")),
50+
UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
51+
Project(testRelation.output, testRelation))
52+
53+
val e = intercept[TreeNodeException[_]] {
54+
caseSensitiveAnalyze(
55+
Project(Seq(UnresolvedAttribute("tBl.a")),
56+
UnresolvedRelation(None, "TaBlE", Some("TbL"))))
57+
}
58+
assert(e.getMessage().toLowerCase.contains("unresolved"))
59+
60+
assert(
61+
caseInsensitiveAnalyze(
62+
Project(Seq(UnresolvedAttribute("TbL.a")),
63+
UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
3664
Project(testRelation.output, testRelation))
65+
66+
assert(
67+
caseInsensitiveAnalyze(
68+
Project(Seq(UnresolvedAttribute("tBl.a")),
69+
UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
70+
Project(testRelation.output, testRelation))
71+
}
72+
73+
test("resolve relations") {
74+
val e = intercept[RuntimeException] {
75+
caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None))
76+
}
77+
assert(e.getMessage === "Table Not Found: tAbLe")
78+
79+
assert(
80+
caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
81+
testRelation)
82+
83+
assert(
84+
caseInsensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) ===
85+
testRelation)
86+
87+
assert(
88+
caseInsensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
89+
testRelation)
3790
}
3891

3992
test("throw errors for unresolved attributes during analysis") {
4093
val e = intercept[TreeNodeException[_]] {
41-
analyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
94+
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
4295
}
4396
assert(e.getMessage().toLowerCase.contains("unresolved"))
4497
}

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
5757
self =>
5858

5959
@transient
60-
protected[sql] lazy val catalog: Catalog = new SimpleCatalog
60+
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true)
6161
@transient
6262
protected[sql] lazy val analyzer: Analyzer =
6363
new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = true)

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,15 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
4545

4646
val client = Hive.get(hive.hiveconf)
4747

48+
val caseSensitive: Boolean = false
49+
4850
def lookupRelation(
4951
db: Option[String],
5052
tableName: String,
5153
alias: Option[String]): LogicalPlan = {
52-
val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase)
53-
val table = client.getTable(databaseName, tableName)
54+
val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
55+
val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase)
56+
val table = client.getTable(databaseName, tblName)
5457
val partitions: Seq[Partition] =
5558
if (table.isPartitioned) {
5659
client.getAllPartitionsForPruner(table).toSeq
@@ -60,8 +63,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
6063

6164
// Since HiveQL is case insensitive for table names we make them all lowercase.
6265
MetastoreRelation(
63-
databaseName.toLowerCase,
64-
tableName.toLowerCase,
66+
databaseName,
67+
tblName,
6568
alias)(table.getTTable, partitions.map(part => part.getTPartition))
6669
}
6770

@@ -70,7 +73,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
7073
tableName: String,
7174
schema: Seq[Attribute],
7275
allowExisting: Boolean = false): Unit = {
73-
val table = new Table(databaseName, tableName)
76+
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
77+
val table = new Table(dbName, tblName)
7478
val hiveSchema =
7579
schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), ""))
7680
table.setFields(hiveSchema)
@@ -85,7 +89,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
8589
sd.setInputFormat("org.apache.hadoop.mapred.TextInputFormat")
8690
sd.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")
8791
val serDeInfo = new SerDeInfo()
88-
serDeInfo.setName(tableName)
92+
serDeInfo.setName(tblName)
8993
serDeInfo.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")
9094
serDeInfo.setParameters(Map[String, String]())
9195
sd.setSerdeInfo(serDeInfo)
@@ -104,13 +108,14 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
104108
object CreateTables extends Rule[LogicalPlan] {
105109
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
106110
case InsertIntoCreatedTable(db, tableName, child) =>
107-
val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase)
111+
val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
112+
val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase)
108113

109-
createTable(databaseName, tableName, child.output)
114+
createTable(databaseName, tblName, child.output)
110115

111116
InsertIntoTable(
112117
EliminateAnalysisOperators(
113-
lookupRelation(Some(databaseName), tableName, None)),
118+
lookupRelation(Some(databaseName), tblName, None)),
114119
Map.empty,
115120
child,
116121
overwrite = false)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
0 val_0
2+
4 val_4
3+
12 val_12
4+
8 val_8
5+
0 val_0
6+
0 val_0
7+
10 val_10
8+
5 val_5
9+
11 val_11
10+
5 val_5
11+
2 val_2
12+
12 val_12
13+
5 val_5
14+
9 val_9

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,22 @@ class HiveQuerySuite extends HiveComparisonTest {
210210
}
211211
}
212212

213+
createQueryTest("case sensitivity: Hive table",
214+
"SELECT srcalias.KEY, SRCALIAS.value FROM sRc SrCAlias WHERE SrCAlias.kEy < 15")
215+
216+
test("case sensitivity: registered table") {
217+
val testData: SchemaRDD =
218+
TestHive.sparkContext.parallelize(
219+
TestData(1, "str1") ::
220+
TestData(2, "str2") :: Nil)
221+
testData.registerAsTable("REGisteredTABle")
222+
223+
assertResult(Array(Array(2, "str2"))) {
224+
hql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " +
225+
"WHERE TableAliaS.a > 1").collect()
226+
}
227+
}
228+
213229
def isExplanation(result: SchemaRDD) = {
214230
val explanation = result.select('plan).collect().map { case Row(plan: String) => plan }
215231
explanation.size > 1 && explanation.head.startsWith("Physical execution plan")

0 commit comments

Comments
 (0)