Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1511,16 +1511,31 @@ class SparkContext(config: SparkConf) extends Logging {
/**
* Broadcast a read-only variable to the cluster, returning a
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
* The variable will be sent to each executor only once.
*
* @param value value to broadcast to the Spark nodes
* @return `Broadcast` object, a read-only variable cached on each machine
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
broadcastInternal(value, serializedOnly = false)
}

/**
* Internal version of broadcast - broadcast a read-only variable to the cluster, returning a
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each executor only once.
*
* @param value value to broadcast to the Spark nodes
* @param serializedOnly if true, do not cache the unserialized value on the driver
* @return `Broadcast` object, a read-only variable cached on each machine
*/
private[spark] def broadcastInternal[T: ClassTag](
value: T,
serializedOnly: Boolean): Broadcast[T] = {
assertNotStopped()
require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
"Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val bc = env.broadcastManager.newBroadcast[T](value, isLocal, serializedOnly)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@ private[spark] trait BroadcastFactory {
* @param value value to broadcast
* @param isLocal whether we are in local mode (single JVM process)
* @param id unique id representing this broadcast variable
* @param serializedOnly if true, do not cache the unserialized value on the driver
* @return `Broadcast` object, a read-only variable cached on each machine
*/
def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def newBroadcast[T: ClassTag](
value: T,
isLocal: Boolean,
id: Long,
serializedOnly: Boolean = false): Broadcast[T]

def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ private[spark] class BroadcastManager(
.asInstanceOf[java.util.Map[Any, Any]]
)

def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
def newBroadcast[T: ClassTag](
value_ : T,
isLocal: Boolean,
serializedOnly: Boolean = false): Broadcast[T] = {
val bid = nextBroadcastId.getAndIncrement()
value_ match {
case pb: PythonBroadcast =>
Expand All @@ -72,7 +75,7 @@ private[spark] class BroadcastManager(

case _ => // do nothing
}
broadcastFactory.newBroadcast[T](value_, isLocal, bid)
broadcastFactory.newBroadcast[T](value_, isLocal, bid, serializedOnly)
}

def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.broadcast

import java.io._
import java.lang.ref.SoftReference
import java.lang.ref.{Reference, SoftReference, WeakReference}
import java.nio.ByteBuffer
import java.util.zip.Adler32

Expand Down Expand Up @@ -54,8 +54,9 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea
*
* @param obj object to broadcast
* @param id A unique identifier for the broadcast variable.
* @param serializedOnly if true, do not cache the unserialized value on the driver
*/
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long, serializedOnly: Boolean)
extends Broadcast[T](id) with Logging with Serializable {

/**
Expand All @@ -64,15 +65,17 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
*
* On the driver, if the value is required, it is read lazily from the block manager. We hold
* a soft reference so that it can be garbage collected if required, as we can always reconstruct
* in the future.
* in the future. For internal broadcast variables where `serializedOnly = true`, we hold a
* WeakReference to allow the value to be reclaimed more aggressively.
*/
@transient private var _value: SoftReference[T] = _
@transient private var _value: Reference[T] = _

/** The compression codec to use, or None if compression is disabled */
@transient private var compressionCodec: Option[CompressionCodec] = _
/** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
@transient private var blockSize: Int = _

/** Is the execution in local mode. */
@transient private var isLocalMaster: Boolean = _

/** Whether to generate checksum for blocks or not. */
private var checksumEnabled: Boolean = false
Expand All @@ -86,6 +89,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
// Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided
blockSize = conf.get(config.BROADCAST_BLOCKSIZE).toInt * 1024
checksumEnabled = conf.get(config.BROADCAST_CHECKSUM)
isLocalMaster = Utils.isLocalMaster(conf)
}
setConf(SparkEnv.get.conf)

Expand All @@ -103,7 +107,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
memoized
} else {
val newlyRead = readBroadcastBlock()
_value = new SoftReference[T](newlyRead)
_value = if (serializedOnly) {
new WeakReference[T](newlyRead)
} else {
new SoftReference[T](newlyRead)
}
newlyRead
}
}
Expand All @@ -129,11 +137,23 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
*/
private def writeBlocks(value: T): Int = {
import StorageLevel._
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
// do not create a duplicate copy of the broadcast variable's value.
val blockManager = SparkEnv.get.blockManager
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
if (serializedOnly && !isLocalMaster) {
// SPARK-39983: When creating a broadcast variable internal to Spark (such as a broadcasted
// hashed relation), don't store the broadcasted value in the driver's block manager:
// we do not expect internal broadcast variables' values to be read on the driver, so
// skipping the store reduces driver memory pressure because we don't add a long-lived
// reference to the broadcasted object. However, this optimization cannot be applied for
// local mode (since tasks might run on the driver). To guard against performance
// regressions if an internal broadcast is accessed on the driver, we store a weak
// reference to the broadcasted value:
_value = new WeakReference[T](value)
} else {
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
// do not create a duplicate copy of the broadcast variable's value.
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
}
try {
val blocks =
Expand Down Expand Up @@ -258,11 +278,14 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
try {
val obj = TorrentBroadcast.unBlockifyObject[T](
blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")

if (!serializedOnly || isLocalMaster || Utils.isInRunningSparkTask) {
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
}

if (obj != null) {
Expand Down Expand Up @@ -297,6 +320,20 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}
}

// Is the unserialized value cached. Exposed for testing.
private[spark] def hasCachedValue: Boolean = {
TorrentBroadcast.torrentBroadcastLock.withLock(broadcastId) {
setConf(SparkEnv.get.conf)
val blockManager = SparkEnv.get.blockManager
blockManager.getLocalValues(broadcastId) match {
case Some(blockResult) if (blockResult.data.hasNext) =>
val x = blockResult.data.next().asInstanceOf[T]
releaseBlockManagerLock(broadcastId)
x != null
case _ => false
}
}
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ private[spark] class TorrentBroadcastFactory extends BroadcastFactory {

override def initialize(isDriver: Boolean, conf: SparkConf): Unit = { }

override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
new TorrentBroadcast[T](value_, id)
override def newBroadcast[T: ClassTag](
value_ : T,
isLocal: Boolean,
id: Long,
serializedOnly: Boolean = false): Broadcast[T] = {
new TorrentBroadcast[T](value_, id, serializedOnly)
}

override def stop(): Unit = { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,25 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio
assert(instances.size === 1)
}

test("SPARK-39983 - Broadcasted value not cached on driver") {
// Use distributed cluster as in local mode the broabcast value is actually cached.
val conf = new SparkConf()
.setMaster("local-cluster[2,1,1024]")
.setAppName("test")
sc = new SparkContext(conf)

sc.broadcastInternal(value = 1234, serializedOnly = false) match {
case tb: TorrentBroadcast[Int] =>
assert(tb.hasCachedValue)
assert(1234 === tb.value)
}
sc.broadcastInternal(value = 1234, serializedOnly = true) match {
case tb: TorrentBroadcast[Int] =>
assert(!tb.hasCachedValue)
assert(1234 === tb.value)
}
}

/**
* Verify the persistence of state associated with a TorrentBroadcast in a local-cluster.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ case class BroadcastExchangeExec(
val beforeBroadcast = System.nanoTime()
longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)

// Broadcast the relation
val broadcasted = sparkContext.broadcast(relation)
// SPARK-39983 - Broadcast the relation without caching the unserialized object.
val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true)
longMetric("broadcastTime") += NANOSECONDS.toMillis(
System.nanoTime() - beforeBroadcast)
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package org.apache.spark.sql.execution

import java.util.concurrent.{CountDownLatch, TimeUnit}

import org.apache.spark.SparkException
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite}
import org.apache.spark.broadcast.TorrentBroadcast
import org.apache.spark.scheduler._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.HashedRelation
Expand Down Expand Up @@ -94,3 +96,28 @@ class BroadcastExchangeSuite extends SparkPlanTest
}
}
}

// Additional tests run in 'local-cluster' mode.
class BroadcastExchangeExecSparkSuite
extends SparkFunSuite with LocalSparkContext with AdaptiveSparkPlanHelper {

test("SPARK-39983 - Broadcasted relation is not cached on the driver") {
// Use distributed cluster as in local mode the broabcast value is actually cached.
val conf = new SparkConf()
.setMaster("local-cluster[2,1,1024]")
.setAppName("test")
sc = new SparkContext(conf)
val spark = new SparkSession(sc)

val df = spark.range(1).toDF()
val joinDF = df.join(broadcast(df), "id")
val broadcastExchangeExec = collect(
joinDF.queryExecution.executedPlan) { case p: BroadcastExchangeExec => p }
assert(broadcastExchangeExec.size == 1, "one and only BroadcastExchangeExec")

// The broadcasted relation should not be cached on the driver.
val broadcasted =
broadcastExchangeExec(0).relationFuture.get().asInstanceOf[TorrentBroadcast[Any]]
assert(!broadcasted.hasCachedValue)
}
}