diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index ef5648c6dbe47..cb0d33f8aaf1a 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -160,7 +160,7 @@ statement | op=(ADD | LIST) identifier .*? #manageResource | SET ROLE .*? #failNativeCommand | SET .*? #setConfiguration - | RESET #resetConfiguration + | RESET .*? #resetConfiguration | unsupportedHiveNativeCommands .*? #failNativeCommand ; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3c58c6e1b6780..829620d7c033d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -82,11 +82,13 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * Example SQL : * {{{ * RESET; + * RESET key; * }}} */ override def visitResetConfiguration( ctx: ResetConfigurationContext): LogicalPlan = withOrigin(ctx) { - ResetCommand + val raw = remainder(ctx.RESET.getSymbol) + if (raw.nonEmpty) ResetCommand(Some(raw.trim)) else ResetCommand(None) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 5f12830ee621f..ce23f069232ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -150,12 +150,21 @@ object SetCommand { * This command is for resetting SQLConf to the default values. Command that runs * {{{ * reset; + * reset key; + * reset key1 key2 ...; * }}} */ -case object ResetCommand extends RunnableCommand with Logging { +case class ResetCommand(key: Option[String]) extends RunnableCommand with Logging { override def run(sparkSession: SparkSession): Seq[Row] = { - sparkSession.sessionState.conf.clear() + key match { + case None => + sparkSession.sessionState.conf.clear() + // "RESET key" clear a specific property. + case Some(key) => + key.split("\\s+") + .foreach(confName => if (!confName.isEmpty) sparkSession.conf.unset(confName)) + } Seq.empty[Row] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index b32fb90e10072..e2f341f439ffa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -301,4 +301,10 @@ class SparkSqlParserSuite extends PlanTest { "SELECT a || b || c FROM t", Project(UnresolvedAlias(concat) :: Nil, UnresolvedRelation(TableIdentifier("t")))) } + + test("reset") { + assertEqual("reset", ResetCommand(None)) + assertEqual("reset spark.test.property", ResetCommand(Some("spark.test.property"))) + assertEqual("reset #$a! !a$# \t ", ResetCommand(Some("#$a! !a$#"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index a283ff971adcd..5067b8a9619ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -114,50 +114,79 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } } - test("reset - public conf") { - spark.sessionState.conf.clear() - val original = spark.conf.get(SQLConf.GROUP_BY_ORDINAL) - try { - assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL) === true) - sql(s"set ${SQLConf.GROUP_BY_ORDINAL.key}=false") - assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL) === false) - assert(sql(s"set").where(s"key = '${SQLConf.GROUP_BY_ORDINAL.key}'").count() == 1) - sql(s"reset") - assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL) === true) - assert(sql(s"set").where(s"key = '${SQLConf.GROUP_BY_ORDINAL.key}'").count() == 0) - } finally { - sql(s"set ${SQLConf.GROUP_BY_ORDINAL}=$original") + Seq("reset", s"reset ${SQLConf.GROUP_BY_ORDINAL.key}").foreach { resetCmd => + test(s"$resetCmd - public conf") { + spark.sessionState.conf.clear() + val original = spark.conf.get(SQLConf.GROUP_BY_ORDINAL) + try { + assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL) === true) + sql(s"set ${SQLConf.GROUP_BY_ORDINAL.key}=false") + assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL) === false) + assert(sql(s"set").where(s"key = '${SQLConf.GROUP_BY_ORDINAL.key}'").count() == 1) + sql(resetCmd) + assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL) === true) + assert(sql(s"set").where(s"key = '${SQLConf.GROUP_BY_ORDINAL.key}'").count() == 0) + } finally { + sql(s"set ${SQLConf.GROUP_BY_ORDINAL}=$original") + } } } - test("reset - internal conf") { - spark.sessionState.conf.clear() - val original = spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) - try { - assert(spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 100) - sql(s"set ${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}=10") - assert(spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 10) - assert(sql(s"set").where(s"key = '${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}'").count() == 1) - sql(s"reset") - assert(spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 100) - assert(sql(s"set").where(s"key = '${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}'").count() == 0) - } finally { - sql(s"set ${SQLConf.OPTIMIZER_MAX_ITERATIONS}=$original") + Seq("reset", s"reset ${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}").foreach { resetCmd => + test(s"$resetCmd - internal conf") { + spark.sessionState.conf.clear() + val original = spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) + try { + assert(spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 100) + sql(s"set ${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}=10") + assert(spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 10) + assert(sql(s"set").where(s"key = '${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}'").count() == 1) + sql(resetCmd) + assert(spark.conf.get(SQLConf.OPTIMIZER_MAX_ITERATIONS) === 100) + assert(sql(s"set").where(s"key = '${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}'").count() == 0) + } finally { + sql(s"set ${SQLConf.OPTIMIZER_MAX_ITERATIONS}=$original") + } } } - test("reset - user-defined conf") { - spark.sessionState.conf.clear() - val userDefinedConf = "x.y.z.reset" - try { - assert(spark.conf.getOption(userDefinedConf).isEmpty) - sql(s"set $userDefinedConf=false") - assert(spark.conf.get(userDefinedConf) === "false") - assert(sql(s"set").where(s"key = '$userDefinedConf'").count() == 1) - sql(s"reset") - assert(spark.conf.getOption(userDefinedConf).isEmpty) - } finally { - spark.conf.unset(userDefinedConf) + Seq("reset", s"reset $testKey").foreach { resetCmd => + test(s"$resetCmd - user-defined conf") { + spark.sessionState.conf.clear() + try { + assert(spark.conf.getOption(testKey).isEmpty) + sql(s"set $testKey=false") + assert(spark.conf.get(testKey) === "false") + assert(sql(s"set").where(s"key = '$testKey'").count() == 1) + sql(resetCmd) + assert(spark.conf.getOption(testKey).isEmpty) + } finally { + spark.conf.unset(testKey) + } + } + } + + Seq("reset", s"reset ${testKey}1 \t ${testKey}2 \t ").foreach { resetCmd => + test(s"$resetCmd - multiple conf") { + spark.sessionState.conf.clear() + val key1 = testKey + "1" + val key2 = testKey + "2" + try { + assert(spark.conf.getOption(key1).isEmpty) + assert(spark.conf.getOption(key2).isEmpty) + sql(s"set $key1=false") + sql(s"set $key2=true") + assert(spark.conf.get(key1) === "false") + assert(spark.conf.get(key2) === "true") + assert(sql(s"set").where(s"key = '$key1'").count() == 1) + assert(sql(s"set").where(s"key = '$key2'").count() == 1) + sql(resetCmd) + assert(spark.conf.getOption(key1).isEmpty) + assert(spark.conf.getOption(key2).isEmpty) + } finally { + spark.conf.unset(key1) + spark.conf.unset(key2) + } } }