Skip to content

Commit 04ae0fa

Browse files
committed
also check SparkSession
1 parent d3bef38 commit 04ae0fa

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
2424
import scala.reflect.runtime.universe.TypeTag
2525
import scala.util.control.NonFatal
2626

27-
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
27+
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext}
2828
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
2929
import org.apache.spark.api.java.JavaRDD
3030
import org.apache.spark.internal.Logging
@@ -898,6 +898,7 @@ object SparkSession extends Logging {
898898
* @since 2.0.0
899899
*/
900900
def getOrCreate(): SparkSession = synchronized {
901+
assertOnDriver()
901902
// Get the session from current thread's active session.
902903
var session = activeThreadSession.get()
903904
if ((session ne null) && !session.sparkContext.isStopped) {
@@ -1022,14 +1023,20 @@ object SparkSession extends Logging {
10221023
*
10231024
* @since 2.2.0
10241025
*/
1025-
def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)
1026+
def getActiveSession: Option[SparkSession] = {
1027+
assertOnDriver()
1028+
Option(activeThreadSession.get)
1029+
}
10261030

10271031
/**
10281032
* Returns the default SparkSession that is returned by the builder.
10291033
*
10301034
* @since 2.2.0
10311035
*/
1032-
def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
1036+
def getDefaultSession: Option[SparkSession] = {
1037+
assertOnDriver()
1038+
Option(defaultSession.get)
1039+
}
10331040

10341041
/**
10351042
* Returns the currently active SparkSession, otherwise the default one. If there is no default
@@ -1062,6 +1069,14 @@ object SparkSession extends Logging {
10621069
}
10631070
}
10641071

1072+
private def assertOnDriver(): Unit = {
1073+
if (Utils.isTesting && TaskContext.get != null) {
1074+
// we're accessing it during task execution, fail.
1075+
throw new IllegalStateException(
1076+
"SparkSession should only be created and accessed on the driver.")
1077+
}
1078+
}
1079+
10651080
/**
10661081
* Helper method to create an instance of `SessionState` based on `className` from conf.
10671082
* The result is either `SessionState` or a Hive based `SessionState`.

0 commit comments

Comments
 (0)