Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ trait Catalog {
}

class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
val tables = new mutable.HashMap[String, LogicalPlan]()
import scala.collection.mutable.SynchronizedMap
val tables = new mutable.HashMap[String, LogicalPlan]() with SynchronizedMap[String, LogicalPlan]

override def registerTable(
tableIdentifier: Seq[String],
Expand Down Expand Up @@ -134,9 +135,11 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
* lost when the JVM exits.
*/
trait OverrideCatalog extends Catalog {
import scala.collection.mutable.SynchronizedMap

// TODO: This doesn't work when the database changes...
val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]()
with SynchronizedMap[(Option[String],String), LogicalPlan]

abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = {
val tableIdent = processTableIdentifier(tableIdentifier)
Expand Down Expand Up @@ -235,3 +238,5 @@ object EmptyCatalog extends Catalog {
throw new UnsupportedOperationException
}
}

object SimpleCaseSensitiveCatalog extends SimpleCatalog(true)
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,5 @@ class StringKeyHashMap[T](normalizer: (String) => String) {
def iterator: Iterator[(String, T)] = base.toIterator
}

object SimpleCaseSentiveFunctionRegistry extends SimpleFunctionRegistry(true)
object SimpleInCaseSentiveFunctionRegistry extends SimpleFunctionRegistry(false)
25 changes: 16 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,30 @@ private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryR
* InMemoryRelation. This relation is automatically substituted query plans that return the
* `sameResult` as the originally cached query.
*
* TODO Cached Data (Global wide V.S. Catalog Instance wide)
* Internal to Spark SQL.
*/
private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {

private[sql] object CacheManager extends Logging {
@transient
private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
private[this] val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]

@transient
private val cacheLock = new ReentrantReadWriteLock
private[this] val cacheLock = new ReentrantReadWriteLock

/** Returns true if the table is currently cached in-memory. */
def isCached(tableName: String): Boolean = lookupCachedData(sqlContext.table(tableName)).nonEmpty
private[sql] def isCached(sqlContext: SQLContext, tableName: String): Boolean = {
lookupCachedData(sqlContext.table(tableName)).nonEmpty
}

/** Caches the specified table in-memory. */
def cacheTable(tableName: String): Unit = cacheQuery(sqlContext.table(tableName), Some(tableName))
private[sql] def cacheTable(sqlContext: SQLContext, tableName: String): Unit = {
cacheQuery(sqlContext.table(tableName), Some(tableName))
}

/** Removes the specified table from the in-memory cache. */
def uncacheTable(tableName: String): Unit = uncacheQuery(sqlContext.table(tableName))
private[sql] def uncacheTable(sqlContext: SQLContext, tableName: String): Unit = {
uncacheQuery(sqlContext.table(tableName))
}

/** Acquires a read lock on the cache for the duration of `f`. */
private def readLock[A](f: => A): A = {
Expand Down Expand Up @@ -99,8 +105,8 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
CachedData(
planToCache,
InMemoryRelation(
sqlContext.conf.useCompression,
sqlContext.conf.columnBatchSize,
query.sqlContext.conf.useCompression,
query.sqlContext.conf.columnBatchSize,
storageLevel,
query.queryExecution.executedPlan,
tableName))
Expand Down Expand Up @@ -162,3 +168,4 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
}
}
}

6 changes: 3 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -819,23 +819,23 @@ class DataFrame protected[sql](
* @group basic
*/
override def persist(): this.type = {
sqlContext.cacheManager.cacheQuery(this)
CacheManager.cacheQuery(this)
this
}

/**
* @group basic
*/
override def persist(newLevel: StorageLevel): this.type = {
sqlContext.cacheManager.cacheQuery(this, None, newLevel)
CacheManager.cacheQuery(this, None, newLevel)
this
}

/**
* @group basic
*/
override def unpersist(blocking: Boolean): this.type = {
sqlContext.cacheManager.tryUncacheQuery(this, blocking)
CacheManager.tryUncacheQuery(this, blocking)
this
}

Expand Down
19 changes: 8 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs

@transient
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true)
protected[sql] lazy val catalog: Catalog = SimpleCaseSensitiveCatalog

@transient
protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(true)
protected[sql] lazy val functionRegistry: FunctionRegistry = SimpleCaseSentiveFunctionRegistry

@transient
protected[sql] lazy val analyzer: Analyzer =
Expand Down Expand Up @@ -144,9 +144,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
case _ =>
}

@transient
protected[sql] val cacheManager = new CacheManager(this)

/**
* :: Experimental ::
* A collection of methods that are considered experimental, but can be used to hook into
Expand Down Expand Up @@ -203,24 +200,24 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Returns true if the table is currently cached in-memory.
* @group cachemgmt
*/
def isCached(tableName: String): Boolean = cacheManager.isCached(tableName)
def isCached(tableName: String): Boolean = CacheManager.isCached(this, tableName)

/**
* Caches the specified table in-memory.
* @group cachemgmt
*/
def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName)
def cacheTable(tableName: String): Unit = CacheManager.cacheTable(this, tableName)

/**
* Removes the specified table from the in-memory cache.
* @group cachemgmt
*/
def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName)
def uncacheTable(tableName: String): Unit = CacheManager.uncacheTable(this, tableName)

/**
* Removes all cached tables from the in-memory cache.
*/
def clearCache(): Unit = cacheManager.clearCache()
def clearCache(): Unit = CacheManager.clearCache()

// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
Expand Down Expand Up @@ -905,7 +902,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group basic
*/
def dropTempTable(tableName: String): Unit = {
cacheManager.tryUncacheQuery(table(tableName))
CacheManager.tryUncacheQuery(table(tableName))
catalog.unregisterTable(Seq(tableName))
}

Expand Down Expand Up @@ -1066,7 +1063,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] class QueryExecution(val logical: LogicalPlan) {

lazy val analyzed: LogicalPlan = analyzer(logical)
lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed)
lazy val withCachedData: LogicalPlan = CacheManager.useCachedData(analyzed)
lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData)

// TODO: Don't just pick the first one...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
*/
package org.apache.spark.sql.sources

import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.{CacheManager, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{LogicalRDD, RunnableCommand}
import org.apache.spark.sql.execution.RunnableCommand

private[sql] case class InsertIntoDataSource(
logicalRelation: LogicalRelation,
Expand All @@ -32,7 +32,7 @@ private[sql] case class InsertIntoDataSource(
relation.insert(DataFrame(sqlContext, query), overwrite)

// Invalidate the cache.
sqlContext.cacheManager.invalidateCache(logicalRelation)
CacheManager.invalidateCache(logicalRelation)

Seq.empty[Row]
}
Expand Down
24 changes: 12 additions & 12 deletions sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ class CachedTableSuite extends QueryTest {
}

test("unpersist an uncached table will not raise exception") {
assert(None == cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = true)
assert(None == cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = false)
assert(None == cacheManager.lookupCachedData(testData))
assert(None == CacheManager.lookupCachedData(testData))
testData.unpersist(true)
assert(None == CacheManager.lookupCachedData(testData))
testData.unpersist(false)
assert(None == CacheManager.lookupCachedData(testData))
testData.persist()
assert(None != cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = true)
assert(None == cacheManager.lookupCachedData(testData))
testData.unpersist(blocking = false)
assert(None == cacheManager.lookupCachedData(testData))
assert(None != CacheManager.lookupCachedData(testData))
testData.unpersist(true)
assert(None == CacheManager.lookupCachedData(testData))
testData.unpersist(false)
assert(None == CacheManager.lookupCachedData(testData))
}

test("cache table as select") {
Expand Down Expand Up @@ -287,13 +287,13 @@ class CachedTableSuite extends QueryTest {
cacheTable("t1")
cacheTable("t2")
clearCache()
assert(cacheManager.isEmpty)
assert(CacheManager.isEmpty)

sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
cacheTable("t1")
cacheTable("t2")
sql("Clear CACHE")
assert(cacheManager.isEmpty)
assert(CacheManager.isEmpty)
}
}
6 changes: 3 additions & 3 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}

test("join operator selection") {
cacheManager.clearCache()
CacheManager.clearCache()

Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
Expand Down Expand Up @@ -94,7 +94,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}

test("broadcasted hash join operator selection") {
cacheManager.clearCache()
CacheManager.clearCache()
sql("CACHE TABLE testData")

Seq(
Expand Down Expand Up @@ -385,7 +385,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}

test("broadcasted left semi join operator selection") {
cacheManager.clearCache()
CacheManager.clearCache()
sql("CACHE TABLE testData")
val tmp = conf.autoBroadcastJoinThreshold

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService}
import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor}

import org.apache.spark.Logging
import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
Expand All @@ -39,10 +39,11 @@ object HiveThriftServer2 extends Logging {
/**
* :: DeveloperApi ::
* Starts a new thrift server with the given context.
* TODO probably a SparkContext, and HiveConf as parameter would be better
*/
@DeveloperApi
def startWithContext(sqlContext: HiveContext): Unit = {
val server = new HiveThriftServer2(sqlContext)
val server = new HiveThriftServer2(sqlContext.sparkContext)
server.init(sqlContext.hiveconf)
server.start()
sqlContext.sparkContext.addSparkListener(new HiveThriftServer2Listener(server))
Expand All @@ -66,7 +67,7 @@ object HiveThriftServer2 extends Logging {
)

try {
val server = new HiveThriftServer2(SparkSQLEnv.hiveContext)
val server = new HiveThriftServer2(SparkSQLEnv.sparkContext)
server.init(SparkSQLEnv.hiveContext.hiveconf)
server.start()
logInfo("HiveThriftServer2 started")
Expand All @@ -89,12 +90,12 @@ object HiveThriftServer2 extends Logging {

}

private[hive] class HiveThriftServer2(hiveContext: HiveContext)
private[hive] class HiveThriftServer2(sc: SparkContext)
extends HiveServer2
with ReflectedCompositeService {

override def init(hiveConf: HiveConf) {
val sparkSqlCliService = new SparkSQLCLIService(hiveContext)
val sparkSqlCliService = new SparkSQLCLIService(sc)
setSuperField(this, "cliService", sparkSqlCliService)
addService(sparkSqlCliService)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import java.io.IOException
import java.util.{List => JList}
import javax.security.auth.login.LoginException

import org.apache.spark.SparkContext

import scala.collection.JavaConversions._

import org.apache.commons.logging.Log
Expand All @@ -36,14 +38,14 @@ import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.util.Utils

private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
private[hive] class SparkSQLCLIService(sc: SparkContext)
extends CLIService
with ReflectedCompositeService {

override def init(hiveConf: HiveConf) {
setSuperField(this, "hiveConf", hiveConf)

val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext)
val sparkSqlSessionManager = new SparkSQLSessionManager(sc)
setSuperField(this, "sessionManager", sparkSqlSessionManager)
addService(sparkSqlSessionManager)
var sparkServiceUGI: UserGroupInformation = null
Expand All @@ -66,7 +68,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
getInfoType match {
case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL")
case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL")
case GetInfoType.CLI_DBMS_VER => new GetInfoValue(hiveContext.sparkContext.version)
case GetInfoType.CLI_DBMS_VER => new GetInfoValue(sc.version)
case _ => super.getInfo(sessionHandle, getInfoType)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,22 @@ import org.apache.commons.logging.Log
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.service.cli.session.SessionManager
import org.apache.spark.SparkContext

import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager
import org.apache.hive.service.cli.SessionHandle

private[hive] class SparkSQLSessionManager(hiveContext: HiveContext)
private[hive] class SparkSQLSessionManager(private val sc: SparkContext)
extends SessionManager
with ReflectedCompositeService {

private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext)
private val hiveContexts = new ThreadLocal[HiveContext]() {
override protected def initialValue = new HiveContext(sc)
}

private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContexts)

override def init(hiveConf: HiveConf) {
setSuperField(this, "hiveConf", hiveConf)
Expand All @@ -52,5 +57,6 @@ private[hive] class SparkSQLSessionManager(hiveContext: HiveContext)
override def closeSession(sessionHandle: SessionHandle) {
super.closeSession(sessionHandle)
sparkSqlOperationManager.sessionToActivePool -= sessionHandle
hiveContexts.remove()
}
}
Loading