1717
1818package org .apache .spark .sql .execution .exchange
1919
20- import java .util .concurrent .TimeoutException
20+ import java .util .UUID
21+ import java .util .concurrent ._
2122
22- import scala .concurrent .{ ExecutionContext , Future }
23+ import scala .concurrent .ExecutionContext
2324import scala .concurrent .duration ._
2425import scala .util .control .NonFatal
2526
@@ -43,6 +44,8 @@ case class BroadcastExchangeExec(
4344 mode : BroadcastMode ,
4445 child : SparkPlan ) extends Exchange {
4546
47+ private val runId : UUID = UUID .randomUUID
48+
4649 override lazy val metrics = Map (
4750 " dataSize" -> SQLMetrics .createSizeMetric(sparkContext, " data size" ),
4851 " collectTime" -> SQLMetrics .createTimingMetric(sparkContext, " time to collect" ),
@@ -67,68 +70,74 @@ case class BroadcastExchangeExec(
6770
6871 @ transient
6972 private lazy val relationFuture : Future [broadcast.Broadcast [Any ]] = {
70- // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
73+ // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
7174 val executionId = sparkContext.getLocalProperty(SQLExecution .EXECUTION_ID_KEY )
72- Future {
73- // This will run in another thread. Set the execution id so that we can connect these jobs
74- // with the correct execution.
75- SQLExecution .withExecutionId(sqlContext.sparkSession, executionId) {
76- try {
77- val beforeCollect = System .nanoTime()
78- // Use executeCollect/executeCollectIterator to avoid conversion to Scala types
79- val (numRows, input) = child.executeCollectIterator()
80- if (numRows >= 512000000 ) {
81- throw new SparkException (
82- s " Cannot broadcast the table with 512 million or more rows: $numRows rows " )
83- }
84-
85- val beforeBuild = System .nanoTime()
86- longMetric(" collectTime" ) += NANOSECONDS .toMillis(beforeBuild - beforeCollect)
87-
88- // Construct the relation.
89- val relation = mode.transform(input, Some (numRows))
90-
91- val dataSize = relation match {
92- case map : HashedRelation =>
93- map.estimatedSize
94- case arr : Array [InternalRow ] =>
95- arr.map(_.asInstanceOf [UnsafeRow ].getSizeInBytes.toLong).sum
96- case _ =>
97- throw new SparkException (" [BUG] BroadcastMode.transform returned unexpected type: " +
98- relation.getClass.getName)
75+ val task = new Callable [broadcast.Broadcast [Any ]]() {
76+ override def call (): broadcast.Broadcast [Any ] = {
77+ // This will run in another thread. Set the execution id so that we can connect these jobs
78+ // with the correct execution.
79+ SQLExecution .withExecutionId(sqlContext.sparkSession, executionId) {
80+ try {
81+ sparkContext.setJobGroup(runId.toString, s " broadcast exchange (runId $runId) " ,
82+ interruptOnCancel = true )
83+ val beforeCollect = System .nanoTime()
84+ // Use executeCollect/executeCollectIterator to avoid conversion to Scala types
85+ val (numRows, input) = child.executeCollectIterator()
86+ if (numRows >= 512000000 ) {
87+ throw new SparkException (
88+ s " Cannot broadcast the table with 512 million or more rows: $numRows rows " )
89+ }
90+
91+ val beforeBuild = System .nanoTime()
92+ longMetric(" collectTime" ) += NANOSECONDS .toMillis(beforeBuild - beforeCollect)
93+
94+ // Construct the relation.
95+ val relation = mode.transform(input, Some (numRows))
96+
97+ val dataSize = relation match {
98+ case map : HashedRelation =>
99+ map.estimatedSize
100+ case arr : Array [InternalRow ] =>
101+ arr.map(_.asInstanceOf [UnsafeRow ].getSizeInBytes.toLong).sum
102+ case _ =>
103+ throw new SparkException (" [BUG] BroadcastMode.transform returned unexpected " +
104+ s " type: ${relation.getClass.getName}" )
105+ }
106+
107+ longMetric(" dataSize" ) += dataSize
108+ if (dataSize >= (8L << 30 )) {
109+ throw new SparkException (
110+ s " Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30 } GB " )
111+ }
112+
113+ val beforeBroadcast = System .nanoTime()
114+ longMetric(" buildTime" ) += NANOSECONDS .toMillis(beforeBroadcast - beforeBuild)
115+
116+ // Broadcast the relation
117+ val broadcasted = sparkContext.broadcast(relation)
118+ longMetric(" broadcastTime" ) += NANOSECONDS .toMillis(
119+ System .nanoTime() - beforeBroadcast)
120+
121+ SQLMetrics .postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
122+ broadcasted
123+ } catch {
124+ // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
125+ // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
126+ // will catch this exception and re-throw the wrapped fatal throwable.
127+ case oe : OutOfMemoryError =>
128+ throw new SparkFatalException (
129+ new OutOfMemoryError (" Not enough memory to build and broadcast the table to all " +
130+ " worker nodes. As a workaround, you can either disable broadcast by setting " +
131+ s " ${SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key} to -1 or increase the spark " +
132+ s " driver memory by setting ${SparkLauncher .DRIVER_MEMORY } to a higher value. " )
133+ .initCause(oe.getCause))
134+ case e if ! NonFatal (e) =>
135+ throw new SparkFatalException (e)
99136 }
100-
101- longMetric(" dataSize" ) += dataSize
102- if (dataSize >= (8L << 30 )) {
103- throw new SparkException (
104- s " Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30 } GB " )
105- }
106-
107- val beforeBroadcast = System .nanoTime()
108- longMetric(" buildTime" ) += NANOSECONDS .toMillis(beforeBroadcast - beforeBuild)
109-
110- // Broadcast the relation
111- val broadcasted = sparkContext.broadcast(relation)
112- longMetric(" broadcastTime" ) += NANOSECONDS .toMillis(System .nanoTime() - beforeBroadcast)
113-
114- SQLMetrics .postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
115- broadcasted
116- } catch {
117- // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
118- // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
119- // will catch this exception and re-throw the wrapped fatal throwable.
120- case oe : OutOfMemoryError =>
121- throw new SparkFatalException (
122- new OutOfMemoryError (s " Not enough memory to build and broadcast the table to " +
123- s " all worker nodes. As a workaround, you can either disable broadcast by setting " +
124- s " ${SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key} to -1 or increase the spark driver " +
125- s " memory by setting ${SparkLauncher .DRIVER_MEMORY } to a higher value " )
126- .initCause(oe.getCause))
127- case e if ! NonFatal (e) =>
128- throw new SparkFatalException (e)
129137 }
130138 }
131- }(BroadcastExchangeExec .executionContext)
139+ }
140+ BroadcastExchangeExec .executionContext.submit[broadcast.Broadcast [Any ]](task)
132141 }
133142
134143 override protected def doPrepare (): Unit = {
@@ -143,14 +152,20 @@ case class BroadcastExchangeExec(
143152
144153 override protected [sql] def doExecuteBroadcast [T ](): broadcast.Broadcast [T ] = {
145154 try {
146- ThreadUtils .awaitResult(relationFuture, timeout ).asInstanceOf [broadcast.Broadcast [T ]]
155+ relationFuture.get(timeout.toSeconds, TimeUnit . SECONDS ).asInstanceOf [broadcast.Broadcast [T ]]
147156 } catch {
148157 case ex : TimeoutException =>
149- logError(s " Could not execute broadcast in ${timeout.toSeconds} secs. " , ex)
158+ if (! relationFuture.isDone) {
159+ sparkContext.cancelJobGroup(runId.toString)
160+ relationFuture.cancel(true )
161+ }
150162 throw new SparkException (s " Could not execute broadcast in ${timeout.toSeconds} secs. " +
151163 s " You can increase the timeout for broadcasts via ${SQLConf .BROADCAST_TIMEOUT .key} or " +
152164 s " disable broadcast join by setting ${SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key} to -1 " ,
153165 ex)
166+ case NonFatal (ex) =>
167+ throw new SparkException (" Exception thrown in Future.get: " , ex)
168+
154169 }
155170 }
156171}
0 commit comments