Skip to content

Commit 03e90f6

Browse files
committed
[SPARK-24250][SQL] support accessing SQLConf inside tasks
re-submit #21299 which broke build. A few new commits are added to fix the SQLConf problem in `JsonSchemaInference.infer`, and prevent us to access `SQLConf` in DAGScheduler event loop thread. ## What changes were proposed in this pull request? Previously in #20136 we decided to forbid tasks to access `SQLConf`, because it doesn't work and always give you the default conf value. In #21190 we fixed the check and all the places that violate it. Currently the pattern of accessing configs at the executor side is: read the configs at the driver side, then access the variables holding the config values in the RDD closure, so that they will be serialized to the executor side. Something like ``` val someConf = conf.getXXX child.execute().mapPartitions { if (someConf == ...) ... ... } ``` However, this pattern is hard to apply if the config needs to be propagated via a long call stack. An example is `DataType.sameType`, and see how many changes were made in #21190 . When it comes to code generation, it's even worse. I tried it locally and we need to change a ton of files to propagate configs to code generators. This PR proposes to allow tasks to access `SQLConf`. The idea is, we can save all the SQL configs to job properties when an SQL execution is triggered. At executor side we rebuild the `SQLConf` from job properties. ## How was this patch tested? a new test suite Author: Wenchen Fan <[email protected]> Closes #21376 from cloud-fan/config.
1 parent a6e883f commit 03e90f6

File tree

12 files changed

+239
-43
lines changed

12 files changed

+239
-43
lines changed

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,6 @@ private[spark] class TaskContextImpl(
178178

179179
private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException
180180

181+
// TODO: shall we publish it and define it in `TaskContext`?
182+
private[spark] def getLocalProperties(): Properties = localProperties
181183
}

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class DAGScheduler(
206206
private val messageScheduler =
207207
ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message")
208208

209-
private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
209+
private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
210210
taskScheduler.setDAGScheduler(this)
211211

212212
/**

core/src/main/scala/org/apache/spark/util/EventLoop.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging {
3737

3838
private val stopped = new AtomicBoolean(false)
3939

40-
private val eventThread = new Thread(name) {
40+
// Exposed for testing.
41+
private[spark] val eventThread = new Thread(name) {
4142
setDaemon(true)
4243

4344
override def run(): Unit = {
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.internal
19+
20+
import java.util.{Map => JMap}
21+
22+
import org.apache.spark.{TaskContext, TaskContextImpl}
23+
import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader}
24+
25+
/**
26+
* A readonly SQLConf that will be created by tasks running at the executor side. It reads the
27+
* configs from the local properties which are propagated from driver to executors.
28+
*/
29+
class ReadOnlySQLConf(context: TaskContext) extends SQLConf {
30+
31+
@transient override val settings: JMap[String, String] = {
32+
context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]]
33+
}
34+
35+
@transient override protected val reader: ConfigReader = {
36+
new ConfigReader(new TaskContextConfigProvider(context))
37+
}
38+
39+
override protected def setConfWithCheck(key: String, value: String): Unit = {
40+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
41+
}
42+
43+
override def unsetConf(key: String): Unit = {
44+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
45+
}
46+
47+
override def unsetConf(entry: ConfigEntry[_]): Unit = {
48+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
49+
}
50+
51+
override def clear(): Unit = {
52+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
53+
}
54+
55+
override def clone(): SQLConf = {
56+
throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.")
57+
}
58+
59+
override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = {
60+
throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.")
61+
}
62+
}
63+
64+
class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider {
65+
override def get(key: String): Option[String] = Option(context.getLocalProperty(key))
66+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.util.matching.Regex
2727

2828
import org.apache.hadoop.fs.Path
2929

30-
import org.apache.spark.{SparkContext, SparkEnv}
30+
import org.apache.spark.{SparkContext, TaskContext}
3131
import org.apache.spark.internal.Logging
3232
import org.apache.spark.internal.config._
3333
import org.apache.spark.network.util.ByteUnit
@@ -95,7 +95,9 @@ object SQLConf {
9595

9696
/**
9797
* Returns the active config object within the current scope. If there is an active SparkSession,
98-
* the proper SQLConf associated with the thread's session is used.
98+
* the proper SQLConf associated with the thread's active session is used. If it's called from
99+
* tasks in the executor side, a SQLConf will be created from job local properties, which are set
100+
* and propagated from the driver side.
99101
*
100102
* The way this works is a little bit convoluted, due to the fact that config was added initially
101103
* only for physical plans (and as a result not in sql/catalyst module).
@@ -107,7 +109,22 @@ object SQLConf {
107109
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
108110
* run unit tests (that does not involve SparkSession) in serial order.
109111
*/
110-
def get: SQLConf = confGetter.get()()
112+
def get: SQLConf = {
113+
if (TaskContext.get != null) {
114+
new ReadOnlySQLConf(TaskContext.get())
115+
} else {
116+
if (Utils.isTesting && SparkContext.getActive.isDefined) {
117+
// DAGScheduler event loop thread does not have an active SparkSession, the `confGetter`
118+
// will return `fallbackConf` which is unexpected. Here we prevent it from happening.
119+
val schedulerEventLoopThread =
120+
SparkContext.getActive.get.dagScheduler.eventProcessLoop.eventThread
121+
if (schedulerEventLoopThread.getId == Thread.currentThread().getId) {
122+
throw new RuntimeException("Cannot get SQLConf inside scheduler event loop thread.")
123+
}
124+
}
125+
confGetter.get()()
126+
}
127+
}
111128

112129
val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
113130
.internal()
@@ -1292,17 +1309,11 @@ object SQLConf {
12921309
class SQLConf extends Serializable with Logging {
12931310
import SQLConf._
12941311

1295-
if (Utils.isTesting && SparkEnv.get != null) {
1296-
// assert that we're only accessing it on the driver.
1297-
assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
1298-
"SQLConf should only be created and accessed on the driver.")
1299-
}
1300-
13011312
/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
13021313
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
13031314
new java.util.HashMap[String, String]())
13041315

1305-
@transient private val reader = new ConfigReader(settings)
1316+
@transient protected val reader = new ConfigReader(settings)
13061317

13071318
/** ************************ Spark SQL Params/Hints ******************* */
13081319

@@ -1765,7 +1776,7 @@ class SQLConf extends Serializable with Logging {
17651776
settings.containsKey(key)
17661777
}
17671778

1768-
private def setConfWithCheck(key: String, value: String): Unit = {
1779+
protected def setConfWithCheck(key: String, value: String): Unit = {
17691780
settings.put(key, value)
17701781
}
17711782

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
2424
import scala.reflect.runtime.universe.TypeTag
2525
import scala.util.control.NonFatal
2626

27-
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
27+
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext}
2828
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
2929
import org.apache.spark.api.java.JavaRDD
3030
import org.apache.spark.internal.Logging
@@ -898,6 +898,7 @@ object SparkSession extends Logging {
898898
* @since 2.0.0
899899
*/
900900
def getOrCreate(): SparkSession = synchronized {
901+
assertOnDriver()
901902
// Get the session from current thread's active session.
902903
var session = activeThreadSession.get()
903904
if ((session ne null) && !session.sparkContext.isStopped) {
@@ -1022,14 +1023,20 @@ object SparkSession extends Logging {
10221023
*
10231024
* @since 2.2.0
10241025
*/
1025-
def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)
1026+
def getActiveSession: Option[SparkSession] = {
1027+
assertOnDriver()
1028+
Option(activeThreadSession.get)
1029+
}
10261030

10271031
/**
10281032
* Returns the default SparkSession that is returned by the builder.
10291033
*
10301034
* @since 2.2.0
10311035
*/
1032-
def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
1036+
def getDefaultSession: Option[SparkSession] = {
1037+
assertOnDriver()
1038+
Option(defaultSession.get)
1039+
}
10331040

10341041
/**
10351042
* Returns the currently active SparkSession, otherwise the default one. If there is no default
@@ -1062,6 +1069,14 @@ object SparkSession extends Logging {
10621069
}
10631070
}
10641071

1072+
private def assertOnDriver(): Unit = {
1073+
if (Utils.isTesting && TaskContext.get != null) {
1074+
// we're accessing it during task execution, fail.
1075+
throw new IllegalStateException(
1076+
"SparkSession should only be created and accessed on the driver.")
1077+
}
1078+
}
1079+
10651080
/**
10661081
* Helper method to create an instance of `SessionState` based on `className` from conf.
10671082
* The result is either `SessionState` or a Hive based `SessionState`.

sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,18 @@ object SQLExecution {
6868
// sparkContext.getCallSite() would first try to pick up any call site that was previously
6969
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
7070
// streaming queries would give us call site like "run at <unknown>:0"
71-
val callSite = sparkSession.sparkContext.getCallSite()
71+
val callSite = sc.getCallSite()
7272

73-
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
74-
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
75-
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
76-
try {
77-
body
78-
} finally {
79-
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
80-
executionId, System.currentTimeMillis()))
73+
withSQLConfPropagated(sparkSession) {
74+
sc.listenerBus.post(SparkListenerSQLExecutionStart(
75+
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
76+
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
77+
try {
78+
body
79+
} finally {
80+
sc.listenerBus.post(SparkListenerSQLExecutionEnd(
81+
executionId, System.currentTimeMillis()))
82+
}
8183
}
8284
} finally {
8385
executionIdToQueryExecution.remove(executionId)
@@ -90,13 +92,41 @@ object SQLExecution {
9092
* thread from the original one, this method can be used to connect the Spark jobs in this action
9193
* with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`.
9294
*/
93-
def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = {
95+
def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = {
96+
val sc = sparkSession.sparkContext
9497
val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
98+
withSQLConfPropagated(sparkSession) {
99+
try {
100+
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
101+
body
102+
} finally {
103+
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
104+
}
105+
}
106+
}
107+
108+
/**
109+
* Wrap an action with specified SQL configs. These configs will be propagated to the executor
110+
* side via job local properties.
111+
*/
112+
def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = {
113+
val sc = sparkSession.sparkContext
114+
// Set all the specified SQL configs to local properties, so that they can be available at
115+
// the executor side.
116+
val allConfigs = sparkSession.sessionState.conf.getAllConfs
117+
val originalLocalProps = allConfigs.collect {
118+
case (key, value) if key.startsWith("spark") =>
119+
val originalValue = sc.getLocalProperty(key)
120+
sc.setLocalProperty(key, value)
121+
(key, originalValue)
122+
}
123+
95124
try {
96-
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
97125
body
98126
} finally {
99-
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
127+
for ((key, value) <- originalLocalProps) {
128+
sc.setLocalProperty(key, value)
129+
}
100130
}
101131
}
102132
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
643643
Future {
644644
// This will run in another thread. Set the execution id so that we can connect these jobs
645645
// with the correct execution.
646-
SQLExecution.withExecutionId(sparkContext, executionId) {
646+
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
647647
val beforeCollect = System.nanoTime()
648648
// Note that we use .executeCollect() because we don't want to convert data to Scala types
649649
val rows: Array[InternalRow] = child.executeCollect()

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD}
3434
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
3535
import org.apache.spark.sql.catalyst.InternalRow
3636
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
37+
import org.apache.spark.sql.execution.SQLExecution
3738
import org.apache.spark.sql.execution.datasources._
3839
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
3940
import org.apache.spark.sql.types.StructType
@@ -104,22 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource {
104105
CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow)
105106
}.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow))
106107

107-
JsonInferSchema.infer(rdd, parsedOptions, rowParser)
108+
SQLExecution.withSQLConfPropagated(json.sparkSession) {
109+
JsonInferSchema.infer(rdd, parsedOptions, rowParser)
110+
}
108111
}
109112

110113
private def createBaseDataset(
111114
sparkSession: SparkSession,
112115
inputPaths: Seq[FileStatus],
113116
parsedOptions: JSONOptions): Dataset[String] = {
114-
val paths = inputPaths.map(_.getPath.toString)
115-
val textOptions = Map.empty[String, String] ++
116-
parsedOptions.encoding.map("encoding" -> _) ++
117-
parsedOptions.lineSeparator.map("lineSep" -> _)
118-
119117
sparkSession.baseRelationToDataFrame(
120118
DataSource.apply(
121119
sparkSession,
122-
paths = paths,
120+
paths = inputPaths.map(_.getPath.toString),
123121
className = classOf[TextFileFormat].getName,
124122
options = parsedOptions.parameters
125123
).resolveRelation(checkFilesExist = false))
@@ -165,7 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource {
165163
.map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream))
166164
.getOrElse(createParser(_: JsonFactory, _: PortableDataStream))
167165

168-
JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
166+
SQLExecution.withSQLConfPropagated(sparkSession) {
167+
JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
168+
}
169169
}
170170

171171
private def createBaseRdd(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ private[sql] object JsonInferSchema {
4545
val parseMode = configOptions.parseMode
4646
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
4747

48-
// perform schema inference on each row and merge afterwards
49-
val rootType = json.mapPartitions { iter =>
48+
// In each RDD partition, perform schema inference on each row and merge afterwards.
49+
val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode)
50+
val mergedTypesFromPartitions = json.mapPartitions { iter =>
5051
val factory = new JsonFactory()
5152
configOptions.setJacksonOptions(factory)
5253
iter.flatMap { row =>
@@ -66,9 +67,13 @@ private[sql] object JsonInferSchema {
6667
s"Parse Mode: ${FailFastMode.name}.", e)
6768
}
6869
}
69-
}
70-
}.fold(StructType(Nil))(
71-
compatibleRootType(columnNameOfCorruptRecord, parseMode))
70+
}.reduceOption(typeMerger).toIterator
71+
}
72+
73+
// Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because
74+
// `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have
75+
// active SparkSession and `SQLConf.get` may point to the wrong configs.
76+
val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger)
7277

7378
canonicalizeType(rootType) match {
7479
case Some(st: StructType) => st

0 commit comments

Comments
 (0)