@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
2424import scala .reflect .runtime .universe .TypeTag
2525import scala .util .control .NonFatal
2626
27- import org .apache .spark .{SPARK_VERSION , SparkConf , SparkContext }
27+ import org .apache .spark .{SPARK_VERSION , SparkConf , SparkContext , TaskContext }
2828import org .apache .spark .annotation .{DeveloperApi , Experimental , InterfaceStability }
2929import org .apache .spark .api .java .JavaRDD
3030import org .apache .spark .internal .Logging
@@ -898,6 +898,7 @@ object SparkSession extends Logging {
898898 * @since 2.0.0
899899 */
900900 def getOrCreate (): SparkSession = synchronized {
901+ assertOnDriver()
901902 // Get the session from current thread's active session.
902903 var session = activeThreadSession.get()
903904 if ((session ne null ) && ! session.sparkContext.isStopped) {
@@ -1022,14 +1023,20 @@ object SparkSession extends Logging {
10221023 *
10231024 * @since 2.2.0
10241025 */
1025- def getActiveSession : Option [SparkSession ] = Option (activeThreadSession.get)
1026+ def getActiveSession : Option [SparkSession ] = {
1027+ assertOnDriver()
1028+ Option (activeThreadSession.get)
1029+ }
10261030
10271031 /**
10281032 * Returns the default SparkSession that is returned by the builder.
10291033 *
10301034 * @since 2.2.0
10311035 */
1032- def getDefaultSession : Option [SparkSession ] = Option (defaultSession.get)
1036+ def getDefaultSession : Option [SparkSession ] = {
1037+ assertOnDriver()
1038+ Option (defaultSession.get)
1039+ }
10331040
10341041 /**
10351042 * Returns the currently active SparkSession, otherwise the default one. If there is no default
@@ -1062,6 +1069,14 @@ object SparkSession extends Logging {
10621069 }
10631070 }
10641071
1072+ private def assertOnDriver (): Unit = {
1073+ if (Utils .isTesting && TaskContext .get != null ) {
1074+ // we're accessing it during task execution, fail.
1075+ throw new IllegalStateException (
1076+ " SparkSession should only be created and accessed on the driver." )
1077+ }
1078+ }
1079+
10651080 /**
10661081 * Helper method to create an instance of `SessionState` based on `className` from conf.
10671082 * The result is either `SessionState` or a Hive based `SessionState`.
0 commit comments