Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class SQLContext private[sql](

protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false)

protected[sql] def doPriCheck(logicalPlan: LogicalPlan): Unit = Nil

protected[sql] def executeSql(sql: String):
org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {

lazy val analyzed: LogicalPlan = sqlContext.analyzer.execute(logical)

lazy val authorized: LogicalPlan = {
sqlContext.doPriCheck(analyzed)
analyzed
}

lazy val withCachedData: LogicalPlan = {
assertAnalyzed()
sqlContext.cacheManager.useCachedData(analyzed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ class SparkSQLParser(fallback: => ParserInterface) extends AbstractSparkSQLParse
protected val TABLE = Keyword("TABLE")
protected val TABLES = Keyword("TABLES")
protected val UNCACHE = Keyword("UNCACHE")
protected val ROLE = Keyword("ROLE")

override protected lazy val start: Parser[LogicalPlan] =
cache | uncache | set | show | desc | others
cache | uncache | setRole | set | show | desc | others

private lazy val cache: Parser[LogicalPlan] =
CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
Expand All @@ -91,6 +92,11 @@ class SparkSQLParser(fallback: => ParserInterface) extends AbstractSparkSQLParse
| CLEAR ~ CACHE ^^^ ClearCacheCommand
)

private lazy val setRole: Parser[LogicalPlan] =
SET ~ ROLE ~ ident ^^ {
case set ~ role ~ roleName => fallback.parsePlan(List(set, role, roleName).mkString(" "))
}

private lazy val set: Parser[LogicalPlan] =
SET ~> restInput ^^ {
case input => SetCommandParser(input)
Expand Down Expand Up @@ -120,5 +126,4 @@ class SparkSQLParser(fallback: => ParserInterface) extends AbstractSparkSQLParse
wholeInput ^^ {
case input => fallback.parsePlan(input)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
import org.apache.spark.sql.{execution, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
import org.apache.spark.sql.Strategy

private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SparkPlanner =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{DataFrame, Row, SQLConf}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.{DataFrame, Row, SQLConf}

class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
import testImplicits.localSeqToDataFrameHolder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext:
val ctx = if (hiveContext.hiveThriftServerSingleSession) {
hiveContext
} else {
hiveContext.newSession()
hiveContext.newSession(username)
}
ctx.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion)
sparkSqlOperationManager.sessionToContexts += sessionHandle -> ctx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,16 @@ class HiveContext private[hive](
* and Hive client (both of execution and metadata) with existing HiveContext.
*/
override def newSession(): HiveContext = {
newSession()
}

def newSession(userName: String = null): HiveContext = {
new HiveContext(
sc = sc,
cacheManager = cacheManager,
listener = listener,
execHive = executionHive.newSession(),
metaHive = metadataHive.newSession(),
execHive = executionHive.newSession(userName),
metaHive = metadataHive.newSession(userName),
isRootContext = false)
}

Expand Down Expand Up @@ -550,6 +554,14 @@ class HiveContext private[hive](
new SparkSQLParser(new ExtendedHiveQlParser(this))
}

override protected[sql] def doPriCheck(logicalPlan: LogicalPlan): Unit = {
log.info("check privildege")
val threadClassLoader = Thread.currentThread.getContextClassLoader
Thread.currentThread.setContextClassLoader(metadataHive.getClass.getClassLoader)
val authorizer = metadataHive.checkPrivileges(logicalPlan)
Thread.currentThread.setContextClassLoader(threadClassLoader)
}

@transient
private val hivePlanner = new SparkPlanner(this) with HiveStrategies {
val hiveContext = self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging

"TOK_GRANT",
"TOK_GRANT_ROLE",
"TOK_REVOKE_ROLE",

"TOK_IMPORT",

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import javax.annotation.Nullable

import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

private[hive] case class HiveDatabase(name: String, location: String)

Expand Down Expand Up @@ -86,6 +87,7 @@ private[hive] case class HiveTable(
* shared classes.
*/
private[hive] trait ClientInterface {
def checkPrivileges(logicalPlan: LogicalPlan): Unit

/** Returns the Hive Version of this client. */
def version: HiveVersion
Expand Down Expand Up @@ -182,7 +184,7 @@ private[hive] trait ClientInterface {
def addJar(path: String): Unit

/** Return a ClientInterface as new session, that will share the class loader and Hive client */
def newSession(): ClientInterface
def newSession(userName: String = null): ClientInterface

/** Run a function within Hive state (SessionState, HiveConf, Hive client and class loader) */
def withHiveState[A](f: => A): A
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.hive.client

import java.io.{File, PrintStream}
import java.util.{Map => JMap}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConverters._
import scala.language.reflectiveCalls
Expand All @@ -29,13 +29,18 @@ import org.apache.hadoop.hive.metastore.{TableType => HTableType}
import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema}
import org.apache.hadoop.hive.ql.{metadata, Driver}
import org.apache.hadoop.hive.ql.metadata.Hive
import org.apache.hadoop.hive.ql.plan.HiveOperation
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.security.authorization.plugin._
import org.apache.hadoop.hive.ql.security.authorization.plugin.HivePrivilegeObject
.HivePrivilegeObjectType
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader}
import org.apache.hadoop.security.UserGroupInformation

import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.sql.hive.MetastoreRelation
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.util.{CircularBuffer, Utils}

Expand All @@ -57,6 +62,7 @@ import org.apache.spark.util.{CircularBuffer, Utils}
* this ClientWrapper.
*/
private[hive] class ClientWrapper(
val userName: String = null,
override val version: HiveVersion,
config: Map[String, String],
initClassLoader: ClassLoader,
Expand Down Expand Up @@ -118,13 +124,19 @@ private[hive] class ClientWrapper(
}
initialConf.set(k, v)
}
val state = new SessionState(initialConf)

val state = version match {
case hive.v12 => new SessionState(initialConf)
case _ => new SessionState(initialConf, userName)
}

if (clientLoader.cachedHive != null) {
Hive.set(clientLoader.cachedHive.asInstanceOf[Hive])
}
SessionState.start(state)
state.out = new PrintStream(outputBuffer, true, "UTF-8")
state.err = new PrintStream(outputBuffer, true, "UTF-8")
state.setIsHiveServerQuery(true)
state
} finally {
Thread.currentThread().setContextClassLoader(original)
Expand All @@ -135,6 +147,62 @@ private[hive] class ClientWrapper(
/** Returns the configuration for the current session. */
def conf: HiveConf = SessionState.get().getConf

def checkPrivileges(logicalPlan: LogicalPlan): Unit = {
val authorizer = SessionState.get().getAuthorizerV2
val hiveOp = HiveOperationType.valueOf(getHiveOperation(logicalPlan).name())
val (inputsHObjs, outputsHObjs) = getInputOutputHObjs(logicalPlan)

val hiveAuthzContext = getHiveAuthzContext(logicalPlan, logicalPlan.toString)
authorizer.checkPrivileges(hiveOp, inputsHObjs, outputsHObjs, hiveAuthzContext)
}

def getHiveOperation(logicalPlan: LogicalPlan): HiveOperation = {
logicalPlan match {
case Project(_, _) => HiveOperation.QUERY
case _ => HiveOperation.ALTERINDEX_PROPS // TODO add more types here
}
}

def getHiveAuthzContext(logicalPlan: LogicalPlan, command: String): HiveAuthzContext = {
val authzContextBuilder = new HiveAuthzContext.Builder()
authzContextBuilder.setUserIpAddress(SessionState.get().getUserIpAddress)
authzContextBuilder.setCommandString(command)
authzContextBuilder.build()
}

def getInputOutputHObjs(logicalPlan: LogicalPlan): (JList[HivePrivilegeObject],
JList[HivePrivilegeObject]) = {
val inputObjs = new JArrayList[HivePrivilegeObject]
val outputObjs = new JArrayList[HivePrivilegeObject]
getInputOutputHObjsHelper(inputObjs, outputObjs, null, logicalPlan)
(inputObjs, outputObjs)
}

def getInputOutputHObjsHelper(
inputObjs: JList[HivePrivilegeObject],
outputObjs: JList[HivePrivilegeObject],
hivePrivilegeObjectType: HivePrivilegeObjectType,
logicalPlan: LogicalPlan): Unit = {
logicalPlan match {
case Project(projectionList, child) => buildHivePrivilegeObject(
HivePrivilegeObjectType.TABLE_OR_VIEW, projectionList, inputObjs, child)
case _ => Nil
}
}

private def buildHivePrivilegeObject(
hivePrivilegeObjectType: HivePrivilegeObjectType,
projectionList: Seq[Expression],
hivePriObjs: JList[HivePrivilegeObject], logicalPlan: LogicalPlan): Unit = {
logicalPlan match {
case Filter(_, child) => buildHivePrivilegeObject(hivePrivilegeObjectType, projectionList,
hivePriObjs, child)
case MetastoreRelation(dbName, tblName, _) =>
hivePriObjs.add(new HivePrivilegeObject(hivePrivilegeObjectType, dbName, tblName, null))
case _ => Nil
}
}

override def getConf(key: String, defaultValue: String): String = {
conf.get(key, defaultValue)
}
Expand Down Expand Up @@ -403,13 +471,15 @@ private[hive] class ClientWrapper(
*/
protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = withHiveState {
logDebug(s"Running hiveql '$cmd'")
if (cmd.toLowerCase.startsWith("set")) { logDebug(s"Changing config: $cmd") }
if (cmd.toLowerCase.startsWith("set") && !cmd.toLowerCase.startsWith("set role ")) {
logDebug(s"Changing config: $cmd")
}
try {
val cmd_trimmed: String = cmd.trim()
val tokens: Array[String] = cmd_trimmed.split("\\s+")
// The remainder of the command.
val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim()
val proc = shim.getCommandProcessor(tokens(0), conf)
val proc = shim.getCommandProcessor(tokens, conf)
proc match {
case driver: Driver =>
val response: CommandProcessorResponse = driver.run(cmd)
Expand Down Expand Up @@ -512,8 +582,8 @@ private[hive] class ClientWrapper(
runSqlHive(s"ADD JAR $path")
}

def newSession(): ClientWrapper = {
clientLoader.createClient().asInstanceOf[ClientWrapper]
def newSession(userName: String = null): ClientWrapper = {
clientLoader.createClient(userName).asInstanceOf[ClientWrapper]
}

def reset(): Unit = withHiveState {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ private[client] sealed abstract class Shim {

def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition]

def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor
def getCommandProcessor(token: Array[String], conf: HiveConf): CommandProcessor

def getDriverResults(driver: Driver): Seq[String]

Expand Down Expand Up @@ -214,8 +214,8 @@ private[client] class Shim_v0_12 extends Shim with Logging {
getAllPartitions(hive, table)
}

override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor]
override def getCommandProcessor(token: Array[String], conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, token(0), conf).asInstanceOf[CommandProcessor]

override def getDriverResults(driver: Driver): Seq[String] = {
val res = new JArrayList[String]()
Expand Down Expand Up @@ -358,8 +358,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
partitions.asScala.toSeq
}

override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor]
override def getCommandProcessor(token: Array[String], conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor]

override def getDriverResults(driver: Driver): Seq[String] = {
val res = new JArrayList[Object]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ private[hive] class IsolatedClientLoader(
}

/** The isolated client interface to Hive. */
private[hive] def createClient(): ClientInterface = {
private[hive] def createClient(userName: String = null): ClientInterface = {
if (!isolationOn) {
return new ClientWrapper(version, config, baseClassLoader, this)
return new ClientWrapper(userName, version, config, baseClassLoader, this)
}
// Pre-reflective instantiation setup.
logDebug("Initializing the logger to avoid disaster...")
Expand All @@ -246,7 +246,7 @@ private[hive] class IsolatedClientLoader(
classLoader
.loadClass(classOf[ClientWrapper].getName)
.getConstructors.head
.newInstance(version, config, classLoader, this)
.newInstance(userName, version, config, classLoader, this)
.asInstanceOf[ClientInterface]
} catch {
case e: InvocationTargetException =>
Expand Down