From e6baa25db00e9ce83ec4ce50e92e91e3daac17ae Mon Sep 17 00:00:00 2001 From: Alex Balikov Date: Thu, 4 Aug 2022 18:04:36 -0700 Subject: [PATCH 1/5] SPARK-39983 - do not cache unserialized broadcast relations on the driver --- .../scala/org/apache/spark/SparkContext.scala | 19 ++++++- .../spark/broadcast/BroadcastFactory.scala | 8 ++- .../spark/broadcast/BroadcastManager.scala | 7 ++- .../spark/broadcast/TorrentBroadcast.scala | 53 +++++++++++++++---- .../broadcast/TorrentBroadcastFactory.scala | 8 ++- .../spark/broadcast/BroadcastSuite.scala | 19 +++++++ .../exchange/BroadcastExchangeExec.scala | 4 +- .../execution/BroadcastExchangeSuite.scala | 29 +++++++++- 8 files changed, 126 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6cb4f04ac7f74..f101dc8e083f4 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1511,16 +1511,31 @@ class SparkContext(config: SparkConf) extends Logging { /** * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. - * The variable will be sent to each cluster only once. + * The variable will be sent to each executor only once. * * @param value value to broadcast to the Spark nodes * @return `Broadcast` object, a read-only variable cached on each machine */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { + broadcastInternal(value, serializedOnly = false) + } + + /** + * Internal version of broadcast - broadcast a read-only variable to the cluster, returning a + * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. + * The variable will be sent to each executor only once. + * + * @param value value to broadcast to the Spark nodes + * @param serializedOnly if true, do not cache the unserialized value on the driver + * @return `Broadcast` object, a read-only variable cached on each machine + */ + private[spark] def broadcastInternal[T: ClassTag]( + value: T, + serializedOnly: Boolean): Broadcast[T] = { assertNotStopped() require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass), "Can not directly broadcast RDDs; instead, call collect() and broadcast the result.") - val bc = env.broadcastManager.newBroadcast[T](value, isLocal) + val bc = env.broadcastManager.newBroadcast[T](value, isLocal, serializedOnly) val callSite = getCallSite logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) cleaner.foreach(_.registerBroadcastForCleanup(bc)) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 9891582501b8b..38d642753ad3a 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -36,8 +36,14 @@ private[spark] trait BroadcastFactory { * @param value value to broadcast * @param isLocal whether we are in local mode (single JVM process) * @param id unique id representing this broadcast variable + * @param serializedOnly if true, do not cache the unserialized value on the driver + * @return `Broadcast` object, a read-only variable cached on each machine */ - def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def newBroadcast[T: ClassTag]( + value: T, + isLocal: Boolean, + id: Long, + serializedOnly: Boolean = false): Broadcast[T] def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index b6f59c36081f5..cd152709a1f37 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -60,7 +60,10 @@ private[spark] class BroadcastManager( .asInstanceOf[java.util.Map[Any, Any]] ) - def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { + def newBroadcast[T: ClassTag]( + value_ : T, + isLocal: Boolean, + serializedOnly: Boolean = false): Broadcast[T] = { val bid = nextBroadcastId.getAndIncrement() value_ match { case pb: PythonBroadcast => @@ -72,7 +75,7 @@ private[spark] class BroadcastManager( case _ => // do nothing } - broadcastFactory.newBroadcast[T](value_, isLocal, bid) + broadcastFactory.newBroadcast[T](value_, isLocal, bid, serializedOnly) } def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index e35a079746a64..9ad220a35e31f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -54,8 +54,9 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea * * @param obj object to broadcast * @param id A unique identifier for the broadcast variable. + * @param serializedOnly if true, do not cache the unserialized value on the driver */ -private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) +private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long, serializedOnly: Boolean) extends Broadcast[T](id) with Logging with Serializable { /** @@ -72,7 +73,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) @transient private var compressionCodec: Option[CompressionCodec] = _ /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */ @transient private var blockSize: Int = _ - + /** Is the execution in local mode. */ + @transient private var isLocalMaster: Boolean = _ /** Whether to generate checksum for blocks or not. */ private var checksumEnabled: Boolean = false @@ -86,6 +88,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided blockSize = conf.get(config.BROADCAST_BLOCKSIZE).toInt * 1024 checksumEnabled = conf.get(config.BROADCAST_CHECKSUM) + isLocalMaster = Utils.isLocalMaster(conf) } setConf(SparkEnv.get.conf) @@ -129,11 +132,22 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) */ private def writeBlocks(value: T): Int = { import StorageLevel._ - // Store a copy of the broadcast variable in the driver so that tasks run on the driver - // do not create a duplicate copy of the broadcast variable's value. val blockManager = SparkEnv.get.blockManager - if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + if (serializedOnly && !isLocalMaster) { + // SPARK-39983: When creating a broadcast variable internal to Spark (such as a broadcasted + // hashed relation), don't store the broadcasted value in the driver's block manager: + // we do not expect internal broadcast variables' values to be read on the driver, so + // skipping the store reduces driver memory pressure because we don't add a long-lived + // reference to the broadcasted object. However, this optimization cannot be applied for + // local mode (since tasks might run on the driver). To guard against performance + // regressions if an internal broadcast is accessed on the driver, we store a soft + // reference to the broadcasted value: + _value = new SoftReference[T](value) + } else { // Store a copy of the broadcast variable in the driver so that tasks run on the driver + // do not create a duplicate copy of the broadcast variable's value. + if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } } try { val blocks = @@ -258,11 +272,14 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) try { val obj = TorrentBroadcast.unBlockifyObject[T]( blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + + if (!serializedOnly || isLocalMaster || Utils.isInRunningSparkTask) { + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } } if (obj != null) { @@ -297,6 +314,20 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } } + // Is the unserialized value cached. Exposed for testing. + private[spark] def hasCachedValue: Boolean = { + TorrentBroadcast.torrentBroadcastLock.withLock(broadcastId) { + setConf(SparkEnv.get.conf) + val blockManager = SparkEnv.get.blockManager + blockManager.getLocalValues(broadcastId) match { + case Some(blockResult) if (blockResult.data.hasNext) => + val x = blockResult.data.next().asInstanceOf[T] + releaseBlockManagerLock(broadcastId) + x != null + case _ => false + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index 6846e1967c4d6..4ff39ba40742c 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -30,8 +30,12 @@ private[spark] class TorrentBroadcastFactory extends BroadcastFactory { override def initialize(isDriver: Boolean, conf: SparkConf): Unit = { } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = { - new TorrentBroadcast[T](value_, id) + override def newBroadcast[T: ClassTag]( + value_ : T, + isLocal: Boolean, + id: Long, + serializedOnly: Boolean = false): Broadcast[T] = { + new TorrentBroadcast[T](value_, id, serializedOnly) } override def stop(): Unit = { } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 5e8b25f425166..41452076f888f 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -187,6 +187,25 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio assert(instances.size === 1) } + test("SPARK-39983 - Broadcasted value not cached on driver") { + // Use distributed cluster as in local mode the broabcast value is actually cached. + val conf = new SparkConf() + .setMaster("local-cluster[2,1,1024]") + .setAppName("test") + sc = new SparkContext(conf) + + sc.broadcastInternal(value = 1234, serializedOnly = false) match { + case tb: TorrentBroadcast[Int] => + assert(tb.hasCachedValue) + assert(1234 === tb.value) + } + sc.broadcastInternal(value = 1234, serializedOnly = true) match { + case tb: TorrentBroadcast[Int] => + assert(!tb.hasCachedValue) + assert(1234 === tb.value) + } + } + /** * Verify the persistence of state associated with a TorrentBroadcast in a local-cluster. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index accd0a064ea63..77b1b30df98ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -166,8 +166,8 @@ case class BroadcastExchangeExec( val beforeBroadcast = System.nanoTime() longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild) - // Broadcast the relation - val broadcasted = sparkContext.broadcast(relation) + // SPARK-39983 - Broadcast the relation without caching the unserialized value. + val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true) longMetric("broadcastTime") += NANOSECONDS.toMillis( System.nanoTime() - beforeBroadcast) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala index 7d6306b65ff47..129f76d7be38f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.execution import java.util.concurrent.{CountDownLatch, TimeUnit} -import org.apache.spark.SparkException +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.broadcast.TorrentBroadcast import org.apache.spark.scheduler._ +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.HashedRelation @@ -94,3 +96,28 @@ class BroadcastExchangeSuite extends SparkPlanTest } } } + +// Additional tests run in 'local-cluster' mode. +class BroadcastExchangeExecSparkSuite + extends SparkFunSuite with LocalSparkContext with AdaptiveSparkPlanHelper { + + test("SPARK-39983 - Broadcasted relation is not cached on the driver") { + // Use distributed cluster as in local mode the broabcast value is actually cached. + val conf = new SparkConf() + .setMaster("local-cluster[2,1,1024]") + .setAppName("test") + sc = new SparkContext(conf) + val spark = new SparkSession(sc) + + val df = spark.range(1).toDF() + val joinDF = df.join(broadcast(df), "id") + val broadcastExchangeExec = collect( + joinDF.queryExecution.executedPlan) { case p: BroadcastExchangeExec => p } + assert(broadcastExchangeExec.size == 1, "one and only BroadcastExchangeExec") + + // The broadcasted relation should not be cached on the driver. + val broadcasted = + broadcastExchangeExec(0).relationFuture.get().asInstanceOf[TorrentBroadcast[Any]] + assert(!broadcasted.hasCachedValue) + } +} From 75ab18ee0e382b8117bf65fc9ef05190d4fdf01a Mon Sep 17 00:00:00 2001 From: Alex Balikov Date: Thu, 4 Aug 2022 18:32:42 -0700 Subject: [PATCH 2/5] comment fixes --- .../scala/org/apache/spark/broadcast/TorrentBroadcast.scala | 3 ++- .../spark/sql/execution/exchange/BroadcastExchangeExec.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 9ad220a35e31f..68391227b8eae 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -143,7 +143,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long, serializedO // regressions if an internal broadcast is accessed on the driver, we store a soft // reference to the broadcasted value: _value = new SoftReference[T](value) - } else { // Store a copy of the broadcast variable in the driver so that tasks run on the driver + } else { + // Store a copy of the broadcast variable in the driver so that tasks run on the driver // do not create a duplicate copy of the broadcast variable's value. if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) { throw new SparkException(s"Failed to store $broadcastId in BlockManager") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 77b1b30df98ce..548a8628ba44d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -166,7 +166,7 @@ case class BroadcastExchangeExec( val beforeBroadcast = System.nanoTime() longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild) - // SPARK-39983 - Broadcast the relation without caching the unserialized value. + // SPARK-39983 - Broadcast the relation without caching the unserialized object. val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true) longMetric("broadcastTime") += NANOSECONDS.toMillis( System.nanoTime() - beforeBroadcast) From cb2bfa003085a3e49453223797b8c04efbac6ff8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 9 Aug 2022 00:20:41 -0700 Subject: [PATCH 3/5] Store WeakReference for serializedOnly broadcast variables --- .../spark/broadcast/TorrentBroadcast.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 68391227b8eae..8f91f673aa9df 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -18,7 +18,7 @@ package org.apache.spark.broadcast import java.io._ -import java.lang.ref.SoftReference +import java.lang.ref.{Reference, SoftReference, WeakReference} import java.nio.ByteBuffer import java.util.zip.Adler32 @@ -65,9 +65,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long, serializedO * * On the driver, if the value is required, it is read lazily from the block manager. We hold * a soft reference so that it can be garbage collected if required, as we can always reconstruct - * in the future. + * in the future. For internal broadcast variables where `serializedOnly = true`, we hold a + * WeakReference to allow the value to be reclaimed more aggressively. */ - @transient private var _value: SoftReference[T] = _ + @transient private var _value: Reference[T] = _ /** The compression codec to use, or None if compression is disabled */ @transient private var compressionCodec: Option[CompressionCodec] = _ @@ -106,7 +107,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long, serializedO memoized } else { val newlyRead = readBroadcastBlock() - _value = new SoftReference[T](newlyRead) + _value = if (serializedOnly) { + new WeakReference[T](newlyRead) + } else { + new SoftReference[T](newlyRead) + } newlyRead } } @@ -140,9 +145,9 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long, serializedO // skipping the store reduces driver memory pressure because we don't add a long-lived // reference to the broadcasted object. However, this optimization cannot be applied for // local mode (since tasks might run on the driver). To guard against performance - // regressions if an internal broadcast is accessed on the driver, we store a soft + // regressions if an internal broadcast is accessed on the driver, we store a weak // reference to the broadcasted value: - _value = new SoftReference[T](value) + _value = new WeakReference[T](value) } else { // Store a copy of the broadcast variable in the driver so that tasks run on the driver // do not create a duplicate copy of the broadcast variable's value. From 63db4ff98c28af37b98829cbe8352e8c1cbc2c8e Mon Sep 17 00:00:00 2001 From: Alex Balikov Date: Tue, 9 Aug 2022 21:46:17 -0700 Subject: [PATCH 4/5] Empty-Commit From 6a725a840ee7f56c633e5384a182c56f78aeeba1 Mon Sep 17 00:00:00 2001 From: Alex Balikov Date: Wed, 10 Aug 2022 13:10:21 -0700 Subject: [PATCH 5/5] Empty-Commit