Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ private[hive] object SparkSQLCLIDriver extends Logging {

// Set all properties specified via command line.
val conf: HiveConf = sessionState.getConf
// Use startup ClassLoader replace hiveConf's UDFClassLoader
conf.setClassLoader(Thread.currentThread().getContextClassLoader)
sessionState.cmdProperties.entrySet().asScala.foreach { item =>
val key = item.getKey.toString
val value = item.getValue.toString
Expand All @@ -133,20 +135,7 @@ private[hive] object SparkSQLCLIDriver extends Logging {
// Clean up after we exit
ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() }

val remoteMode = isRemoteMode(sessionState)
// "-h" option has been passed, so connect to Hive thrift server.
if (!remoteMode) {
// Hadoop-20 and above - we need to augment classpath using hiveconf
// components.
// See also: code in ExecDriver.java
var loader = conf.getClassLoader
val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS)
if (StringUtils.isNotBlank(auxJars)) {
loader = ThriftserverShimUtils.addToClassPath(loader, StringUtils.split(auxJars, ","))
}
conf.setClassLoader(loader)
Thread.currentThread().setContextClassLoader(loader)
} else {
if (isRemoteMode(sessionState)) {
// Hive 1.2 + not supported in CLI
throw new RuntimeException("Remote operations not supported")
}
Expand All @@ -164,6 +153,15 @@ private[hive] object SparkSQLCLIDriver extends Logging {
val cli = new SparkSQLCLIDriver
cli.setHiveVariables(oproc.getHiveVariables)

// In SparkSQL CLI, we may want to use jars augmented by hiveconf
// hive.aux.jars.path, here we add jars augmented by hiveconf to
// Spark's SessionResourceLoader to obtain these jars.
val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS)
if (StringUtils.isNotBlank(auxJars)) {
val resourceLoader = SparkSQLEnv.sqlContext.sessionState.resourceLoader
StringUtils.split(auxJars, ",").foreach(resourceLoader.addJar(_))
}

// TODO work around for set the log output to console, because the HiveContext
// will set the output into an invalid buffer.
sessionState.in = System.in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,15 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
"SELECT example_max(1);" -> "1"
)
}

test("SPARK-28954 test --jars command") {
val jarFile = new File("../../sql/hive/src/test/resources/SPARK-21101-1.0.jar").getCanonicalPath
runCliWithin(
1.minute,
Seq("--jars", s"$jarFile"))(
s"CREATE TEMPORARY FUNCTION testjar AS" +
s" 'org.apache.spark.sql.hive.execution.UDTFStack';" -> "",
"SELECT testjar(1,'A', 10);" -> "A\t10"
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.hive.thriftserver

import org.apache.commons.logging.LogFactory
import org.apache.hadoop.hive.ql.exec.Utilities
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hive.service.cli.{RowSet, RowSetFactory, TableSchema, Type}
import org.apache.hive.service.cli.thrift.TProtocolVersion._
Expand Down Expand Up @@ -51,12 +50,6 @@ private[thriftserver] object ThriftserverShimUtils {

private[thriftserver] def toJavaSQLType(s: String): Int = Type.getType(s).toJavaSQLType

private[thriftserver] def addToClassPath(
loader: ClassLoader,
auxJars: Array[String]): ClassLoader = {
Utilities.addToClassPath(loader, auxJars)
}

private[thriftserver] val testedProtocolVersions = Seq(
HIVE_CLI_SERVICE_PROTOCOL_V1,
HIVE_CLI_SERVICE_PROTOCOL_V2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@

package org.apache.spark.sql.hive.thriftserver

import java.security.AccessController

import scala.collection.JavaConverters._

import org.apache.hadoop.hive.ql.exec.AddToClassPathAction
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde2.thrift.Type
Expand Down Expand Up @@ -56,13 +52,6 @@ private[thriftserver] object ThriftserverShimUtils {

private[thriftserver] def toJavaSQLType(s: String): Int = Type.getType(s).toJavaSQLType

private[thriftserver] def addToClassPath(
loader: ClassLoader,
auxJars: Array[String]): ClassLoader = {
val addAction = new AddToClassPathAction(loader, auxJars.toList.asJava)
AccessController.doPrivileged(addAction)
}

private[thriftserver] val testedProtocolVersions = Seq(
HIVE_CLI_SERVICE_PROTOCOL_V1,
HIVE_CLI_SERVICE_PROTOCOL_V2,
Expand Down