diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index cbd49e070f2eb..26ead57316e18 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -18,6 +18,7 @@ package org.apache.spark.broadcast import java.io._ +import java.lang.ref.SoftReference import java.nio.ByteBuffer import java.util.zip.Adler32 @@ -61,9 +62,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]], * which builds this value by reading blocks from the driver and/or other executors. * - * On the driver, if the value is required, it is read lazily from the block manager. + * On the driver, if the value is required, it is read lazily from the block manager. We hold + * a soft reference so that it can be garbage collected if required, as we can always reconstruct + * in the future. */ - @transient private lazy val _value: T = readBroadcastBlock() + @transient private var _value: SoftReference[T] = _ /** The compression codec to use, or None if compression is disabled */ @transient private var compressionCodec: Option[CompressionCodec] = _ @@ -92,8 +95,15 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) /** The checksum for all the blocks. */ private var checksums: Array[Int] = _ - override protected def getValue() = { - _value + override protected def getValue() = synchronized { + val memoized: T = if (_value == null) null.asInstanceOf[T] else _value.get + if (memoized != null) { + memoized + } else { + val newlyRead = readBroadcastBlock() + _value = new SoftReference[T](newlyRead) + newlyRead + } } private def calcChecksum(block: ByteBuffer): Int = { @@ -205,8 +215,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } private def readBroadcastBlock(): T = Utils.tryOrIOException { - TorrentBroadcast.synchronized { - val broadcastCache = SparkEnv.get.broadcastManager.cachedValues + val broadcastCache = SparkEnv.get.broadcastManager.cachedValues + broadcastCache.synchronized { Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse { setConf(SparkEnv.get.conf)