Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery}
* An interface for looking up relations by name. Used by an [[Analyzer]].
*/
trait Catalog {

def caseSensitive: Boolean

def lookupRelation(
databaseName: Option[String],
tableName: String,
Expand All @@ -35,22 +38,44 @@ trait Catalog {
def unregisterTable(databaseName: Option[String], tableName: String): Unit

def unregisterAllTables(): Unit

protected def processDatabaseAndTableName(
databaseName: Option[String],
tableName: String): (Option[String], String) = {
if (!caseSensitive) {
(databaseName.map(_.toLowerCase), tableName.toLowerCase)
} else {
(databaseName, tableName)
}
}

protected def processDatabaseAndTableName(
databaseName: String,
tableName: String): (String, String) = {
if (!caseSensitive) {
(databaseName.toLowerCase, tableName.toLowerCase)
} else {
(databaseName, tableName)
}
}
}

class SimpleCatalog extends Catalog {
class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
val tables = new mutable.HashMap[String, LogicalPlan]()

override def registerTable(
databaseName: Option[String],
tableName: String,
plan: LogicalPlan): Unit = {
tables += ((tableName, plan))
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
tables += ((tblName, plan))
}

override def unregisterTable(
databaseName: Option[String],
tableName: String) = {
tables -= tableName
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
tables -= tblName
}

override def unregisterAllTables() = {
Expand All @@ -61,12 +86,13 @@ class SimpleCatalog extends Catalog {
databaseName: Option[String],
tableName: String,
alias: Option[String] = None): LogicalPlan = {
val table = tables.get(tableName).getOrElse(sys.error(s"Table Not Found: $tableName"))
val tableWithQualifiers = Subquery(tableName, table)
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
val table = tables.get(tblName).getOrElse(sys.error(s"Table Not Found: $tableName"))
val tableWithQualifiers = Subquery(tblName, table)

// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
// properly qualified with this alias.
alias.map(a => Subquery(a.toLowerCase, tableWithQualifiers)).getOrElse(tableWithQualifiers)
alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
}
}

Expand All @@ -85,26 +111,28 @@ trait OverrideCatalog extends Catalog {
databaseName: Option[String],
tableName: String,
alias: Option[String] = None): LogicalPlan = {

val overriddenTable = overrides.get((databaseName, tableName))
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
val overriddenTable = overrides.get((dbName, tblName))

// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
// properly qualified with this alias.
val withAlias =
overriddenTable.map(r => alias.map(a => Subquery(a.toLowerCase, r)).getOrElse(r))
overriddenTable.map(r => alias.map(a => Subquery(a, r)).getOrElse(r))

withAlias.getOrElse(super.lookupRelation(databaseName, tableName, alias))
withAlias.getOrElse(super.lookupRelation(dbName, tblName, alias))
}

override def registerTable(
databaseName: Option[String],
tableName: String,
plan: LogicalPlan): Unit = {
overrides.put((databaseName, tableName), plan)
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
overrides.put((dbName, tblName), plan)
}

override def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
overrides.remove((databaseName, tableName))
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
overrides.remove((dbName, tblName))
}

override def unregisterAllTables(): Unit = {
Expand All @@ -117,6 +145,9 @@ trait OverrideCatalog extends Catalog {
* relations are already filled in and the analyser needs only to resolve attribute references.
*/
object EmptyCatalog extends Catalog {

val caseSensitive: Boolean = true

def lookupRelation(
databaseName: Option[String],
tableName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,81 @@

package org.apache.spark.sql.catalyst.analysis

import org.scalatest.FunSuite
import org.scalatest.{BeforeAndAfter, FunSuite}

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.types.IntegerType

/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.expressions._
class AnalysisSuite extends FunSuite with BeforeAndAfter {
val caseSensitiveCatalog = new SimpleCatalog(true)
val caseInsensitiveCatalog = new SimpleCatalog(false)
val caseSensitiveAnalyze =
new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true)
val caseInsensitiveAnalyze =
new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false)

class AnalysisSuite extends FunSuite {
val analyze = SimpleAnalyzer
val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())

val testRelation = LocalRelation('a.int)
before {
caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation)
caseInsensitiveCatalog.registerTable(None, "TaBlE", testRelation)
}

test("analyze project") {
assert(
analyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
Project(testRelation.output, testRelation))

assert(
caseSensitiveAnalyze(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
Project(testRelation.output, testRelation))

val e = intercept[TreeNodeException[_]] {
caseSensitiveAnalyze(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(None, "TaBlE", Some("TbL"))))
}
assert(e.getMessage().toLowerCase.contains("unresolved"))

assert(
caseInsensitiveAnalyze(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
Project(testRelation.output, testRelation))

assert(
caseInsensitiveAnalyze(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
Project(testRelation.output, testRelation))
}

test("resolve relations") {
val e = intercept[RuntimeException] {
caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None))
}
assert(e.getMessage === "Table Not Found: tAbLe")

assert(
caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
testRelation)

assert(
caseInsensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) ===
testRelation)

assert(
caseInsensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
testRelation)
}

test("throw errors for unresolved attributes during analysis") {
val e = intercept[TreeNodeException[_]] {
analyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
}
assert(e.getMessage().toLowerCase.contains("unresolved"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
self =>

@transient
protected[sql] lazy val catalog: Catalog = new SimpleCatalog
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true)
@transient
protected[sql] lazy val analyzer: Analyzer =
new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with

val client = Hive.get(hive.hiveconf)

val caseSensitive: Boolean = false

def lookupRelation(
db: Option[String],
tableName: String,
alias: Option[String]): LogicalPlan = {
val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase)
val table = client.getTable(databaseName, tableName)
val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase)
val table = client.getTable(databaseName, tblName)
val partitions: Seq[Partition] =
if (table.isPartitioned) {
client.getAllPartitionsForPruner(table).toSeq
Expand All @@ -60,8 +63,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with

// Since HiveQL is case insensitive for table names we make them all lowercase.
MetastoreRelation(
databaseName.toLowerCase,
tableName.toLowerCase,
databaseName,
tblName,
alias)(table.getTTable, partitions.map(part => part.getTPartition))
}

Expand All @@ -70,7 +73,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
tableName: String,
schema: Seq[Attribute],
allowExisting: Boolean = false): Unit = {
val table = new Table(databaseName, tableName)
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
val table = new Table(dbName, tblName)
val hiveSchema =
schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), ""))
table.setFields(hiveSchema)
Expand All @@ -85,7 +89,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
sd.setInputFormat("org.apache.hadoop.mapred.TextInputFormat")
sd.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")
val serDeInfo = new SerDeInfo()
serDeInfo.setName(tableName)
serDeInfo.setName(tblName)
serDeInfo.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")
serDeInfo.setParameters(Map[String, String]())
sd.setSerdeInfo(serDeInfo)
Expand All @@ -104,13 +108,14 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
object CreateTables extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case InsertIntoCreatedTable(db, tableName, child) =>
val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase)
val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase)

createTable(databaseName, tableName, child.output)
createTable(databaseName, tblName, child.output)

InsertIntoTable(
EliminateAnalysisOperators(
lookupRelation(Some(databaseName), tableName, None)),
lookupRelation(Some(databaseName), tblName, None)),
Map.empty,
child,
overwrite = false)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
0 val_0
4 val_4
12 val_12
8 val_8
0 val_0
0 val_0
10 val_10
5 val_5
11 val_11
5 val_5
2 val_2
12 val_12
5 val_5
9 val_9
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,22 @@ class HiveQuerySuite extends HiveComparisonTest {
}
}

createQueryTest("case sensitivity: Hive table",
"SELECT srcalias.KEY, SRCALIAS.value FROM sRc SrCAlias WHERE SrCAlias.kEy < 15")

test("case sensitivity: registered table") {
val testData: SchemaRDD =
TestHive.sparkContext.parallelize(
TestData(1, "str1") ::
TestData(2, "str2") :: Nil)
testData.registerAsTable("REGisteredTABle")

assertResult(Array(Array(2, "str2"))) {
hql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " +
"WHERE TableAliaS.a > 1").collect()
}
}

def isExplanation(result: SchemaRDD) = {
val explanation = result.select('plan).collect().map { case Row(plan: String) => plan }
explanation.size > 1 && explanation.head.startsWith("Physical execution plan")
Expand Down