diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 2136a2ea63543..41fa9168d81f9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -39,8 +39,8 @@ private[hive] object SparkSQLEnv extends Logging { sparkContext.addSparkListener(new StatsReportListener()) hiveContext = new HiveContext(sparkContext) { - @transient override lazy val sessionState = { - val state = SessionState.get() + @transient lazy val sessionState = { + val state = getSessionState() setConf(state.getConf.getAllProperties) state } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fad4091d48a89..005b8a0649b83 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -230,13 +230,21 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * in the HiveConf. */ @transient lazy val hiveconf = new HiveConf(classOf[SessionState]) - @transient protected[hive] lazy val sessionState = { - val ss = new SessionState(hiveconf) - setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. - SessionState.start(ss) - ss.err = new PrintStream(outputBuffer, true, "UTF-8") - ss.out = new PrintStream(outputBuffer, true, "UTF-8") + /** + * If the thread local sessionstate is not set, start a new SessionState + * SessionState.start will put ss to thread local + * @return + */ + def getSessionState() = { + var ss = SessionState.get + if (ss == null) { + ss = new SessionState(hiveconf) + setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. + SessionState.start(ss) + ss.err = new PrintStream(outputBuffer, true, "UTF-8") + ss.out = new PrintStream(outputBuffer, true, "UTF-8") + } ss } @@ -283,6 +291,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { */ protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = { try { + // invoke getSessionState to initialize session state if not already done + val sessionState = getSessionState() val cmd_trimmed: String = cmd.trim() val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 04c48c385966e..239364270d3ed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -51,7 +51,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with tableName: String, alias: Option[String]): LogicalPlan = synchronized { val (databaseName, tblName) = processDatabaseAndTableName( - db.getOrElse(hive.sessionState.getCurrentDatabase), tableName) + db.getOrElse(hive.getSessionState + .getCurrentDatabase), tableName) val table = client.getTable(databaseName, tblName) val partitions: Seq[Partition] = if (table.isPartitioned) { @@ -112,7 +113,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with case CreateTableAsSelect(db, tableName, child) => val (dbName, tblName) = processDatabaseAndTableName(db, tableName) - val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) + val databaseName = dbName.getOrElse(hive.getSessionState.getCurrentDatabase) CreateTableAsSelect(Some(databaseName), tableName, child) }