diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 9e6e2912e0622..94ca072d3d606 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -77,7 +77,8 @@ trait Catalog { } class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { - val tables = new mutable.HashMap[String, LogicalPlan]() + import scala.collection.mutable.SynchronizedMap + val tables = new mutable.HashMap[String, LogicalPlan]() with SynchronizedMap[String, LogicalPlan] override def registerTable( tableIdentifier: Seq[String], @@ -134,9 +135,11 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { * lost when the JVM exits. */ trait OverrideCatalog extends Catalog { + import scala.collection.mutable.SynchronizedMap // TODO: This doesn't work when the database changes... val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() + with SynchronizedMap[(Option[String],String), LogicalPlan] abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) @@ -235,3 +238,5 @@ object EmptyCatalog extends Catalog { throw new UnsupportedOperationException } } + +object SimpleCaseSensitiveCatalog extends SimpleCatalog(true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9f334f6d42ad1..5d75a413d3305 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -93,3 +93,5 @@ class StringKeyHashMap[T](normalizer: (String) => String) { def iterator: Iterator[(String, T)] = base.toIterator } +object SimpleCaseSentiveFunctionRegistry extends SimpleFunctionRegistry(true) +object SimpleInCaseSentiveFunctionRegistry extends SimpleFunctionRegistry(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index ca4a127120b37..2714a5a06673f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -34,24 +34,30 @@ private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryR * InMemoryRelation. This relation is automatically substituted query plans that return the * `sameResult` as the originally cached query. * + * TODO Cached Data (Global wide V.S. Catalog Instance wide) * Internal to Spark SQL. */ -private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { - +private[sql] object CacheManager extends Logging { @transient - private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] + private[this] val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] @transient - private val cacheLock = new ReentrantReadWriteLock + private[this] val cacheLock = new ReentrantReadWriteLock /** Returns true if the table is currently cached in-memory. */ - def isCached(tableName: String): Boolean = lookupCachedData(sqlContext.table(tableName)).nonEmpty + private[sql] def isCached(sqlContext: SQLContext, tableName: String): Boolean = { + lookupCachedData(sqlContext.table(tableName)).nonEmpty + } /** Caches the specified table in-memory. */ - def cacheTable(tableName: String): Unit = cacheQuery(sqlContext.table(tableName), Some(tableName)) + private[sql] def cacheTable(sqlContext: SQLContext, tableName: String): Unit = { + cacheQuery(sqlContext.table(tableName), Some(tableName)) + } /** Removes the specified table from the in-memory cache. */ - def uncacheTable(tableName: String): Unit = uncacheQuery(sqlContext.table(tableName)) + private[sql] def uncacheTable(sqlContext: SQLContext, tableName: String): Unit = { + uncacheQuery(sqlContext.table(tableName)) + } /** Acquires a read lock on the cache for the duration of `f`. */ private def readLock[A](f: => A): A = { @@ -99,8 +105,8 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { CachedData( planToCache, InMemoryRelation( - sqlContext.conf.useCompression, - sqlContext.conf.columnBatchSize, + query.sqlContext.conf.useCompression, + query.sqlContext.conf.columnBatchSize, storageLevel, query.queryExecution.executedPlan, tableName)) @@ -162,3 +168,4 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { } } } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 810f7c77477bb..8bbc89792028f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -819,7 +819,7 @@ class DataFrame protected[sql]( * @group basic */ override def persist(): this.type = { - sqlContext.cacheManager.cacheQuery(this) + CacheManager.cacheQuery(this) this } @@ -827,7 +827,7 @@ class DataFrame protected[sql]( * @group basic */ override def persist(newLevel: StorageLevel): this.type = { - sqlContext.cacheManager.cacheQuery(this, None, newLevel) + CacheManager.cacheQuery(this, None, newLevel) this } @@ -835,7 +835,7 @@ class DataFrame protected[sql]( * @group basic */ override def unpersist(blocking: Boolean): this.type = { - sqlContext.cacheManager.tryUncacheQuery(this, blocking) + CacheManager.tryUncacheQuery(this, blocking) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4bdaa023914b8..9cad25926dda5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -104,10 +104,10 @@ class SQLContext(@transient val sparkContext: SparkContext) def getAllConfs: immutable.Map[String, String] = conf.getAllConfs @transient - protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true) + protected[sql] lazy val catalog: Catalog = SimpleCaseSensitiveCatalog @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(true) + protected[sql] lazy val functionRegistry: FunctionRegistry = SimpleCaseSentiveFunctionRegistry @transient protected[sql] lazy val analyzer: Analyzer = @@ -144,9 +144,6 @@ class SQLContext(@transient val sparkContext: SparkContext) case _ => } - @transient - protected[sql] val cacheManager = new CacheManager(this) - /** * :: Experimental :: * A collection of methods that are considered experimental, but can be used to hook into @@ -203,24 +200,24 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns true if the table is currently cached in-memory. * @group cachemgmt */ - def isCached(tableName: String): Boolean = cacheManager.isCached(tableName) + def isCached(tableName: String): Boolean = CacheManager.isCached(this, tableName) /** * Caches the specified table in-memory. * @group cachemgmt */ - def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName) + def cacheTable(tableName: String): Unit = CacheManager.cacheTable(this, tableName) /** * Removes the specified table from the in-memory cache. * @group cachemgmt */ - def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName) + def uncacheTable(tableName: String): Unit = CacheManager.uncacheTable(this, tableName) /** * Removes all cached tables from the in-memory cache. */ - def clearCache(): Unit = cacheManager.clearCache() + def clearCache(): Unit = CacheManager.clearCache() // scalastyle:off // Disable style checker so "implicits" object can start with lowercase i @@ -905,7 +902,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group basic */ def dropTempTable(tableName: String): Unit = { - cacheManager.tryUncacheQuery(table(tableName)) + CacheManager.tryUncacheQuery(table(tableName)) catalog.unregisterTable(Seq(tableName)) } @@ -1066,7 +1063,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] class QueryExecution(val logical: LogicalPlan) { lazy val analyzed: LogicalPlan = analyzer(logical) - lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed) + lazy val withCachedData: LogicalPlan = CacheManager.useCachedData(analyzed) lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData) // TODO: Don't just pick the first one... diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index c9cd0e6e93829..2dd924d84f8f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -16,10 +16,10 @@ */ package org.apache.spark.sql.sources -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{CacheManager, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{LogicalRDD, RunnableCommand} +import org.apache.spark.sql.execution.RunnableCommand private[sql] case class InsertIntoDataSource( logicalRelation: LogicalRelation, @@ -32,7 +32,7 @@ private[sql] case class InsertIntoDataSource( relation.insert(DataFrame(sqlContext, query), overwrite) // Invalidate the cache. - sqlContext.cacheManager.invalidateCache(logicalRelation) + CacheManager.invalidateCache(logicalRelation) Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c240f2be955ca..0146a0c2a0c1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -56,17 +56,17 @@ class CachedTableSuite extends QueryTest { } test("unpersist an uncached table will not raise exception") { - assert(None == cacheManager.lookupCachedData(testData)) - testData.unpersist(blocking = true) - assert(None == cacheManager.lookupCachedData(testData)) - testData.unpersist(blocking = false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == CacheManager.lookupCachedData(testData)) + testData.unpersist(true) + assert(None == CacheManager.lookupCachedData(testData)) + testData.unpersist(false) + assert(None == CacheManager.lookupCachedData(testData)) testData.persist() - assert(None != cacheManager.lookupCachedData(testData)) - testData.unpersist(blocking = true) - assert(None == cacheManager.lookupCachedData(testData)) - testData.unpersist(blocking = false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None != CacheManager.lookupCachedData(testData)) + testData.unpersist(true) + assert(None == CacheManager.lookupCachedData(testData)) + testData.unpersist(false) + assert(None == CacheManager.lookupCachedData(testData)) } test("cache table as select") { @@ -287,13 +287,13 @@ class CachedTableSuite extends QueryTest { cacheTable("t1") cacheTable("t2") clearCache() - assert(cacheManager.isEmpty) + assert(CacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") cacheTable("t1") cacheTable("t2") sql("Clear CACHE") - assert(cacheManager.isEmpty) + assert(CacheManager.isEmpty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index dd0948ad824be..8ca171e993e26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -60,7 +60,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - cacheManager.clearCache() + CacheManager.clearCache() Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), @@ -94,7 +94,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted hash join operator selection") { - cacheManager.clearCache() + CacheManager.clearCache() sql("CACHE TABLE testData") Seq( @@ -385,7 +385,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted left semi join operator selection") { - cacheManager.clearCache() + CacheManager.clearCache() sql("CACHE TABLE testData") val tmp = conf.autoBroadcastJoinThreshold diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 6e07df18b0e15..5093c7a8d6086 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, Logging} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ @@ -39,10 +39,11 @@ object HiveThriftServer2 extends Logging { /** * :: DeveloperApi :: * Starts a new thrift server with the given context. + * TODO probably a SparkContext, and HiveConf as parameter would be better */ @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { - val server = new HiveThriftServer2(sqlContext) + val server = new HiveThriftServer2(sqlContext.sparkContext) server.init(sqlContext.hiveconf) server.start() sqlContext.sparkContext.addSparkListener(new HiveThriftServer2Listener(server)) @@ -66,7 +67,7 @@ object HiveThriftServer2 extends Logging { ) try { - val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) + val server = new HiveThriftServer2(SparkSQLEnv.sparkContext) server.init(SparkSQLEnv.hiveContext.hiveconf) server.start() logInfo("HiveThriftServer2 started") @@ -89,12 +90,12 @@ object HiveThriftServer2 extends Logging { } -private[hive] class HiveThriftServer2(hiveContext: HiveContext) +private[hive] class HiveThriftServer2(sc: SparkContext) extends HiveServer2 with ReflectedCompositeService { override def init(hiveConf: HiveConf) { - val sparkSqlCliService = new SparkSQLCLIService(hiveContext) + val sparkSqlCliService = new SparkSQLCLIService(sc) setSuperField(this, "cliService", sparkSqlCliService) addService(sparkSqlCliService) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 499e077d7294a..4f0e7505dc0fa 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -21,6 +21,8 @@ import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException +import org.apache.spark.SparkContext + import scala.collection.JavaConversions._ import org.apache.commons.logging.Log @@ -36,14 +38,14 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.util.Utils -private[hive] class SparkSQLCLIService(hiveContext: HiveContext) +private[hive] class SparkSQLCLIService(sc: SparkContext) extends CLIService with ReflectedCompositeService { override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) - val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) + val sparkSqlSessionManager = new SparkSQLSessionManager(sc) setSuperField(this, "sessionManager", sparkSqlSessionManager) addService(sparkSqlSessionManager) var sparkServiceUGI: UserGroupInformation = null @@ -66,7 +68,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) getInfoType match { case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL") case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL") - case GetInfoType.CLI_DBMS_VER => new GetInfoValue(hiveContext.sparkContext.version) + case GetInfoType.CLI_DBMS_VER => new GetInfoValue(sc.version) case _ => super.getInfo(sessionHandle, getInfoType) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 89e9ede7261c9..25667f5cc7112 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -23,17 +23,22 @@ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.session.SessionManager +import org.apache.spark.SparkContext import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager import org.apache.hive.service.cli.SessionHandle -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) +private[hive] class SparkSQLSessionManager(private val sc: SparkContext) extends SessionManager with ReflectedCompositeService { - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + private val hiveContexts = new ThreadLocal[HiveContext]() { + override protected def initialValue = new HiveContext(sc) + } + + private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContexts) override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) @@ -52,5 +57,6 @@ private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) override def closeSession(sessionHandle: SessionHandle) { super.closeSession(sessionHandle) sparkSqlOperationManager.sessionToActivePool -= sessionHandle + hiveContexts.remove() } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 9c0bf02391e0e..8bfbe53d971a9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.hive.thriftserver.{SparkExecuteStatementOperation, R /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. */ -private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) +private[thriftserver] class SparkSQLOperationManager(hiveContexts: ThreadLocal[HiveContext]) extends OperationManager with Logging { val handleToOperation = ReflectionUtils @@ -45,7 +45,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) async: Boolean): ExecuteStatementOperation = synchronized { val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)( - hiveContext, sessionToActivePool) + hiveContexts.get(), sessionToActivePool) handleToOperation.put(operation.getHandle, operation) operation } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index b52a51d11e4ad..84d3dc77e2b0e 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -70,10 +70,10 @@ class HiveThriftServer2Suite extends FunSuite with Logging { port } - def withJdbcStatement( + def withMultipleConnectionJdbcStatement( serverStartTimeout: FiniteDuration = 1.minute, httpMode: Boolean = false)( - f: Statement => Unit) { + fs: Seq[Statement => Unit]) { val port = randomListeningPort startThriftServer(port, serverStartTimeout, httpMode) { @@ -85,18 +85,25 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } val user = System.getProperty("user.name") - val connection = DriverManager.getConnection(jdbcUri, user, "") - val statement = connection.createStatement() + val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") } + val statements = connections.map(_.createStatement()) try { - f(statement) + statements.zip(fs).map { case (s, f) => f(s) } } finally { - statement.close() - connection.close() + statements.map(_.close()) + connections.map(_.close()) } } } + def withJdbcStatement( + serverStartTimeout: FiniteDuration = 1.minute, + httpMode: Boolean = false)( + f: Statement => Unit) { + withMultipleConnectionJdbcStatement(serverStartTimeout, httpMode)(Seq(f)) + } + def withCLIServiceClient( serverStartTimeout: FiniteDuration = 1.minute)( f: ThriftCLIServiceClient => Unit) { @@ -384,4 +391,137 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } } } + + test("test multiple session") { + import org.apache.spark.sql.SQLConf + + var defaultV1: String = null + var defaultV2: String = null + + withMultipleConnectionJdbcStatement() (Seq( + // create table + { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_map", + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map", + "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC") + + queries.foreach(statement.execute) + + val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") + val buf1 = new collection.mutable.ArrayBuffer[Int]() + while (rs1.next()) { + buf1 += rs1.getInt(1) + } + rs1.close() + + val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf2 = new collection.mutable.ArrayBuffer[Int]() + while (rs2.next()) { + buf2 += rs2.getInt(1) + } + rs2.close() + + assert(buf1 === buf2) + }, + + // first session, we get the default value of the session status + { statement => + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + rs1.next() + defaultV1 = rs1.getString(1) + assert(defaultV1 != "200") + rs1.close() + + val rs2 = statement.executeQuery("SET hive.cli.print.header") + rs2.next() + defaultV2 = rs2.getString(1) + assert(defaultV1 != "true") + rs2.close() + }, + + // second session, we update the session status + { statement => + val queries = Seq( + s"SET ${SQLConf.SHUFFLE_PARTITIONS}=291", + "SET hive.cli.print.header=true" + ) + + queries.map(statement.execute) + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + rs1.next() + assert("spark.sql.shuffle.partitions=291" === rs1.getString(1)) + rs1.close() + + val rs2 = statement.executeQuery("SET hive.cli.print.header") + rs2.next() + assert("hive.cli.print.header=true" === rs2.getString(1)) + rs2.close() + }, + + // third session, we get the latest session status, supposed to be the + // default value + { statement => + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + rs1.next() + assert(defaultV1 === rs1.getString(1)) + rs1.close() + + val rs2 = statement.executeQuery("SET hive.cli.print.header") + rs2.next() + assert(defaultV2 === rs2.getString(1)) + rs2.close() + }, + + // accessing the cached data in another session + { statement => + val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") + val buf1 = new collection.mutable.ArrayBuffer[Int]() + while (rs1.next()) { + buf1 += rs1.getInt(1) + } + rs1.close() + + val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf2 = new collection.mutable.ArrayBuffer[Int]() + while (rs2.next()) { + buf2 += rs2.getInt(1) + } + rs2.close() + + assert(buf1 === buf2) + statement.executeQuery("UNCACHE TABLE test_table") + + // TODO need to figure out how to determine if the data loaded from cache + val rs3 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf3 = new collection.mutable.ArrayBuffer[Int]() + while (rs3.next()) { + buf3 += rs3.getInt(1) + } + rs3.close() + + assert(buf1 === buf3) + }, + + // accessing the uncached table + { statement => + // TODO need to figure out how to determine if the data loaded from cache + val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") + val buf1 = new collection.mutable.ArrayBuffer[Int]() + while (rs1.next()) { + buf1 += rs1.getInt(1) + } + rs1.close() + + val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf2 = new collection.mutable.ArrayBuffer[Int]() + while (rs2.next()) { + buf2 += rs2.getInt(1) + } + rs2.close() + + assert(buf1 === buf2) + })) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2e205e67c0fdd..3ce9fb0c0d6c9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -249,14 +249,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient - override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog + override protected[sql] lazy val catalog: HiveMetastoreCatalog = { + HiveMetastore.initializeOrGet(this) + } // Note that HiveUDFs will be overridden by functions registered in this context. @transient - override protected[sql] lazy val functionRegistry = - new HiveFunctionRegistry with OverrideFunctionRegistry { - def caseSensitive = false - } + override protected[sql] lazy val functionRegistry = HiveCaseInsensitiveFunctionRegistry /* An analyzer that uses the Hive metastore. */ @transient diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f7ad2efc9544e..1c7322f97f280 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -759,3 +759,19 @@ object HiveMetastoreTypes { case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) } } + +private[hive] object HiveMetastore { + private[this] var catalog: HiveMetastoreCatalog = _ + + /** + * Reuse the Catalog instance only create it if the catalog doesn't exist. + * And we can not change the catalog once it's created. + */ + def initializeOrGet(context: HiveContext): HiveMetastoreCatalog = synchronized { + if (catalog == null) { + catalog = new HiveMetastoreCatalog(context) with OverrideCatalog + } + + catalog + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 91af35f0965c0..908eb83c44dd7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD +import org.apache.spark.sql.CacheManager import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.hive._ @@ -239,7 +240,7 @@ case class InsertIntoHiveTable( } // Invalidate the cache. - sqlContext.cacheManager.invalidateCache(table) + CacheManager.invalidateCache(table) // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index c88d0e6b79491..865a8594bf081 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} +import org.apache.spark.sql.{SaveMode, CacheManager, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand @@ -59,7 +59,7 @@ case class DropTable( val hiveContext = sqlContext.asInstanceOf[HiveContext] val ifExistsClause = if (ifExists) "IF EXISTS " else "" try { - hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) + CacheManager.tryUncacheQuery(hiveContext.table(tableName)) } catch { // This table's metadata is not in case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 34c21c11761ae..13ca29b1932a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Generate, Project, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.analysis.MultiAlias +import org.apache.spark.sql.catalyst.analysis.{OverrideFunctionRegistry, MultiAlias} import org.apache.spark.sql.catalyst.errors.TreeNodeException /* Implicit conversions */ @@ -73,6 +73,12 @@ private[hive] abstract class HiveFunctionRegistry } } +private[hive] object HiveCaseInsensitiveFunctionRegistry + extends HiveFunctionRegistry + with OverrideFunctionRegistry { + def caseSensitive = false +} + private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type EvaluatedType = Any diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index a2d99f1f4b28d..81077b9058223 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.RegexSerDe import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.serde2.avro.AvroSerDe -import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.{CacheManager, SQLConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ @@ -397,7 +397,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } - cacheManager.clearCache() + CacheManager.clearCache() loadedTables.clear() catalog.cachedDataSourceTables.invalidateAll() catalog.client.getAllTables("default").foreach { t =>