diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index d44066f662b07..d42a66c47d2a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression @@ -62,10 +63,29 @@ case class BroadcastHashJoin( override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + private[this] val broadcastFutureInitLock = new Object() + + // Use @volatile so we can read a snapshot of this without locking. + @volatile + private[this] var broadcastFutureValue: Option[Future[Broadcast[HashedRelation]]] = None + // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value // for the same query. - @transient - private lazy val broadcastFuture = { + private def broadcastFuture = { + broadcastFutureInitLock.synchronized { + if (broadcastFutureValue.isEmpty) { + broadcastFutureValue = Some(createBroadcastFuture) + } + broadcastFutureValue.get + } + } + + /** + * Exposes the broadcast future so we can do external accounting of memory usage. + */ + def broadcastFutureOpt: Option[Future[Broadcast[HashedRelation]]] = broadcastFutureValue + + private def createBroadcastFuture: Future[Broadcast[HashedRelation]] = { val numBuildRows = buildSide match { case BuildLeft => longMetric("numLeftRows") case BuildRight => longMetric("numRightRows")