Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
override protected def initialValue(): Properties = new Properties()
}

// Thread Local variable that can be used by users to pass information down the stack
protected[spark] val uninheritableLocalProperties = new ThreadLocal[Properties] {
override protected def initialValue(): Properties = new Properties()
}

/* ------------------------------------------------------------------------------------- *
| Initialization. This code initializes the context in a manner that is exception-safe. |
| All internal fields holding state are initialized here, and any error prompts the |
Expand Down Expand Up @@ -595,32 +600,57 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
}

private[spark] def getLocalProperties: Properties = localProperties.get()
private[spark] def getLocalProperties: (Properties, Properties) =
(localProperties.get(), uninheritableLocalProperties.get())

private[spark] def setLocalProperties(props: Properties) {
private[spark] def setLocalProperties(bothProps: (Properties, Properties)) {
val props = bothProps._1
val uninheritableProps = bothProps._2
localProperties.set(props)
uninheritableLocalProperties.set(uninheritableProps)
}

private def setProperty(props: Properties, key: String, value: String) =
if (value == null) {
props.remove(key)
} else {
props.setProperty(key, value)
}

/**
* Set a local property that affects jobs submitted from this thread, such as the Spark fair
* scheduler pool. User-defined properties may also be set here. These properties are propagated
* through to worker tasks and can be accessed there via
* [[org.apache.spark.TaskContext#getLocalProperty]].
*
* These properties are inherited by child threads spawned from this thread. This
* may have unexpected consequences when working with thread pools. The standard java
* implementation of thread pools have worker threads spawn other worker threads.
* As a result, local properties may propagate unpredictably.
*/
def setLocalProperty(key: String, value: String) {
if (value == null) {
localProperties.get.remove(key)
} else {
localProperties.get.setProperty(key, value)
}
}
def setLocalProperty(key: String, value: String): Unit =
setProperty(localProperties.get, key, value)

/**
* Set a local property that affects jobs submitted from this thread, such as the Spark fair
* scheduler pool. User-defined properties may also be set here. These properties are propagated
* through to worker tasks and can be accessed there via
* [[org.apache.spark.TaskContext#getLocalProperty]].
*
* Properties set through this method will *not* be inherited by spawned threads.
*/
def setUninheritableLocalProperty(key: String, value: String): Unit =
setProperty(uninheritableLocalProperties.get, key, value)


/**
* Get a local property set in this thread, or null if it is missing. See
* [[org.apache.spark.SparkContext.setLocalProperty]].
*/
def getLocalProperty(key: String): String =
Option(localProperties.get).map(_.getProperty(key)).orNull
def getLocalProperty(key: String): String = {
lazy val inheritableProperty = Option(localProperties.get).map(_.getProperty(key)).orNull
Option(uninheritableLocalProperties.get).map(_.getProperty(key)).getOrElse(inheritableProperty)
}

/** Set a human readable description of the current job. */
def setJobDescription(value: String) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -712,8 +712,13 @@ class JavaSparkContext(val sc: SparkContext)
}

/**
* Set a local property that affects jobs submitted from this thread, such as the
* Spark fair scheduler pool.
* Set a local property that affects jobs submitted from this thread, and all child
* threads, such as the Spark fair scheduler pool.
*
* These properties are inherited by child threads spawned from this thread. This
* may have unexpected consequences when working with thread pools. The standard java
* implementation of thread pools have worker threads spawn other worker threads.
* As a result, local properties may propagate unpredictably.
*/
def setLocalProperty(key: String, value: String): Unit = sc.setLocalProperty(key, value)

Expand Down
28 changes: 28 additions & 0 deletions core/src/test/scala/org/apache/spark/SparkContextSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -319,4 +319,32 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext {
assert(sc.getConf.getInt("spark.executor.instances", 0) === 6)
}
}


test("localProperties are inherited by spawned threads.") {
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
sc.setLocalProperty("testProperty", "testValue")
var result = "unset";
val thread = new Thread() { override def run() = {result = sc.getLocalProperty("testProperty")}}
thread.start()
thread.join()
sc.stop()
assert(result == "testValue")
}

test("localProperties do not cross-talk between threads.") {
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
var result = "unset";
val thread1 = new Thread() {
override def run() = {sc.setLocalProperty("testProperty", "testValue")}}
// testProperty should be unset and thus return null
val thread2 = new Thread() {
override def run() = {result = sc.getLocalProperty("testProperty")}}
thread1.start()
thread1.join()
thread2.start()
thread2.join()
sc.stop()
assert(result == null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private[sql] object SQLExecution {
val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
if (oldExecutionId == null) {
val executionId = SQLExecution.nextExecutionId
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
sc.setUninheritableLocalProperty(EXECUTION_ID_KEY, executionId.toString)
val r = try {
val callSite = Utils.getCallSite()
sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
Expand All @@ -56,7 +56,7 @@ private[sql] object SQLExecution {
executionId, System.currentTimeMillis()))
}
} finally {
sc.setLocalProperty(EXECUTION_ID_KEY, null)
sc.setUninheritableLocalProperty(EXECUTION_ID_KEY, null)
}
r
} else {
Expand Down Expand Up @@ -86,10 +86,10 @@ private[sql] object SQLExecution {
def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = {
val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
try {
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
sc.setUninheritableLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
body
} finally {
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
sc.setUninheritableLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ class StreamingContext private[streaming] (

// Copy of thread-local properties from SparkContext. These properties will be set in all tasks
// submitted by this StreamingContext after start.
private[streaming] val savedProperties = new AtomicReference[Properties](new Properties)
private[streaming] val savedProperties =
new AtomicReference[(Properties, Properties)]((new Properties, new Properties))

private[streaming] def getStartSite(): CallSite = startSite.get()

Expand Down Expand Up @@ -580,7 +581,7 @@ class StreamingContext private[streaming] (
sparkContext.clearJobGroup()
sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false")
savedProperties.set(SerializationUtils.clone(
sparkContext.localProperties.get()).asInstanceOf[Properties])
sparkContext.localProperties.get()).asInstanceOf[(Properties, Properties)])
scheduler.start()
}
state = StreamingContextState.ACTIVE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
val oldProps = ssc.sparkContext.getLocalProperties
try {
ssc.sparkContext.setLocalProperties(
SerializationUtils.clone(ssc.savedProperties.get()).asInstanceOf[Properties])
SerializationUtils.clone(
ssc.savedProperties.get()).asInstanceOf[(Properties, Properties)])
val formattedTime = UIUtils.formatBatchTime(
job.time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false)
val batchUrl = s"/streaming/batch/?id=${job.time.milliseconds}"
Expand Down