diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e41088f7c8f69..c9d2c8e62029b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -338,6 +338,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli override protected def initialValue(): Properties = new Properties() } + // Thread Local variable that can be used by users to pass information down the stack + protected[spark] val uninheritableLocalProperties = new ThreadLocal[Properties] { + override protected def initialValue(): Properties = new Properties() + } + /* ------------------------------------------------------------------------------------- * | Initialization. This code initializes the context in a manner that is exception-safe. | | All internal fields holding state are initialized here, and any error prompts the | @@ -595,32 +600,57 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } - private[spark] def getLocalProperties: Properties = localProperties.get() + private[spark] def getLocalProperties: (Properties, Properties) = + (localProperties.get(), uninheritableLocalProperties.get()) - private[spark] def setLocalProperties(props: Properties) { + private[spark] def setLocalProperties(bothProps: (Properties, Properties)) { + val props = bothProps._1 + val uninheritableProps = bothProps._2 localProperties.set(props) + uninheritableLocalProperties.set(uninheritableProps) } + private def setProperty(props: Properties, key: String, value: String) = + if (value == null) { + props.remove(key) + } else { + props.setProperty(key, value) + } + /** * Set a local property that affects jobs submitted from this thread, such as the Spark fair * scheduler pool. User-defined properties may also be set here. These properties are propagated * through to worker tasks and can be accessed there via * [[org.apache.spark.TaskContext#getLocalProperty]]. + * + * These properties are inherited by child threads spawned from this thread. This + * may have unexpected consequences when working with thread pools. The standard java + * implementation of thread pools have worker threads spawn other worker threads. + * As a result, local properties may propagate unpredictably. */ - def setLocalProperty(key: String, value: String) { - if (value == null) { - localProperties.get.remove(key) - } else { - localProperties.get.setProperty(key, value) - } - } + def setLocalProperty(key: String, value: String): Unit = + setProperty(localProperties.get, key, value) + + /** + * Set a local property that affects jobs submitted from this thread, such as the Spark fair + * scheduler pool. User-defined properties may also be set here. These properties are propagated + * through to worker tasks and can be accessed there via + * [[org.apache.spark.TaskContext#getLocalProperty]]. + * + * Properties set through this method will *not* be inherited by spawned threads. + */ + def setUninheritableLocalProperty(key: String, value: String): Unit = + setProperty(uninheritableLocalProperties.get, key, value) + /** * Get a local property set in this thread, or null if it is missing. See * [[org.apache.spark.SparkContext.setLocalProperty]]. */ - def getLocalProperty(key: String): String = - Option(localProperties.get).map(_.getProperty(key)).orNull + def getLocalProperty(key: String): String = { + lazy val inheritableProperty = Option(localProperties.get).map(_.getProperty(key)).orNull + Option(uninheritableLocalProperties.get).map(_.getProperty(key)).getOrElse(inheritableProperty) + } /** Set a human readable description of the current job. */ def setJobDescription(value: String) { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index dfd91ae338e89..fb6323413e3ea 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -712,8 +712,13 @@ class JavaSparkContext(val sc: SparkContext) } /** - * Set a local property that affects jobs submitted from this thread, such as the - * Spark fair scheduler pool. + * Set a local property that affects jobs submitted from this thread, and all child + * threads, such as the Spark fair scheduler pool. + * + * These properties are inherited by child threads spawned from this thread. This + * may have unexpected consequences when working with thread pools. The standard java + * implementation of thread pools have worker threads spawn other worker threads. + * As a result, local properties may propagate unpredictably. */ def setLocalProperty(key: String, value: String): Unit = sc.setLocalProperty(key, value) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 841fd02ae8bb6..e6e30f77e4b09 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -319,4 +319,32 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { assert(sc.getConf.getInt("spark.executor.instances", 0) === 6) } } + + + test("localProperties are inherited by spawned threads.") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.setLocalProperty("testProperty", "testValue") + var result = "unset"; + val thread = new Thread() { override def run() = {result = sc.getLocalProperty("testProperty")}} + thread.start() + thread.join() + sc.stop() + assert(result == "testValue") + } + + test("localProperties do not cross-talk between threads.") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + var result = "unset"; + val thread1 = new Thread() { + override def run() = {sc.setLocalProperty("testProperty", "testValue")}} + // testProperty should be unset and thus return null + val thread2 = new Thread() { + override def run() = {result = sc.getLocalProperty("testProperty")}} + thread1.start() + thread1.join() + thread2.start() + thread2.join() + sc.stop() + assert(result == null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 0a11b16d0ed35..567c10b0d7d4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -43,7 +43,7 @@ private[sql] object SQLExecution { val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) if (oldExecutionId == null) { val executionId = SQLExecution.nextExecutionId - sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) + sc.setUninheritableLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { val callSite = Utils.getCallSite() sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( @@ -56,7 +56,7 @@ private[sql] object SQLExecution { executionId, System.currentTimeMillis())) } } finally { - sc.setLocalProperty(EXECUTION_ID_KEY, null) + sc.setUninheritableLocalProperty(EXECUTION_ID_KEY, null) } r } else { @@ -86,10 +86,10 @@ private[sql] object SQLExecution { def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + sc.setUninheritableLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + sc.setUninheritableLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 928739a416f0f..0d86675764012 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -202,7 +202,8 @@ class StreamingContext private[streaming] ( // Copy of thread-local properties from SparkContext. These properties will be set in all tasks // submitted by this StreamingContext after start. - private[streaming] val savedProperties = new AtomicReference[Properties](new Properties) + private[streaming] val savedProperties = + new AtomicReference[(Properties, Properties)]((new Properties, new Properties)) private[streaming] def getStartSite(): CallSite = startSite.get() @@ -580,7 +581,7 @@ class StreamingContext private[streaming] ( sparkContext.clearJobGroup() sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") savedProperties.set(SerializationUtils.clone( - sparkContext.localProperties.get()).asInstanceOf[Properties]) + sparkContext.localProperties.get()).asInstanceOf[(Properties, Properties)]) scheduler.start() } state = StreamingContextState.ACTIVE diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index ac18f73ea86aa..9e5181183ebb9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -220,7 +220,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { val oldProps = ssc.sparkContext.getLocalProperties try { ssc.sparkContext.setLocalProperties( - SerializationUtils.clone(ssc.savedProperties.get()).asInstanceOf[Properties]) + SerializationUtils.clone( + ssc.savedProperties.get()).asInstanceOf[(Properties, Properties)]) val formattedTime = UIUtils.formatBatchTime( job.time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) val batchUrl = s"/streaming/batch/?id=${job.time.milliseconds}"