Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ class SQLContext(@transient val sparkContext: SparkContext)

def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)

// Note that this is a lazy val so we can override the default value in subclasses.
protected[sql] lazy val conf: SQLConf = new SQLConf
/**
* @return Spark SQL configuration
*/
protected[sql] def conf = tlSession.get().conf

/**
* Set Spark SQL configuration properties.
Expand Down Expand Up @@ -103,9 +105,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs

// TODO how to handle the temp table per user session?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good question. Ideally we may want to session isolation for temporary tables. However, we can leave this for another PR if you think it makes this PR too complicated. Especially, HiveMetastoreCatalog handles both persisted and temporary tables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, we can keep it as separated PR. But for Spark SQL, the temp table is managed by Catalog, probably we also need to refactor the Catalog code a little bit.

@transient
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true)

// TODO how to handle the temp function per user session?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. But this one should be simpler, we don't handle persisted UDF in Spark SQL for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, the same with Catalog, we also need to think about the how to handle the temp function for current session in FunctionRegistry.

@transient
protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(true)

Expand Down Expand Up @@ -138,6 +142,14 @@ class SQLContext(@transient val sparkContext: SparkContext)

protected[sql] def executePlan(plan: LogicalPlan) = new this.QueryExecution(plan)

@transient
protected[sql] val tlSession = new ThreadLocal[SQLSession]() {
override def initialValue = defaultSession
}

@transient
protected[sql] val defaultSession = createSession()

sparkContext.getConf.getAll.foreach {
case (key, value) if key.startsWith("spark.sql") => setConf(key, value)
case _ =>
Expand Down Expand Up @@ -194,6 +206,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* }}}
*
* @group basic
* TODO move to SQLSession?
*/
@transient
val udf: UDFRegistration = new UDFRegistration(this)
Expand Down Expand Up @@ -1059,6 +1072,32 @@ class SQLContext(@transient val sparkContext: SparkContext)
)
}


protected[sql] def openSession(): SQLSession = {
detachSession()
val session = createSession()
tlSession.set(session)

session
}

protected[sql] def currentSession(): SQLSession = {
tlSession.get()
}

protected[sql] def createSession(): SQLSession = {
new this.SQLSession()
}

protected[sql] def detachSession(): Unit = {
tlSession.remove()
}

protected[sql] class SQLSession {
// Note that this is a lazy val so we can override the default value in subclasses.
protected[sql] lazy val conf: SQLConf = new SQLConf
}

/**
* :: DeveloperApi ::
* The primary workflow for executing relational queries using Spark. Designed to allow easy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,22 @@ import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

/** A SQLContext that can be used for local testing. */
object TestSQLContext
class LocalSQLContext
extends SQLContext(
new SparkContext(
"local[2]",
"TestSQLContext",
new SparkConf().set("spark.sql.testkey", "true"))) {

/** Fewer partitions to speed up testing. */
protected[sql] override lazy val conf: SQLConf = new SQLConf {
override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt
override protected[sql] def createSession(): SQLSession = {
new this.SQLSession()
}

protected[sql] class SQLSession extends super.SQLSession {
protected[sql] override lazy val conf: SQLConf = new SQLConf {
/** Fewer partitions to speed up testing. */
override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt
}
}

/**
Expand All @@ -45,3 +51,6 @@ object TestSQLContext
}

}

object TestSQLContext extends LocalSQLContext

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,146 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
}
}
}

test("test multiple session") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

�Indentations are off in this test case.

import org.apache.spark.sql.SQLConf
var defaultV1: String = null
var defaultV2: String = null

withMultipleConnectionJdbcStatement(
// create table
{ statement =>

val queries = Seq(
"DROP TABLE IF EXISTS test_map",
"CREATE TABLE test_map(key INT, value STRING)",
s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map",
"CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC")

queries.foreach(statement.execute)

val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC")
val buf1 = new collection.mutable.ArrayBuffer[Int]()
while (rs1.next()) {
buf1 += rs1.getInt(1)
}
rs1.close()

val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
val buf2 = new collection.mutable.ArrayBuffer[Int]()
while (rs2.next()) {
buf2 += rs2.getInt(1)
}
rs2.close()

assert(buf1 === buf2)
},

// first session, we get the default value of the session status
{ statement =>

val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}")
rs1.next()
defaultV1 = rs1.getString(1)
assert(defaultV1 != "200")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to add comment to indicate that the expected value should be "<undefined>". I was quite confused at first as 200 should be the default value of "spark.sql.shuffle.partitions" :)

rs1.close()

val rs2 = statement.executeQuery("SET hive.cli.print.header")
rs2.next()

defaultV2 = rs2.getString(1)
assert(defaultV1 != "true")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

defaultV2?

rs2.close()
},

// second session, we update the session status
{ statement =>

val queries = Seq(
s"SET ${SQLConf.SHUFFLE_PARTITIONS}=291",
"SET hive.cli.print.header=true"
)

queries.map(statement.execute)
val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}")
rs1.next()
assert("spark.sql.shuffle.partitions=291" === rs1.getString(1))
rs1.close()

val rs2 = statement.executeQuery("SET hive.cli.print.header")
rs2.next()
assert("hive.cli.print.header=true" === rs2.getString(1))
rs2.close()
},

// third session, we get the latest session status, supposed to be the
// default value
{ statement =>

val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}")
rs1.next()
assert(defaultV1 === rs1.getString(1))
rs1.close()

val rs2 = statement.executeQuery("SET hive.cli.print.header")
rs2.next()
assert(defaultV2 === rs2.getString(1))
rs2.close()
},

// accessing the cached data in another session
{ statement =>

val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC")
val buf1 = new collection.mutable.ArrayBuffer[Int]()
while (rs1.next()) {
buf1 += rs1.getInt(1)
}
rs1.close()

val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
val buf2 = new collection.mutable.ArrayBuffer[Int]()
while (rs2.next()) {
buf2 += rs2.getInt(1)
}
rs2.close()

assert(buf1 === buf2)
statement.executeQuery("UNCACHE TABLE test_table")

// TODO need to figure out how to determine if the data loaded from cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may check the result of EXPLAIN EXTENDED SELECT ... for InMemoryColumnarTableScan.

val rs3 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
val buf3 = new collection.mutable.ArrayBuffer[Int]()
while (rs3.next()) {
buf3 += rs3.getInt(1)
}
rs3.close()

assert(buf1 === buf3)
},

// accessing the uncached table
{ statement =>

// TODO need to figure out how to determine if the data loaded from cache
val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC")
val buf1 = new collection.mutable.ArrayBuffer[Int]()
while (rs1.next()) {
buf1 += rs1.getInt(1)
}
rs1.close()

val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC")
val buf2 = new collection.mutable.ArrayBuffer[Int]()
while (rs2.next()) {
buf2 += rs2.getInt(1)
}
rs2.close()

assert(buf1 === buf2)
}
)
}
}

class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {
Expand Down Expand Up @@ -245,15 +385,22 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
s"jdbc:hive2://localhost:$serverPort/"
}

protected def withJdbcStatement(f: Statement => Unit): Unit = {
val connection = DriverManager.getConnection(jdbcUri, user, "")
val statement = connection.createStatement()

try f(statement) finally {
statement.close()
connection.close()
def withMultipleConnectionJdbcStatement(fs: (Statement => Unit)*) {
val user = System.getProperty("user.name")
val connections = fs.map { _ => DriverManager.getConnection(jdbcUri, user, "") }
val statements = connections.map(_.createStatement())

try {
statements.zip(fs).map { case (s, f) => f(s) }
} finally {
statements.map(_.close())
connections.map(_.close())
}
}

def withJdbcStatement(f: Statement => Unit) {
withMultipleConnectionJdbcStatement(f)
}
}

abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging {
Expand Down
Loading