diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 38d7319b1f0ef..5078ef43adf68 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -83,6 +83,9 @@ class SparkContext(config: SparkConf) extends Logging { // The call site where this SparkContext was constructed. private val creationSite: CallSite = Utils.getCallSite() + // In order to prevent SparkContext from being created in executors. + SparkContext.assertOnDriver() + // In order to prevent multiple SparkContexts from being active at the same time, mark this // context as having started construction. // NOTE: this must be placed at the beginning of the SparkContext constructor. @@ -2554,6 +2557,19 @@ object SparkContext extends Logging { } } + /** + * Called to ensure that SparkContext is created or accessed only on the Driver. + * + * Throws an exception if a SparkContext is about to be created in executors. + */ + private def assertOnDriver(): Unit = { + if (TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkContext should only be created and accessed on the driver.") + } + } + /** * This function may be used to get or instantiate a SparkContext and register it as a * singleton object. Because we can only have one active SparkContext per JVM, diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 30237fd576830..d111bb33ce8ff 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -934,6 +934,18 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } } + + test("SPARK-32160: Disallow to create SparkContext in executors") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]")) + + val error = intercept[SparkException] { + sc.range(0, 1).foreach { _ => + new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + } + }.getMessage() + + assert(error.contains("SparkContext should only be created and accessed on the driver.")) + } } object SparkContextSuite { diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 32d69edb171db..6d58e1d14484c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -38,6 +38,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.resource.information import ResourceInformation from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix +from pyspark.taskcontext import TaskContext from pyspark.traceback_utils import CallSite, first_spark_call from pyspark.status import StatusTracker from pyspark.profiler import ProfilerCollector, BasicProfiler @@ -118,6 +119,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ... ValueError:... """ + # In order to prevent SparkContext from being created in executors. + SparkContext._assert_on_driver() + self._callsite = first_spark_call() or CallSite(None, None, None) if gateway is not None and gateway.gateway_parameters.auth_token is None: raise ValueError( @@ -1145,6 +1149,16 @@ def resources(self): resources[name] = ResourceInformation(name, addrs) return resources + @staticmethod + def _assert_on_driver(): + """ + Called to ensure that SparkContext is created only on the Driver. + + Throws an exception if a SparkContext is about to be created in executors. + """ + if TaskContext.get() is not None: + raise Exception("SparkContext should only be created and accessed on the driver.") + def _test(): import atexit diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py index 5833bf9f96fb3..168299e385e78 100644 --- a/python/pyspark/tests/test_context.py +++ b/python/pyspark/tests/test_context.py @@ -267,6 +267,14 @@ def test_resources(self): resources = sc.resources self.assertEqual(len(resources), 0) + def test_disallow_to_create_spark_context_in_executors(self): + # SPARK-32160: SparkContext should not be created in executors. + with SparkContext("local-cluster[3, 1, 1024]") as sc: + with self.assertRaises(Exception) as context: + sc.range(2).foreach(lambda _: SparkContext()) + self.assertIn("SparkContext should only be created and accessed on the driver.", + str(context.exception)) + class ContextTestsWithResources(unittest.TestCase): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index ea1a9f12cd24b..9278eeeefe608 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -1087,7 +1087,7 @@ object SparkSession extends Logging { } private def assertOnDriver(): Unit = { - if (Utils.isTesting && TaskContext.get != null) { + if (TaskContext.get != null) { // we're accessing it during task execution, fail. throw new IllegalStateException( "SparkSession should only be created and accessed on the driver.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index b29de9c4adbaa..98aba3ba25f17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -27,32 +27,29 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSparkContext { private val random = new java.util.Random() - private var taskContext: TaskContext = _ - - override def afterAll(): Unit = try { - TaskContext.unset() - } finally { - super.afterAll() - } private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int) (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { sc = new SparkContext("local", "test", new SparkConf(false)) - taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) TaskContext.setTaskContext(taskContext) - val array = new ExternalAppendOnlyUnsafeRowArray( - taskContext.taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - taskContext, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - inMemoryThreshold, - spillThreshold) - try f(array) finally { - array.clear() + try { + val array = new ExternalAppendOnlyUnsafeRowArray( + taskContext.taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + taskContext, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + inMemoryThreshold, + spillThreshold) + try f(array) finally { + array.clear() + } + } finally { + TaskContext.unset() } }