Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.internal.SQLConf

/**
Expand All @@ -29,4 +30,32 @@ trait SQLConfHelper {
* See [[SQLConf.get]] for more information.
*/
def conf: SQLConf = SQLConf.get

/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
* configurations.
*/
protected def withSQLConf[T](pairs: (String, String)*)(f: => T): T = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change it to type parameter?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it can return things, looks reasonable

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, the PR description descirbed the goal, @beliefer

To make it easy to use such case: val x = withSQLConf {}, this pr also changes its return type.

val conf = SQLConf.get
val (keys, values) = pairs.unzip
val currentValues = keys.map { key =>
if (conf.contains(key)) {
Some(conf.getConfString(key))
} else {
None
}
}
keys.lazyZip(values).foreach { (k, v) =>
if (SQLConf.isStaticConfigKey(k)) {
throw new AnalysisException(s"Cannot modify the value of a static config: $k")
}
conf.setConfString(k, v)
}
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => conf.setConfString(key, value)
case (key, None) => conf.unsetConf(key)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,41 +23,13 @@ import scala.util.control.NonFatal

import org.scalatest.Assertions.fail

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.getZoneId
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils

trait SQLHelper {

/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
* configurations.
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val conf = SQLConf.get
val (keys, values) = pairs.unzip
val currentValues = keys.map { key =>
if (conf.contains(key)) {
Some(conf.getConfString(key))
} else {
None
}
}
keys.lazyZip(values).foreach { (k, v) =>
if (SQLConf.isStaticConfigKey(k)) {
throw new AnalysisException(s"Cannot modify the value of a static config: $k")
}
conf.setConfString(k, v)
}
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => conf.setConfString(key, value)
case (key, None) => conf.unsetConf(key)
}
}
}
trait SQLHelper extends SQLConfHelper {

/**
* Generates a temporary path without creating the actual file/directory, then pass it to `f`. If
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils {
}
}

override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
override def withSQLConf[T](pairs: (String, String)*)(f: => T): T = {
pairs.foreach { case (k, v) =>
SQLConf.get.setConfString(k, v)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ private[sql] trait SQLTestUtilsBase
protected override def _sqlContext: SQLContext = self.spark.sqlContext
}

protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
protected override def withSQLConf[T](pairs: (String, String)*)(f: => T): T = {
SparkSession.setActiveSession(spark)
super.withSQLConf(pairs: _*)(f)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class HiveSerDeSuite extends HiveComparisonTest with PlanTest with BeforeAndAfte
}

// Make sure we set the config values to TestHive.conf.
override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit =
override def withSQLConf[T](pairs: (String, String)*)(f: => T): T =
SQLConf.withExistingConf(TestHive.conf)(super.withSQLConf(pairs: _*)(f))

test("Test the default fileformat for Hive-serde tables") {
Expand Down