Skip to content

Commit ba46703

Browse files
committed
support accessing SQLConf at executor side
1 parent 000e25a commit ba46703

File tree

9 files changed

+210
-36
lines changed

9 files changed

+210
-36
lines changed

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,6 @@ private[spark] class TaskContextImpl(
178178

179179
private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException
180180

181+
// TODO: shall we publish it and define it in `TaskContext`?
182+
private[spark] def getLocalProperties(): Properties = localProperties
181183
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.internal
19+
20+
import java.util.{Map => JMap}
21+
22+
import org.apache.spark.{TaskContext, TaskContextImpl}
23+
import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader}
24+
25+
/**
26+
* A readonly SQLConf that will be created by tasks running at the executor side. It reads the
27+
* configs from the local properties which are propagated from driver to executors.
28+
*/
29+
class ReadOnlySQLConf(context: TaskContext) extends SQLConf {
30+
31+
@transient override val settings: JMap[String, String] = {
32+
context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]]
33+
}
34+
35+
@transient override protected val reader: ConfigReader = {
36+
new ConfigReader(new TaskContextConfigProvider(context))
37+
}
38+
39+
override protected def setConfWithCheck(key: String, value: String): Unit = {
40+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
41+
}
42+
43+
override def unsetConf(key: String): Unit = {
44+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
45+
}
46+
47+
override def unsetConf(entry: ConfigEntry[_]): Unit = {
48+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
49+
}
50+
51+
override def clear(): Unit = {
52+
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
53+
}
54+
55+
override def clone(): SQLConf = {
56+
throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.")
57+
}
58+
59+
override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = {
60+
throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.")
61+
}
62+
}
63+
64+
class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider {
65+
override def get(key: String): Option[String] = Option(context.getLocalProperty(key))
66+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@ import scala.util.matching.Regex
2727

2828
import org.apache.hadoop.fs.Path
2929

30-
import org.apache.spark.{SparkContext, SparkEnv}
30+
import org.apache.spark.TaskContext
3131
import org.apache.spark.internal.Logging
3232
import org.apache.spark.internal.config._
3333
import org.apache.spark.network.util.ByteUnit
3434
import org.apache.spark.sql.catalyst.analysis.Resolver
3535
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
36-
import org.apache.spark.util.Utils
3736

3837
////////////////////////////////////////////////////////////////////////////////////////////////////
3938
// This file defines the configuration options for Spark SQL.
@@ -107,7 +106,13 @@ object SQLConf {
107106
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
108107
* run unit tests (that does not involve SparkSession) in serial order.
109108
*/
110-
def get: SQLConf = confGetter.get()()
109+
def get: SQLConf = {
110+
if (TaskContext.get != null) {
111+
new ReadOnlySQLConf(TaskContext.get())
112+
} else {
113+
confGetter.get()()
114+
}
115+
}
111116

112117
val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
113118
.internal()
@@ -1292,17 +1297,11 @@ object SQLConf {
12921297
class SQLConf extends Serializable with Logging {
12931298
import SQLConf._
12941299

1295-
if (Utils.isTesting && SparkEnv.get != null) {
1296-
// assert that we're only accessing it on the driver.
1297-
assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
1298-
"SQLConf should only be created and accessed on the driver.")
1299-
}
1300-
13011300
/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
13021301
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
13031302
new java.util.HashMap[String, String]())
13041303

1305-
@transient private val reader = new ConfigReader(settings)
1304+
@transient protected val reader = new ConfigReader(settings)
13061305

13071306
/** ************************ Spark SQL Params/Hints ******************* */
13081307

@@ -1765,7 +1764,7 @@ class SQLConf extends Serializable with Logging {
17651764
settings.containsKey(key)
17661765
}
17671766

1768-
private def setConfWithCheck(key: String, value: String): Unit = {
1767+
protected def setConfWithCheck(key: String, value: String): Unit = {
17691768
settings.put(key, value)
17701769
}
17711770

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
2424
import scala.reflect.runtime.universe.TypeTag
2525
import scala.util.control.NonFatal
2626

27-
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
27+
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext}
2828
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
2929
import org.apache.spark.api.java.JavaRDD
3030
import org.apache.spark.internal.Logging
@@ -898,6 +898,7 @@ object SparkSession extends Logging {
898898
* @since 2.0.0
899899
*/
900900
def getOrCreate(): SparkSession = synchronized {
901+
assertOnDriver()
901902
// Get the session from current thread's active session.
902903
var session = activeThreadSession.get()
903904
if ((session ne null) && !session.sparkContext.isStopped) {
@@ -1022,14 +1023,20 @@ object SparkSession extends Logging {
10221023
*
10231024
* @since 2.2.0
10241025
*/
1025-
def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)
1026+
def getActiveSession: Option[SparkSession] = {
1027+
assertOnDriver()
1028+
Option(activeThreadSession.get)
1029+
}
10261030

10271031
/**
10281032
* Returns the default SparkSession that is returned by the builder.
10291033
*
10301034
* @since 2.2.0
10311035
*/
1032-
def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
1036+
def getDefaultSession: Option[SparkSession] = {
1037+
assertOnDriver()
1038+
Option(defaultSession.get)
1039+
}
10331040

10341041
/**
10351042
* Returns the currently active SparkSession, otherwise the default one. If there is no default
@@ -1062,6 +1069,14 @@ object SparkSession extends Logging {
10621069
}
10631070
}
10641071

1072+
private def assertOnDriver(): Unit = {
1073+
if (Utils.isTesting && TaskContext.get != null) {
1074+
// we're accessing it during task execution, fail.
1075+
throw new IllegalStateException(
1076+
"SparkSession should only be created and accessed on the driver.")
1077+
}
1078+
}
1079+
10651080
/**
10661081
* Helper method to create an instance of `SessionState` based on `className` from conf.
10671082
* The result is either `SessionState` or a Hive based `SessionState`.

sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,18 @@ object SQLExecution {
6868
// sparkContext.getCallSite() would first try to pick up any call site that was previously
6969
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
7070
// streaming queries would give us call site like "run at <unknown>:0"
71-
val callSite = sparkSession.sparkContext.getCallSite()
71+
val callSite = sc.getCallSite()
7272

73-
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
74-
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
75-
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
76-
try {
77-
body
78-
} finally {
79-
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
80-
executionId, System.currentTimeMillis()))
73+
withSQLConfPropagated(sparkSession) {
74+
sc.listenerBus.post(SparkListenerSQLExecutionStart(
75+
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
76+
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
77+
try {
78+
body
79+
} finally {
80+
sc.listenerBus.post(SparkListenerSQLExecutionEnd(
81+
executionId, System.currentTimeMillis()))
82+
}
8183
}
8284
} finally {
8385
executionIdToQueryExecution.remove(executionId)
@@ -90,13 +92,37 @@ object SQLExecution {
9092
* thread from the original one, this method can be used to connect the Spark jobs in this action
9193
* with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`.
9294
*/
93-
def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = {
95+
def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = {
96+
val sc = sparkSession.sparkContext
9497
val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
98+
withSQLConfPropagated(sparkSession) {
99+
try {
100+
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
101+
body
102+
} finally {
103+
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
104+
}
105+
}
106+
}
107+
108+
def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = {
109+
val sc = sparkSession.sparkContext
110+
// Set all the specified SQL configs to local properties, so that they can be available at
111+
// the executor side.
112+
val allConfigs = sparkSession.sessionState.conf.getAllConfs
113+
val originalLocalProps = allConfigs.collect {
114+
case (key, value) if key.startsWith("spark") =>
115+
val originalValue = sc.getLocalProperty(key)
116+
sc.setLocalProperty(key, value)
117+
(key, originalValue)
118+
}
119+
95120
try {
96-
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
97121
body
98122
} finally {
99-
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
123+
for ((key, value) <- originalLocalProps) {
124+
sc.setLocalProperty(key, value)
125+
}
100126
}
101127
}
102128
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
629629
Future {
630630
// This will run in another thread. Set the execution id so that we can connect these jobs
631631
// with the correct execution.
632-
SQLExecution.withExecutionId(sparkContext, executionId) {
632+
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
633633
val beforeCollect = System.nanoTime()
634634
// Note that we use .executeCollect() because we don't want to convert data to Scala types
635635
val rows: Array[InternalRow] = child.executeCollect()

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD}
3434
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
3535
import org.apache.spark.sql.catalyst.InternalRow
3636
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
37+
import org.apache.spark.sql.execution.SQLExecution
3738
import org.apache.spark.sql.execution.datasources._
3839
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
3940
import org.apache.spark.sql.types.StructType
@@ -104,22 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource {
104105
CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow)
105106
}.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow))
106107

107-
JsonInferSchema.infer(rdd, parsedOptions, rowParser)
108+
SQLExecution.withSQLConfPropagated(json.sparkSession) {
109+
JsonInferSchema.infer(rdd, parsedOptions, rowParser)
110+
}
108111
}
109112

110113
private def createBaseDataset(
111114
sparkSession: SparkSession,
112115
inputPaths: Seq[FileStatus],
113116
parsedOptions: JSONOptions): Dataset[String] = {
114-
val paths = inputPaths.map(_.getPath.toString)
115-
val textOptions = Map.empty[String, String] ++
116-
parsedOptions.encoding.map("encoding" -> _) ++
117-
parsedOptions.lineSeparator.map("lineSep" -> _)
118-
119117
sparkSession.baseRelationToDataFrame(
120118
DataSource.apply(
121119
sparkSession,
122-
paths = paths,
120+
paths = inputPaths.map(_.getPath.toString),
123121
className = classOf[TextFileFormat].getName,
124122
options = parsedOptions.parameters
125123
).resolveRelation(checkFilesExist = false))
@@ -165,7 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource {
165163
.map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream))
166164
.getOrElse(createParser(_: JsonFactory, _: PortableDataStream))
167165

168-
JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
166+
SQLExecution.withSQLConfPropagated(sparkSession) {
167+
JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
168+
}
169169
}
170170

171171
private def createBaseRdd(

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ case class BroadcastExchangeExec(
6969
Future {
7070
// This will run in another thread. Set the execution id so that we can connect these jobs
7171
// with the correct execution.
72-
SQLExecution.withExecutionId(sparkContext, executionId) {
72+
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
7373
try {
7474
val beforeCollect = System.nanoTime()
7575
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.internal
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.SparkSession
22+
import org.apache.spark.sql.test.SQLTestUtils
23+
24+
class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils {
25+
import testImplicits._
26+
27+
protected var spark: SparkSession = null
28+
29+
// Create a new [[SparkSession]] running in local-cluster mode.
30+
override def beforeAll(): Unit = {
31+
super.beforeAll()
32+
spark = SparkSession.builder()
33+
.master("local-cluster[2,1,1024]")
34+
.appName("testing")
35+
.getOrCreate()
36+
}
37+
38+
override def afterAll(): Unit = {
39+
spark.stop()
40+
spark = null
41+
}
42+
43+
test("ReadonlySQLConf is correctly created at the executor side") {
44+
SQLConf.get.setConfString("spark.sql.x", "a")
45+
try {
46+
val checks = spark.range(10).mapPartitions { it =>
47+
val conf = SQLConf.get
48+
Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a")
49+
}.collect()
50+
assert(checks.forall(_ == true))
51+
} finally {
52+
SQLConf.get.unsetConf("spark.sql.x")
53+
}
54+
}
55+
56+
test("case-sensitive config should work for json schema inference") {
57+
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
58+
withTempPath { path =>
59+
val pathString = path.getCanonicalPath
60+
spark.range(10).select('id.as("ID")).write.json(pathString)
61+
spark.range(10).write.mode("append").json(pathString)
62+
assert(spark.read.json(pathString).columns.toSet == Set("id", "ID"))
63+
}
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)