Skip to content

Commit 000e25a

Browse files
committed
Revert "[SPARK-24250][SQL] support accessing SQLConf inside tasks"
This reverts commit dd37529.
1 parent dd37529 commit 000e25a

File tree

9 files changed

+36
-210
lines changed

9 files changed

+36
-210
lines changed

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,4 @@ private[spark] class TaskContextImpl(
178178

179179
private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException
180180

181-
// TODO: shall we publish it and define it in `TaskContext`?
182-
private[spark] def getLocalProperties(): Properties = localProperties
183181
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala

Lines changed: 0 additions & 66 deletions
This file was deleted.

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ import scala.util.matching.Regex
2727

2828
import org.apache.hadoop.fs.Path
2929

30-
import org.apache.spark.TaskContext
30+
import org.apache.spark.{SparkContext, SparkEnv}
3131
import org.apache.spark.internal.Logging
3232
import org.apache.spark.internal.config._
3333
import org.apache.spark.network.util.ByteUnit
3434
import org.apache.spark.sql.catalyst.analysis.Resolver
3535
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
36+
import org.apache.spark.util.Utils
3637

3738
////////////////////////////////////////////////////////////////////////////////////////////////////
3839
// This file defines the configuration options for Spark SQL.
@@ -106,13 +107,7 @@ object SQLConf {
106107
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
107108
* run unit tests (that does not involve SparkSession) in serial order.
108109
*/
109-
def get: SQLConf = {
110-
if (TaskContext.get != null) {
111-
new ReadOnlySQLConf(TaskContext.get())
112-
} else {
113-
confGetter.get()()
114-
}
115-
}
110+
def get: SQLConf = confGetter.get()()
116111

117112
val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
118113
.internal()
@@ -1297,11 +1292,17 @@ object SQLConf {
12971292
class SQLConf extends Serializable with Logging {
12981293
import SQLConf._
12991294

1295+
if (Utils.isTesting && SparkEnv.get != null) {
1296+
// assert that we're only accessing it on the driver.
1297+
assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
1298+
"SQLConf should only be created and accessed on the driver.")
1299+
}
1300+
13001301
/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
13011302
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
13021303
new java.util.HashMap[String, String]())
13031304

1304-
@transient protected val reader = new ConfigReader(settings)
1305+
@transient private val reader = new ConfigReader(settings)
13051306

13061307
/** ************************ Spark SQL Params/Hints ******************* */
13071308

@@ -1764,7 +1765,7 @@ class SQLConf extends Serializable with Logging {
17641765
settings.containsKey(key)
17651766
}
17661767

1767-
protected def setConfWithCheck(key: String, value: String): Unit = {
1768+
private def setConfWithCheck(key: String, value: String): Unit = {
17681769
settings.put(key, value)
17691770
}
17701771

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
2424
import scala.reflect.runtime.universe.TypeTag
2525
import scala.util.control.NonFatal
2626

27-
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext}
27+
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
2828
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
2929
import org.apache.spark.api.java.JavaRDD
3030
import org.apache.spark.internal.Logging
@@ -898,7 +898,6 @@ object SparkSession extends Logging {
898898
* @since 2.0.0
899899
*/
900900
def getOrCreate(): SparkSession = synchronized {
901-
assertOnDriver()
902901
// Get the session from current thread's active session.
903902
var session = activeThreadSession.get()
904903
if ((session ne null) && !session.sparkContext.isStopped) {
@@ -1023,20 +1022,14 @@ object SparkSession extends Logging {
10231022
*
10241023
* @since 2.2.0
10251024
*/
1026-
def getActiveSession: Option[SparkSession] = {
1027-
assertOnDriver()
1028-
Option(activeThreadSession.get)
1029-
}
1025+
def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)
10301026

10311027
/**
10321028
* Returns the default SparkSession that is returned by the builder.
10331029
*
10341030
* @since 2.2.0
10351031
*/
1036-
def getDefaultSession: Option[SparkSession] = {
1037-
assertOnDriver()
1038-
Option(defaultSession.get)
1039-
}
1032+
def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
10401033

10411034
/**
10421035
* Returns the currently active SparkSession, otherwise the default one. If there is no default
@@ -1069,14 +1062,6 @@ object SparkSession extends Logging {
10691062
}
10701063
}
10711064

1072-
private def assertOnDriver(): Unit = {
1073-
if (Utils.isTesting && TaskContext.get != null) {
1074-
// we're accessing it during task execution, fail.
1075-
throw new IllegalStateException(
1076-
"SparkSession should only be created and accessed on the driver.")
1077-
}
1078-
}
1079-
10801065
/**
10811066
* Helper method to create an instance of `SessionState` based on `className` from conf.
10821067
* The result is either `SessionState` or a Hive based `SessionState`.

sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,16 @@ object SQLExecution {
6868
// sparkContext.getCallSite() would first try to pick up any call site that was previously
6969
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
7070
// streaming queries would give us call site like "run at <unknown>:0"
71-
val callSite = sc.getCallSite()
71+
val callSite = sparkSession.sparkContext.getCallSite()
7272

73-
withSQLConfPropagated(sparkSession) {
74-
sc.listenerBus.post(SparkListenerSQLExecutionStart(
75-
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
76-
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
77-
try {
78-
body
79-
} finally {
80-
sc.listenerBus.post(SparkListenerSQLExecutionEnd(
81-
executionId, System.currentTimeMillis()))
82-
}
73+
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
74+
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
75+
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
76+
try {
77+
body
78+
} finally {
79+
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
80+
executionId, System.currentTimeMillis()))
8381
}
8482
} finally {
8583
executionIdToQueryExecution.remove(executionId)
@@ -92,37 +90,13 @@ object SQLExecution {
9290
* thread from the original one, this method can be used to connect the Spark jobs in this action
9391
* with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`.
9492
*/
95-
def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = {
96-
val sc = sparkSession.sparkContext
93+
def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = {
9794
val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
98-
withSQLConfPropagated(sparkSession) {
99-
try {
100-
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
101-
body
102-
} finally {
103-
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
104-
}
105-
}
106-
}
107-
108-
def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = {
109-
val sc = sparkSession.sparkContext
110-
// Set all the specified SQL configs to local properties, so that they can be available at
111-
// the executor side.
112-
val allConfigs = sparkSession.sessionState.conf.getAllConfs
113-
val originalLocalProps = allConfigs.collect {
114-
case (key, value) if key.startsWith("spark") =>
115-
val originalValue = sc.getLocalProperty(key)
116-
sc.setLocalProperty(key, value)
117-
(key, originalValue)
118-
}
119-
12095
try {
96+
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
12197
body
12298
} finally {
123-
for ((key, value) <- originalLocalProps) {
124-
sc.setLocalProperty(key, value)
125-
}
99+
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
126100
}
127101
}
128102
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
629629
Future {
630630
// This will run in another thread. Set the execution id so that we can connect these jobs
631631
// with the correct execution.
632-
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
632+
SQLExecution.withExecutionId(sparkContext, executionId) {
633633
val beforeCollect = System.nanoTime()
634634
// Note that we use .executeCollect() because we don't want to convert data to Scala types
635635
val rows: Array[InternalRow] = child.executeCollect()

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD}
3434
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
3535
import org.apache.spark.sql.catalyst.InternalRow
3636
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
37-
import org.apache.spark.sql.execution.SQLExecution
3837
import org.apache.spark.sql.execution.datasources._
3938
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
4039
import org.apache.spark.sql.types.StructType
@@ -105,19 +104,22 @@ object TextInputJsonDataSource extends JsonDataSource {
105104
CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow)
106105
}.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow))
107106

108-
SQLExecution.withSQLConfPropagated(json.sparkSession) {
109-
JsonInferSchema.infer(rdd, parsedOptions, rowParser)
110-
}
107+
JsonInferSchema.infer(rdd, parsedOptions, rowParser)
111108
}
112109

113110
private def createBaseDataset(
114111
sparkSession: SparkSession,
115112
inputPaths: Seq[FileStatus],
116113
parsedOptions: JSONOptions): Dataset[String] = {
114+
val paths = inputPaths.map(_.getPath.toString)
115+
val textOptions = Map.empty[String, String] ++
116+
parsedOptions.encoding.map("encoding" -> _) ++
117+
parsedOptions.lineSeparator.map("lineSep" -> _)
118+
117119
sparkSession.baseRelationToDataFrame(
118120
DataSource.apply(
119121
sparkSession,
120-
paths = inputPaths.map(_.getPath.toString),
122+
paths = paths,
121123
className = classOf[TextFileFormat].getName,
122124
options = parsedOptions.parameters
123125
).resolveRelation(checkFilesExist = false))
@@ -163,9 +165,7 @@ object MultiLineJsonDataSource extends JsonDataSource {
163165
.map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream))
164166
.getOrElse(createParser(_: JsonFactory, _: PortableDataStream))
165167

166-
SQLExecution.withSQLConfPropagated(sparkSession) {
167-
JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
168-
}
168+
JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
169169
}
170170

171171
private def createBaseRdd(

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ case class BroadcastExchangeExec(
6969
Future {
7070
// This will run in another thread. Set the execution id so that we can connect these jobs
7171
// with the correct execution.
72-
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
72+
SQLExecution.withExecutionId(sparkContext, executionId) {
7373
try {
7474
val beforeCollect = System.nanoTime()
7575
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types

sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala

Lines changed: 0 additions & 66 deletions
This file was deleted.

0 commit comments

Comments
 (0)