Skip to content

Commit e17d8ec

Browse files
[SPARK-39983][CORE][SQL] Do not cache unserialized broadcast relations on the driver
### What changes were proposed in this pull request? This PR addresses the issue raised in https://issues.apache.org/jira/browse/SPARK-39983 - broadcast relations should not be cached on the driver as they are not needed and can cause significant memory pressure (in one case the relation was 60MB ) The PR adds a new SparkContext.broadcastInternal method with parameter serializedOnly allowing the caller to specify that the broadcasted object should be stored only in serialized form. The current behavior is to also cache an unserialized form of the object. The PR changes the broadcast implementation in TorrentBroadcast to honor the serializedOnly flag and not store the unserialized value, unless the execution is in a local mode (single process). In that case the broadcast cache is effectively shared between driver and executors and thus the unserialized value needs to be cached to satisfy the executor-side of the functionality. ### Why are the changes needed? The broadcast relations can be fairly large (observed 60MB one) and are not needed in unserialized form on the driver. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added a new unit test to BroadcastSuite verifying the low-level broadcast functionality in respect to the serializedOnly flag. Added a new unit test to BroadcastExchangeSuite verifying that broadcasted relations are not cached on the driver. Closes #37413 from alex-balikov/SPARK-39983-broadcast-no-cache. Lead-authored-by: Alex Balikov <[email protected]> Co-authored-by: Josh Rosen <[email protected]> Signed-off-by: Josh Rosen <[email protected]>
1 parent 25759a0 commit e17d8ec

File tree

8 files changed

+136
-25
lines changed

8 files changed

+136
-25
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,16 +1511,31 @@ class SparkContext(config: SparkConf) extends Logging {
15111511
/**
15121512
* Broadcast a read-only variable to the cluster, returning a
15131513
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
1514-
* The variable will be sent to each cluster only once.
1514+
* The variable will be sent to each executor only once.
15151515
*
15161516
* @param value value to broadcast to the Spark nodes
15171517
* @return `Broadcast` object, a read-only variable cached on each machine
15181518
*/
15191519
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
1520+
broadcastInternal(value, serializedOnly = false)
1521+
}
1522+
1523+
/**
1524+
* Internal version of broadcast - broadcast a read-only variable to the cluster, returning a
1525+
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
1526+
* The variable will be sent to each executor only once.
1527+
*
1528+
* @param value value to broadcast to the Spark nodes
1529+
* @param serializedOnly if true, do not cache the unserialized value on the driver
1530+
* @return `Broadcast` object, a read-only variable cached on each machine
1531+
*/
1532+
private[spark] def broadcastInternal[T: ClassTag](
1533+
value: T,
1534+
serializedOnly: Boolean): Broadcast[T] = {
15201535
assertNotStopped()
15211536
require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
15221537
"Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
1523-
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
1538+
val bc = env.broadcastManager.newBroadcast[T](value, isLocal, serializedOnly)
15241539
val callSite = getCallSite
15251540
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
15261541
cleaner.foreach(_.registerBroadcastForCleanup(bc))

core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,14 @@ private[spark] trait BroadcastFactory {
3636
* @param value value to broadcast
3737
* @param isLocal whether we are in local mode (single JVM process)
3838
* @param id unique id representing this broadcast variable
39+
* @param serializedOnly if true, do not cache the unserialized value on the driver
40+
* @return `Broadcast` object, a read-only variable cached on each machine
3941
*/
40-
def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
42+
def newBroadcast[T: ClassTag](
43+
value: T,
44+
isLocal: Boolean,
45+
id: Long,
46+
serializedOnly: Boolean = false): Broadcast[T]
4147

4248
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
4349

core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ private[spark] class BroadcastManager(
6060
.asInstanceOf[java.util.Map[Any, Any]]
6161
)
6262

63-
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
63+
def newBroadcast[T: ClassTag](
64+
value_ : T,
65+
isLocal: Boolean,
66+
serializedOnly: Boolean = false): Broadcast[T] = {
6467
val bid = nextBroadcastId.getAndIncrement()
6568
value_ match {
6669
case pb: PythonBroadcast =>
@@ -72,7 +75,7 @@ private[spark] class BroadcastManager(
7275

7376
case _ => // do nothing
7477
}
75-
broadcastFactory.newBroadcast[T](value_, isLocal, bid)
78+
broadcastFactory.newBroadcast[T](value_, isLocal, bid, serializedOnly)
7679
}
7780

7881
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = {

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.broadcast
1919

2020
import java.io._
21-
import java.lang.ref.SoftReference
21+
import java.lang.ref.{Reference, SoftReference, WeakReference}
2222
import java.nio.ByteBuffer
2323
import java.util.zip.Adler32
2424

@@ -54,8 +54,9 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea
5454
*
5555
* @param obj object to broadcast
5656
* @param id A unique identifier for the broadcast variable.
57+
* @param serializedOnly if true, do not cache the unserialized value on the driver
5758
*/
58-
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
59+
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long, serializedOnly: Boolean)
5960
extends Broadcast[T](id) with Logging with Serializable {
6061

6162
/**
@@ -64,15 +65,17 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
6465
*
6566
* On the driver, if the value is required, it is read lazily from the block manager. We hold
6667
* a soft reference so that it can be garbage collected if required, as we can always reconstruct
67-
* in the future.
68+
* in the future. For internal broadcast variables where `serializedOnly = true`, we hold a
69+
* WeakReference to allow the value to be reclaimed more aggressively.
6870
*/
69-
@transient private var _value: SoftReference[T] = _
71+
@transient private var _value: Reference[T] = _
7072

7173
/** The compression codec to use, or None if compression is disabled */
7274
@transient private var compressionCodec: Option[CompressionCodec] = _
7375
/** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
7476
@transient private var blockSize: Int = _
75-
77+
/** Is the execution in local mode. */
78+
@transient private var isLocalMaster: Boolean = _
7679

7780
/** Whether to generate checksum for blocks or not. */
7881
private var checksumEnabled: Boolean = false
@@ -86,6 +89,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
8689
// Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided
8790
blockSize = conf.get(config.BROADCAST_BLOCKSIZE).toInt * 1024
8891
checksumEnabled = conf.get(config.BROADCAST_CHECKSUM)
92+
isLocalMaster = Utils.isLocalMaster(conf)
8993
}
9094
setConf(SparkEnv.get.conf)
9195

@@ -103,7 +107,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
103107
memoized
104108
} else {
105109
val newlyRead = readBroadcastBlock()
106-
_value = new SoftReference[T](newlyRead)
110+
_value = if (serializedOnly) {
111+
new WeakReference[T](newlyRead)
112+
} else {
113+
new SoftReference[T](newlyRead)
114+
}
107115
newlyRead
108116
}
109117
}
@@ -129,11 +137,23 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
129137
*/
130138
private def writeBlocks(value: T): Int = {
131139
import StorageLevel._
132-
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
133-
// do not create a duplicate copy of the broadcast variable's value.
134140
val blockManager = SparkEnv.get.blockManager
135-
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
136-
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
141+
if (serializedOnly && !isLocalMaster) {
142+
// SPARK-39983: When creating a broadcast variable internal to Spark (such as a broadcasted
143+
// hashed relation), don't store the broadcasted value in the driver's block manager:
144+
// we do not expect internal broadcast variables' values to be read on the driver, so
145+
// skipping the store reduces driver memory pressure because we don't add a long-lived
146+
// reference to the broadcasted object. However, this optimization cannot be applied for
147+
// local mode (since tasks might run on the driver). To guard against performance
148+
// regressions if an internal broadcast is accessed on the driver, we store a weak
149+
// reference to the broadcasted value:
150+
_value = new WeakReference[T](value)
151+
} else {
152+
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
153+
// do not create a duplicate copy of the broadcast variable's value.
154+
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
155+
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
156+
}
137157
}
138158
try {
139159
val blocks =
@@ -258,11 +278,14 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
258278
try {
259279
val obj = TorrentBroadcast.unBlockifyObject[T](
260280
blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
261-
// Store the merged copy in BlockManager so other tasks on this executor don't
262-
// need to re-fetch it.
263-
val storageLevel = StorageLevel.MEMORY_AND_DISK
264-
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
265-
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
281+
282+
if (!serializedOnly || isLocalMaster || Utils.isInRunningSparkTask) {
283+
// Store the merged copy in BlockManager so other tasks on this executor don't
284+
// need to re-fetch it.
285+
val storageLevel = StorageLevel.MEMORY_AND_DISK
286+
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
287+
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
288+
}
266289
}
267290

268291
if (obj != null) {
@@ -297,6 +320,20 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
297320
}
298321
}
299322

323+
// Is the unserialized value cached. Exposed for testing.
324+
private[spark] def hasCachedValue: Boolean = {
325+
TorrentBroadcast.torrentBroadcastLock.withLock(broadcastId) {
326+
setConf(SparkEnv.get.conf)
327+
val blockManager = SparkEnv.get.blockManager
328+
blockManager.getLocalValues(broadcastId) match {
329+
case Some(blockResult) if (blockResult.data.hasNext) =>
330+
val x = blockResult.data.next().asInstanceOf[T]
331+
releaseBlockManagerLock(broadcastId)
332+
x != null
333+
case _ => false
334+
}
335+
}
336+
}
300337
}
301338

302339

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@ private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
3030

3131
override def initialize(isDriver: Boolean, conf: SparkConf): Unit = { }
3232

33-
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
34-
new TorrentBroadcast[T](value_, id)
33+
override def newBroadcast[T: ClassTag](
34+
value_ : T,
35+
isLocal: Boolean,
36+
id: Long,
37+
serializedOnly: Boolean = false): Broadcast[T] = {
38+
new TorrentBroadcast[T](value_, id, serializedOnly)
3539
}
3640

3741
override def stop(): Unit = { }

core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,25 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio
187187
assert(instances.size === 1)
188188
}
189189

190+
test("SPARK-39983 - Broadcasted value not cached on driver") {
191+
// Use distributed cluster as in local mode the broabcast value is actually cached.
192+
val conf = new SparkConf()
193+
.setMaster("local-cluster[2,1,1024]")
194+
.setAppName("test")
195+
sc = new SparkContext(conf)
196+
197+
sc.broadcastInternal(value = 1234, serializedOnly = false) match {
198+
case tb: TorrentBroadcast[Int] =>
199+
assert(tb.hasCachedValue)
200+
assert(1234 === tb.value)
201+
}
202+
sc.broadcastInternal(value = 1234, serializedOnly = true) match {
203+
case tb: TorrentBroadcast[Int] =>
204+
assert(!tb.hasCachedValue)
205+
assert(1234 === tb.value)
206+
}
207+
}
208+
190209
/**
191210
* Verify the persistence of state associated with a TorrentBroadcast in a local-cluster.
192211
*

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ case class BroadcastExchangeExec(
166166
val beforeBroadcast = System.nanoTime()
167167
longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)
168168

169-
// Broadcast the relation
170-
val broadcasted = sparkContext.broadcast(relation)
169+
// SPARK-39983 - Broadcast the relation without caching the unserialized object.
170+
val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true)
171171
longMetric("broadcastTime") += NANOSECONDS.toMillis(
172172
System.nanoTime() - beforeBroadcast)
173173
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)

sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ package org.apache.spark.sql.execution
1919

2020
import java.util.concurrent.{CountDownLatch, TimeUnit}
2121

22-
import org.apache.spark.SparkException
22+
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite}
23+
import org.apache.spark.broadcast.TorrentBroadcast
2324
import org.apache.spark.scheduler._
25+
import org.apache.spark.sql.SparkSession
2426
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
2527
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
2628
import org.apache.spark.sql.execution.joins.HashedRelation
@@ -94,3 +96,28 @@ class BroadcastExchangeSuite extends SparkPlanTest
9496
}
9597
}
9698
}
99+
100+
// Additional tests run in 'local-cluster' mode.
101+
class BroadcastExchangeExecSparkSuite
102+
extends SparkFunSuite with LocalSparkContext with AdaptiveSparkPlanHelper {
103+
104+
test("SPARK-39983 - Broadcasted relation is not cached on the driver") {
105+
// Use distributed cluster as in local mode the broabcast value is actually cached.
106+
val conf = new SparkConf()
107+
.setMaster("local-cluster[2,1,1024]")
108+
.setAppName("test")
109+
sc = new SparkContext(conf)
110+
val spark = new SparkSession(sc)
111+
112+
val df = spark.range(1).toDF()
113+
val joinDF = df.join(broadcast(df), "id")
114+
val broadcastExchangeExec = collect(
115+
joinDF.queryExecution.executedPlan) { case p: BroadcastExchangeExec => p }
116+
assert(broadcastExchangeExec.size == 1, "one and only BroadcastExchangeExec")
117+
118+
// The broadcasted relation should not be cached on the driver.
119+
val broadcasted =
120+
broadcastExchangeExec(0).relationFuture.get().asInstanceOf[TorrentBroadcast[Any]]
121+
assert(!broadcasted.hasCachedValue)
122+
}
123+
}

0 commit comments

Comments
 (0)