From 9178b4468f44167f44d3a7cb18c34fcd010e3542 Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Thu, 3 Dec 2015 19:53:30 -0500 Subject: [PATCH 1/5] [SPARK-12145][SQL] Command 'Set Role [ADMIN|NONE|ALL]' doesn't work in SQL based authorization --- .../spark/sql/catalyst/SqlParserSuite.scala | 178 ++++++++++++++++++ .../spark/sql/execution/SparkSQLParser.scala | 9 +- .../thriftserver/SparkSQLSessionManager.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 8 +- .../sql/hive/client/ClientInterface.scala | 2 +- .../spark/sql/hive/client/ClientWrapper.scala | 14 +- .../spark/sql/hive/client/HiveShim.scala | 10 +- .../hive/client/IsolatedClientLoader.scala | 6 +- 8 files changed, 210 insertions(+), 19 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala new file mode 100644 index 000000000000..2134fa1dcaa8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias +import org.apache.spark.sql.catalyst.expressions.{Literal, GreaterThan, Not, Attribute} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, LogicalPlan, Command} +import org.apache.spark.unsafe.types.CalendarInterval + +private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { + override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty +} + +private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { + protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") + + override protected lazy val start: Parser[LogicalPlan] = set + + private lazy val set: Parser[LogicalPlan] = + EXECUTE ~> ident ^^ { + case fileName => TestCommand(fileName) + } +} + +private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { + protected val EXECUTE = Keyword("EXECUTE") + + override protected lazy val start: Parser[LogicalPlan] = set + + private lazy val set: Parser[LogicalPlan] = + EXECUTE ~> ident ^^ { + case fileName => TestCommand(fileName) + } +} + +private[sql] class SetRoleTestParser extends AbstractSparkSQLParser{ + protected val SET = Keyword("SET") + protected val ROLE = Keyword("ROLE") + + override protected lazy val start: Parser[LogicalPlan] = setR | set + + private lazy val setR: Parser[LogicalPlan] = + SET ~ ROLE ~ ident ^^ { + case set ~ role ~ roleName => TestCommand(List(set, role, roleName).mkString(" ")) + } + + private lazy val set: Parser[LogicalPlan] = + SET ~> restInput ^^ { + case input => TestCommand(input) + } +} + +class SqlParserSuite extends PlanTest { + + test("test long keyword") { + val parser = new SuperLongKeywordTestParser + assert(TestCommand("NotRealCommand") === + parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand")) + } + + test("test case insensitive") { + val parser = new CaseInsensitiveTestParser + assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand")) + } + + test("test set role command") { + val parser = new SetRoleTestParser + assert(TestCommand(" A.B.C = ADMIN") === parser.parse("SET A.B.C = ADMIN")) + assert(TestCommand(" A.B.C") === parser.parse("SET A.B.C")) + assert(TestCommand(" ROLE.A.B = ADMIN") === parser.parse("SET ROLE.A.B = ADMIN")) + assert(TestCommand(" ROLEA.A.B = ADMIN") === parser.parse("SET ROLEA.A.B = ADMIN")) + assert(TestCommand(" role.A.B = ADMIN") === parser.parse("SET role.A.B = ADMIN")) + assert(TestCommand("set role ADMIN") === parser.parse("SET ROLE ADMIN")) + assert(TestCommand("set role ADMIN") === parser.parse("SET role ADMIN")) + } + + test("test NOT operator with comparison operations") { + val parsed = SqlParser.parse("SELECT NOT TRUE > TRUE") + val expected = Project( + UnresolvedAlias( + Not( + GreaterThan(Literal(true), Literal(true))) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + test("support hive interval literal") { + def checkInterval(sql: String, result: CalendarInterval): Unit = { + val parsed = SqlParser.parse(sql) + val expected = Project( + UnresolvedAlias( + Literal(result) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + def checkYearMonth(lit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' YEAR TO MONTH", + CalendarInterval.fromYearMonthString(lit)) + } + + def checkDayTime(lit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' DAY TO SECOND", + CalendarInterval.fromDayTimeString(lit)) + } + + def checkSingleUnit(lit: String, unit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' $unit", + CalendarInterval.fromSingleUnitString(unit, lit)) + } + + checkYearMonth("123-10") + checkYearMonth("496-0") + checkYearMonth("-2-3") + checkYearMonth("-123-0") + + checkDayTime("99 11:22:33.123456789") + checkDayTime("-99 11:22:33.123456789") + checkDayTime("10 9:8:7.123456789") + checkDayTime("1 0:0:0") + checkDayTime("-1 0:0:0") + checkDayTime("1 0:0:1") + + for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) { + checkSingleUnit("7", unit) + checkSingleUnit("-7", unit) + checkSingleUnit("0", unit) + } + + checkSingleUnit("13.123456789", "second") + checkSingleUnit("-13.123456789", "second") + } + + test("support scientific notation") { + def assertRight(input: String, output: Double): Unit = { + val parsed = SqlParser.parse("SELECT " + input) + val expected = Project( + UnresolvedAlias( + Literal(output) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + assertRight("9.0e1", 90) + assertRight(".9e+2", 90) + assertRight("0.9e+2", 90) + assertRight("900e-1", 90) + assertRight("900.0E-1", 90) + assertRight("9.e+1", 90) + + intercept[RuntimeException](SqlParser.parse("SELECT .e3")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala index d2d827156372..b6189e196155 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala @@ -74,9 +74,10 @@ class SparkSQLParser(fallback: => ParserInterface) extends AbstractSparkSQLParse protected val TABLE = Keyword("TABLE") protected val TABLES = Keyword("TABLES") protected val UNCACHE = Keyword("UNCACHE") + protected val ROLE = Keyword("ROLE") override protected lazy val start: Parser[LogicalPlan] = - cache | uncache | set | show | desc | others + cache | uncache | setRole | set | show | desc | others private lazy val cache: Parser[LogicalPlan] = CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { @@ -91,6 +92,11 @@ class SparkSQLParser(fallback: => ParserInterface) extends AbstractSparkSQLParse | CLEAR ~ CACHE ^^^ ClearCacheCommand ) + private lazy val setRole: Parser[LogicalPlan] = + SET ~ ROLE ~ ident ^^ { + case set ~ role ~ roleName => fallback(List(set, role, roleName).mkString(" ")) + } + private lazy val set: Parser[LogicalPlan] = SET ~> restInput ^^ { case input => SetCommandParser(input) @@ -120,5 +126,4 @@ class SparkSQLParser(fallback: => ParserInterface) extends AbstractSparkSQLParse wholeInput ^^ { case input => fallback.parsePlan(input) } - } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index de4e9c62b57a..586403a7d050 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -74,7 +74,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: val ctx = if (hiveContext.hiveThriftServerSingleSession) { hiveContext } else { - hiveContext.newSession() + hiveContext.newSession(username) } ctx.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) sparkSqlOperationManager.sessionToContexts += sessionHandle -> ctx diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index eaca3c9269bb..094aacbd875b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -100,12 +100,16 @@ class HiveContext private[hive]( * and Hive client (both of execution and metadata) with existing HiveContext. */ override def newSession(): HiveContext = { + newSession() + } + + def newSession(userName: String = null): HiveContext = { new HiveContext( sc = sc, cacheManager = cacheManager, listener = listener, - execHive = executionHive.newSession(), - metaHive = metadataHive.newSession(), + execHive = executionHive.newSession(userName), + metaHive = metadataHive.newSession(userName), isRootContext = false) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 9d9a55edd731..54528fa48c3e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -182,7 +182,7 @@ private[hive] trait ClientInterface { def addJar(path: String): Unit /** Return a ClientInterface as new session, that will share the class loader and Hive client */ - def newSession(): ClientInterface + def newSession(userName: String = null): ClientInterface /** Run a function within Hive state (SessionState, HiveConf, Hive client and class loader) */ def withHiveState[A](f: => A): A diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index ce7a305d437a..07056a045b2f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -57,6 +57,7 @@ import org.apache.spark.util.{CircularBuffer, Utils} * this ClientWrapper. */ private[hive] class ClientWrapper( + val userName: String = null, override val version: HiveVersion, config: Map[String, String], initClassLoader: ClassLoader, @@ -118,13 +119,14 @@ private[hive] class ClientWrapper( } initialConf.set(k, v) } - val state = new SessionState(initialConf) + val state = new SessionState(initialConf, userName) if (clientLoader.cachedHive != null) { Hive.set(clientLoader.cachedHive.asInstanceOf[Hive]) } SessionState.start(state) state.out = new PrintStream(outputBuffer, true, "UTF-8") state.err = new PrintStream(outputBuffer, true, "UTF-8") + state.setIsHiveServerQuery(true) state } finally { Thread.currentThread().setContextClassLoader(original) @@ -403,13 +405,15 @@ private[hive] class ClientWrapper( */ protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = withHiveState { logDebug(s"Running hiveql '$cmd'") - if (cmd.toLowerCase.startsWith("set")) { logDebug(s"Changing config: $cmd") } + if (cmd.toLowerCase.startsWith("set") && !cmd.toLowerCase.startsWith("set role ")) { + logDebug(s"Changing config: $cmd") + } try { val cmd_trimmed: String = cmd.trim() val tokens: Array[String] = cmd_trimmed.split("\\s+") // The remainder of the command. val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() - val proc = shim.getCommandProcessor(tokens(0), conf) + val proc = shim.getCommandProcessor(tokens, conf) proc match { case driver: Driver => val response: CommandProcessorResponse = driver.run(cmd) @@ -512,8 +516,8 @@ private[hive] class ClientWrapper( runSqlHive(s"ADD JAR $path") } - def newSession(): ClientWrapper = { - clientLoader.createClient().asInstanceOf[ClientWrapper] + def newSession(userName: String = null): ClientWrapper = { + clientLoader.createClient(userName).asInstanceOf[ClientWrapper] } def reset(): Unit = withHiveState { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index ca636b0265d4..16bc57b289f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -68,7 +68,7 @@ private[client] sealed abstract class Shim { def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition] - def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor + def getCommandProcessor(token: Array[String], conf: HiveConf): CommandProcessor def getDriverResults(driver: Driver): Seq[String] @@ -214,8 +214,8 @@ private[client] class Shim_v0_12 extends Shim with Logging { getAllPartitions(hive, table) } - override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = - getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor] + override def getCommandProcessor(token: Array[String], conf: HiveConf): CommandProcessor = + getCommandProcessorMethod.invoke(null, token(0), conf).asInstanceOf[CommandProcessor] override def getDriverResults(driver: Driver): Seq[String] = { val res = new JArrayList[String]() @@ -358,8 +358,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { partitions.asScala.toSeq } - override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = - getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor] + override def getCommandProcessor(token: Array[String], conf: HiveConf): CommandProcessor = + getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor] override def getDriverResults(driver: Driver): Seq[String] = { val res = new JArrayList[Object]() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 010051d255fd..2a2c7169abc6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -233,9 +233,9 @@ private[hive] class IsolatedClientLoader( } /** The isolated client interface to Hive. */ - private[hive] def createClient(): ClientInterface = { + private[hive] def createClient(userName: String = null): ClientInterface = { if (!isolationOn) { - return new ClientWrapper(version, config, baseClassLoader, this) + return new ClientWrapper(userName, version, config, baseClassLoader, this) } // Pre-reflective instantiation setup. logDebug("Initializing the logger to avoid disaster...") @@ -246,7 +246,7 @@ private[hive] class IsolatedClientLoader( classLoader .loadClass(classOf[ClientWrapper].getName) .getConstructors.head - .newInstance(version, config, classLoader, this) + .newInstance(userName, version, config, classLoader, this) .asInstanceOf[ClientInterface] } catch { case e: InvocationTargetException => From 7351f835b564b0d8c6340678f938568d8822b752 Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Mon, 18 Jan 2016 21:50:50 -0500 Subject: [PATCH 2/5] Rebase code --- .../scala/org/apache/spark/sql/execution/SparkSQLParser.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala index b6189e196155..75fa85cf8e97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala @@ -94,7 +94,7 @@ class SparkSQLParser(fallback: => ParserInterface) extends AbstractSparkSQLParse private lazy val setRole: Parser[LogicalPlan] = SET ~ ROLE ~ ident ^^ { - case set ~ role ~ roleName => fallback(List(set, role, roleName).mkString(" ")) + case set ~ role ~ roleName => fallback.parsePlan(List(set, role, roleName).mkString(" ")) } private lazy val set: Parser[LogicalPlan] = From 8055065ee636939169fd25c88b8db6fe3fc3a934 Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Tue, 19 Jan 2016 20:48:17 -0500 Subject: [PATCH 3/5] Fix import order issue --- .../scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala | 4 ++-- .../org/apache/spark/sql/execution/SparkStrategies.scala | 5 +++-- .../apache/spark/sql/execution/joins/InnerJoinSuite.scala | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 2134fa1dcaa8..e968fff8870f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.expressions.{Literal, GreaterThan, Not, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GreaterThan, Literal, Not} +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, OneRowRelation, Project} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, LogicalPlan, Command} import org.apache.spark.unsafe.types.CalendarInterval private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index df0f73049921..b68aea2705ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} -import org.apache.spark.sql.{execution, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -26,9 +24,12 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} +import org.apache.spark.sql.Strategy private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index ab81b702596a..149f34dbd748 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.{DataFrame, Row, SQLConf} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner @@ -24,7 +25,6 @@ import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -import org.apache.spark.sql.{DataFrame, Row, SQLConf} class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.localSeqToDataFrameHolder From b99768d261ce7f7875603c8bc9c8d18c36f474fa Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Wed, 13 Jan 2016 14:23:37 -0500 Subject: [PATCH 4/5] Add revoke supports --- .../src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 1 + .../org/apache/spark/sql/hive/client/ClientWrapper.scala | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 46246f8191db..e1864c077725 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -128,6 +128,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging "TOK_GRANT", "TOK_GRANT_ROLE", + "TOK_REVOKE_ROLE", "TOK_IMPORT", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 07056a045b2f..c09b50d65375 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -119,7 +119,12 @@ private[hive] class ClientWrapper( } initialConf.set(k, v) } - val state = new SessionState(initialConf, userName) + + val state = version match { + case hive.v12 => new SessionState(initialConf) + case _ => new SessionState(initialConf, userName) + } + if (clientLoader.cachedHive != null) { Hive.set(clientLoader.cachedHive.asInstanceOf[Hive]) } From 4a7295199631fa558627bd7fbd68cd0e6f1edf22 Mon Sep 17 00:00:00 2001 From: Ferdinand Xu Date: Sun, 31 Jan 2016 22:37:59 -0500 Subject: [PATCH 5/5] Add SQL based authorization support initial part --- .../spark/sql/catalyst/SqlParserSuite.scala | 178 ------------------ .../org/apache/spark/sql/SQLContext.scala | 2 + .../spark/sql/execution/QueryExecution.scala | 5 + .../apache/spark/sql/hive/HiveContext.scala | 8 + .../sql/hive/client/ClientInterface.scala | 2 + .../spark/sql/hive/client/ClientWrapper.scala | 67 ++++++- 6 files changed, 81 insertions(+), 181 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala deleted file mode 100644 index e968fff8870f..000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.expressions.{Attribute, GreaterThan, Literal, Not} -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, OneRowRelation, Project} -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.unsafe.types.CalendarInterval - -private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { - override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty -} - -private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") - - override protected lazy val start: Parser[LogicalPlan] = set - - private lazy val set: Parser[LogicalPlan] = - EXECUTE ~> ident ^^ { - case fileName => TestCommand(fileName) - } -} - -private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("EXECUTE") - - override protected lazy val start: Parser[LogicalPlan] = set - - private lazy val set: Parser[LogicalPlan] = - EXECUTE ~> ident ^^ { - case fileName => TestCommand(fileName) - } -} - -private[sql] class SetRoleTestParser extends AbstractSparkSQLParser{ - protected val SET = Keyword("SET") - protected val ROLE = Keyword("ROLE") - - override protected lazy val start: Parser[LogicalPlan] = setR | set - - private lazy val setR: Parser[LogicalPlan] = - SET ~ ROLE ~ ident ^^ { - case set ~ role ~ roleName => TestCommand(List(set, role, roleName).mkString(" ")) - } - - private lazy val set: Parser[LogicalPlan] = - SET ~> restInput ^^ { - case input => TestCommand(input) - } -} - -class SqlParserSuite extends PlanTest { - - test("test long keyword") { - val parser = new SuperLongKeywordTestParser - assert(TestCommand("NotRealCommand") === - parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand")) - } - - test("test case insensitive") { - val parser = new CaseInsensitiveTestParser - assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand")) - } - - test("test set role command") { - val parser = new SetRoleTestParser - assert(TestCommand(" A.B.C = ADMIN") === parser.parse("SET A.B.C = ADMIN")) - assert(TestCommand(" A.B.C") === parser.parse("SET A.B.C")) - assert(TestCommand(" ROLE.A.B = ADMIN") === parser.parse("SET ROLE.A.B = ADMIN")) - assert(TestCommand(" ROLEA.A.B = ADMIN") === parser.parse("SET ROLEA.A.B = ADMIN")) - assert(TestCommand(" role.A.B = ADMIN") === parser.parse("SET role.A.B = ADMIN")) - assert(TestCommand("set role ADMIN") === parser.parse("SET ROLE ADMIN")) - assert(TestCommand("set role ADMIN") === parser.parse("SET role ADMIN")) - } - - test("test NOT operator with comparison operations") { - val parsed = SqlParser.parse("SELECT NOT TRUE > TRUE") - val expected = Project( - UnresolvedAlias( - Not( - GreaterThan(Literal(true), Literal(true))) - ) :: Nil, - OneRowRelation) - comparePlans(parsed, expected) - } - - test("support hive interval literal") { - def checkInterval(sql: String, result: CalendarInterval): Unit = { - val parsed = SqlParser.parse(sql) - val expected = Project( - UnresolvedAlias( - Literal(result) - ) :: Nil, - OneRowRelation) - comparePlans(parsed, expected) - } - - def checkYearMonth(lit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' YEAR TO MONTH", - CalendarInterval.fromYearMonthString(lit)) - } - - def checkDayTime(lit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' DAY TO SECOND", - CalendarInterval.fromDayTimeString(lit)) - } - - def checkSingleUnit(lit: String, unit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' $unit", - CalendarInterval.fromSingleUnitString(unit, lit)) - } - - checkYearMonth("123-10") - checkYearMonth("496-0") - checkYearMonth("-2-3") - checkYearMonth("-123-0") - - checkDayTime("99 11:22:33.123456789") - checkDayTime("-99 11:22:33.123456789") - checkDayTime("10 9:8:7.123456789") - checkDayTime("1 0:0:0") - checkDayTime("-1 0:0:0") - checkDayTime("1 0:0:1") - - for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) { - checkSingleUnit("7", unit) - checkSingleUnit("-7", unit) - checkSingleUnit("0", unit) - } - - checkSingleUnit("13.123456789", "second") - checkSingleUnit("-13.123456789", "second") - } - - test("support scientific notation") { - def assertRight(input: String, output: Double): Unit = { - val parsed = SqlParser.parse("SELECT " + input) - val expected = Project( - UnresolvedAlias( - Literal(output) - ) :: Nil, - OneRowRelation) - comparePlans(parsed, expected) - } - - assertRight("9.0e1", 90) - assertRight(".9e+2", 90) - assertRight("0.9e+2", 90) - assertRight("900e-1", 90) - assertRight("900.0E-1", 90) - assertRight("9.e+1", 90) - - intercept[RuntimeException](SqlParser.parse("SELECT .e3")) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 147e3557b632..fbc85e41b117 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -210,6 +210,8 @@ class SQLContext private[sql]( protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false) + protected[sql] def doPriCheck(logicalPlan: LogicalPlan): Unit = Nil + protected[sql] def executeSql(sql: String): org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 107570f9dbcc..840c1d75ad37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -35,6 +35,11 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { lazy val analyzed: LogicalPlan = sqlContext.analyzer.execute(logical) + lazy val authorized: LogicalPlan = { + sqlContext.doPriCheck(analyzed) + analyzed + } + lazy val withCachedData: LogicalPlan = { assertAnalyzed() sqlContext.cacheManager.useCachedData(analyzed) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 094aacbd875b..9a7fe450178e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -554,6 +554,14 @@ class HiveContext private[hive]( new SparkSQLParser(new ExtendedHiveQlParser(this)) } + override protected[sql] def doPriCheck(logicalPlan: LogicalPlan): Unit = { + log.info("check privildege") + val threadClassLoader = Thread.currentThread.getContextClassLoader + Thread.currentThread.setContextClassLoader(metadataHive.getClass.getClassLoader) + val authorizer = metadataHive.checkPrivileges(logicalPlan) + Thread.currentThread.setContextClassLoader(threadClassLoader) + } + @transient private val hivePlanner = new SparkPlanner(this) with HiveStrategies { val hiveContext = self diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 54528fa48c3e..90f56dca56f1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -23,6 +23,7 @@ import javax.annotation.Nullable import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan private[hive] case class HiveDatabase(name: String, location: String) @@ -86,6 +87,7 @@ private[hive] case class HiveTable( * shared classes. */ private[hive] trait ClientInterface { + def checkPrivileges(logicalPlan: LogicalPlan): Unit /** Returns the Hive Version of this client. */ def version: HiveVersion diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index c09b50d65375..5acea2638ec9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} -import java.util.{Map => JMap} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.language.reflectiveCalls @@ -29,13 +29,18 @@ import org.apache.hadoop.hive.metastore.{TableType => HTableType} import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} import org.apache.hadoop.hive.ql.{metadata, Driver} import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.plan.HiveOperation import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.security.authorization.plugin._ +import org.apache.hadoop.hive.ql.security.authorization.plugin.HivePrivilegeObject + .HivePrivilegeObjectType import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.sql.hive.MetastoreRelation import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.util.{CircularBuffer, Utils} @@ -121,7 +126,7 @@ private[hive] class ClientWrapper( } val state = version match { - case hive.v12 => new SessionState(initialConf) + case hive.v12 => new SessionState(initialConf) case _ => new SessionState(initialConf, userName) } @@ -142,6 +147,62 @@ private[hive] class ClientWrapper( /** Returns the configuration for the current session. */ def conf: HiveConf = SessionState.get().getConf + def checkPrivileges(logicalPlan: LogicalPlan): Unit = { + val authorizer = SessionState.get().getAuthorizerV2 + val hiveOp = HiveOperationType.valueOf(getHiveOperation(logicalPlan).name()) + val (inputsHObjs, outputsHObjs) = getInputOutputHObjs(logicalPlan) + + val hiveAuthzContext = getHiveAuthzContext(logicalPlan, logicalPlan.toString) + authorizer.checkPrivileges(hiveOp, inputsHObjs, outputsHObjs, hiveAuthzContext) + } + + def getHiveOperation(logicalPlan: LogicalPlan): HiveOperation = { + logicalPlan match { + case Project(_, _) => HiveOperation.QUERY + case _ => HiveOperation.ALTERINDEX_PROPS // TODO add more types here + } + } + + def getHiveAuthzContext(logicalPlan: LogicalPlan, command: String): HiveAuthzContext = { + val authzContextBuilder = new HiveAuthzContext.Builder() + authzContextBuilder.setUserIpAddress(SessionState.get().getUserIpAddress) + authzContextBuilder.setCommandString(command) + authzContextBuilder.build() + } + + def getInputOutputHObjs(logicalPlan: LogicalPlan): (JList[HivePrivilegeObject], + JList[HivePrivilegeObject]) = { + val inputObjs = new JArrayList[HivePrivilegeObject] + val outputObjs = new JArrayList[HivePrivilegeObject] + getInputOutputHObjsHelper(inputObjs, outputObjs, null, logicalPlan) + (inputObjs, outputObjs) + } + + def getInputOutputHObjsHelper( + inputObjs: JList[HivePrivilegeObject], + outputObjs: JList[HivePrivilegeObject], + hivePrivilegeObjectType: HivePrivilegeObjectType, + logicalPlan: LogicalPlan): Unit = { + logicalPlan match { + case Project(projectionList, child) => buildHivePrivilegeObject( + HivePrivilegeObjectType.TABLE_OR_VIEW, projectionList, inputObjs, child) + case _ => Nil + } + } + + private def buildHivePrivilegeObject( + hivePrivilegeObjectType: HivePrivilegeObjectType, + projectionList: Seq[Expression], + hivePriObjs: JList[HivePrivilegeObject], logicalPlan: LogicalPlan): Unit = { + logicalPlan match { + case Filter(_, child) => buildHivePrivilegeObject(hivePrivilegeObjectType, projectionList, + hivePriObjs, child) + case MetastoreRelation(dbName, tblName, _) => + hivePriObjs.add(new HivePrivilegeObject(hivePrivilegeObjectType, dbName, tblName, null)) + case _ => Nil + } + } + override def getConf(key: String, defaultValue: String): String = { conf.get(key, defaultValue) }