diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index af6def52d07d1..dc8d97594c7a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -60,6 +60,13 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((key @ SetCommand.VariableName(name), Some(value))) => + val runFunc = (sparkSession: SparkSession) => { + sparkSession.conf.set(name, value) + Seq(Row(key, value)) + } + (keyValueOutput, runFunc) + // Configures a single property. case Some((key, Some(value))) => val runFunc = (sparkSession: SparkSession) => { @@ -117,6 +124,10 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } +object SetCommand { + val VariableName = """hivevar:([^=]+)""".r +} + /** * This command is for resetting SQLConf to the default values. Command that runs * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala index 50725a09c42b3..791a9cf813b6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/VariableSubstitution.scala @@ -17,10 +17,7 @@ package org.apache.spark.sql.internal -import java.util.regex.Pattern - import org.apache.spark.internal.config._ -import org.apache.spark.sql.AnalysisException /** * A helper class that enables substitution using syntax like @@ -37,6 +34,7 @@ class VariableSubstitution(conf: SQLConf) { private val reader = new ConfigReader(provider) .bind("spark", provider) .bind("sparkconf", provider) + .bind("hivevar", provider) .bind("hiveconf", provider) /** @@ -49,5 +47,4 @@ class VariableSubstitution(conf: SQLConf) { input } } - } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 5dafec1c3021b..0c79b6f4211ff 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -38,7 +38,7 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils} import org.apache.spark.util.ShutdownHookManager /** @@ -291,6 +291,10 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { throw new RuntimeException("Remote operations not supported") } + override def setHiveVariables(hiveVariables: java.util.Map[String, String]): Unit = { + hiveVariables.asScala.foreach(kv => SparkSQLEnv.sqlContext.conf.setConfString(kv._1, kv._2)) + } + override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() val cmd_lower = cmd_trimmed.toLowerCase(Locale.ENGLISH) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveVariableSubstitutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveVariableSubstitutionSuite.scala new file mode 100644 index 0000000000000..84d3946ca5c6f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveVariableSubstitutionSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class HiveVariableSubstitutionSuite extends QueryTest with TestHiveSingleton { + test("SET hivevar with prefix") { + spark.sql("SET hivevar:county=gram") + assert(spark.conf.getOption("county") === Some("gram")) + } + + test("SET hivevar with dotted name") { + spark.sql("SET hivevar:eloquent.mosquito.alphabet=zip") + assert(spark.conf.getOption("eloquent.mosquito.alphabet") === Some("zip")) + } + + test("hivevar substitution") { + spark.conf.set("pond", "bus") + checkAnswer(spark.sql("SELECT '${hivevar:pond}'"), Row("bus") :: Nil) + } + + test("variable substitution without a prefix") { + spark.sql("SET hivevar:flask=plaid") + checkAnswer(spark.sql("SELECT '${flask}'"), Row("plaid") :: Nil) + } + + test("variable substitution precedence") { + spark.conf.set("turn.aloof", "questionable") + spark.sql("SET hivevar:turn.aloof=dime") + // hivevar clobbers the conf setting + checkAnswer(spark.sql("SELECT '${turn.aloof}'"), Row("dime") :: Nil) + } +}