Skip to content

Commit d2bbdbe

Browse files
committed
cancel timeout broadcast execution
1 parent 8b0bdaa commit d2bbdbe

File tree

2 files changed

+158
-62
lines changed

2 files changed

+158
-62
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala

Lines changed: 77 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.sql.execution.exchange
1919

20-
import java.util.concurrent.TimeoutException
20+
import java.util.UUID
21+
import java.util.concurrent._
2122

22-
import scala.concurrent.{ExecutionContext, Future}
23+
import scala.concurrent.ExecutionContext
2324
import scala.concurrent.duration._
2425
import scala.util.control.NonFatal
2526

@@ -43,6 +44,8 @@ case class BroadcastExchangeExec(
4344
mode: BroadcastMode,
4445
child: SparkPlan) extends Exchange {
4546

47+
private val runId: UUID = UUID.randomUUID
48+
4649
override lazy val metrics = Map(
4750
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
4851
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"),
@@ -67,68 +70,74 @@ case class BroadcastExchangeExec(
6770

6871
@transient
6972
private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
70-
// broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
73+
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
7174
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
72-
Future {
73-
// This will run in another thread. Set the execution id so that we can connect these jobs
74-
// with the correct execution.
75-
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
76-
try {
77-
val beforeCollect = System.nanoTime()
78-
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
79-
val (numRows, input) = child.executeCollectIterator()
80-
if (numRows >= 512000000) {
81-
throw new SparkException(
82-
s"Cannot broadcast the table with 512 million or more rows: $numRows rows")
83-
}
84-
85-
val beforeBuild = System.nanoTime()
86-
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)
87-
88-
// Construct the relation.
89-
val relation = mode.transform(input, Some(numRows))
90-
91-
val dataSize = relation match {
92-
case map: HashedRelation =>
93-
map.estimatedSize
94-
case arr: Array[InternalRow] =>
95-
arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
96-
case _ =>
97-
throw new SparkException("[BUG] BroadcastMode.transform returned unexpected type: " +
98-
relation.getClass.getName)
75+
val task = new Callable[broadcast.Broadcast[Any]]() {
76+
override def call(): broadcast.Broadcast[Any] = {
77+
// This will run in another thread. Set the execution id so that we can connect these jobs
78+
// with the correct execution.
79+
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
80+
try {
81+
sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",
82+
interruptOnCancel = true)
83+
val beforeCollect = System.nanoTime()
84+
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
85+
val (numRows, input) = child.executeCollectIterator()
86+
if (numRows >= 512000000) {
87+
throw new SparkException(
88+
s"Cannot broadcast the table with 512 million or more rows: $numRows rows")
89+
}
90+
91+
val beforeBuild = System.nanoTime()
92+
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)
93+
94+
// Construct the relation.
95+
val relation = mode.transform(input, Some(numRows))
96+
97+
val dataSize = relation match {
98+
case map: HashedRelation =>
99+
map.estimatedSize
100+
case arr: Array[InternalRow] =>
101+
arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
102+
case _ =>
103+
throw new SparkException("[BUG] BroadcastMode.transform returned unexpected " +
104+
s"type: ${relation.getClass.getName}")
105+
}
106+
107+
longMetric("dataSize") += dataSize
108+
if (dataSize >= (8L << 30)) {
109+
throw new SparkException(
110+
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
111+
}
112+
113+
val beforeBroadcast = System.nanoTime()
114+
longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)
115+
116+
// Broadcast the relation
117+
val broadcasted = sparkContext.broadcast(relation)
118+
longMetric("broadcastTime") += NANOSECONDS.toMillis(
119+
System.nanoTime() - beforeBroadcast)
120+
121+
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
122+
broadcasted
123+
} catch {
124+
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
125+
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
126+
// will catch this exception and re-throw the wrapped fatal throwable.
127+
case oe: OutOfMemoryError =>
128+
throw new SparkFatalException(
129+
new OutOfMemoryError("Not enough memory to build and broadcast the table to all " +
130+
"worker nodes. As a workaround, you can either disable broadcast by setting " +
131+
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " +
132+
s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.")
133+
.initCause(oe.getCause))
134+
case e if !NonFatal(e) =>
135+
throw new SparkFatalException(e)
99136
}
100-
101-
longMetric("dataSize") += dataSize
102-
if (dataSize >= (8L << 30)) {
103-
throw new SparkException(
104-
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
105-
}
106-
107-
val beforeBroadcast = System.nanoTime()
108-
longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)
109-
110-
// Broadcast the relation
111-
val broadcasted = sparkContext.broadcast(relation)
112-
longMetric("broadcastTime") += NANOSECONDS.toMillis(System.nanoTime() - beforeBroadcast)
113-
114-
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
115-
broadcasted
116-
} catch {
117-
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
118-
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
119-
// will catch this exception and re-throw the wrapped fatal throwable.
120-
case oe: OutOfMemoryError =>
121-
throw new SparkFatalException(
122-
new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " +
123-
s"all worker nodes. As a workaround, you can either disable broadcast by setting " +
124-
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " +
125-
s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value")
126-
.initCause(oe.getCause))
127-
case e if !NonFatal(e) =>
128-
throw new SparkFatalException(e)
129137
}
130138
}
131-
}(BroadcastExchangeExec.executionContext)
139+
}
140+
BroadcastExchangeExec.executionContext.submit[broadcast.Broadcast[Any]](task)
132141
}
133142

134143
override protected def doPrepare(): Unit = {
@@ -143,14 +152,20 @@ case class BroadcastExchangeExec(
143152

144153
override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
145154
try {
146-
ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]]
155+
relationFuture.get(timeout.toSeconds, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
147156
} catch {
148157
case ex: TimeoutException =>
149-
logError(s"Could not execute broadcast in ${timeout.toSeconds} secs.", ex)
158+
if (!relationFuture.isDone) {
159+
sparkContext.cancelJobGroup(runId.toString)
160+
relationFuture.cancel(true)
161+
}
150162
throw new SparkException(s"Could not execute broadcast in ${timeout.toSeconds} secs. " +
151163
s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " +
152164
s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1",
153165
ex)
166+
case NonFatal(ex) =>
167+
throw new SparkException("Exception thrown in Future.get: ", ex)
168+
154169
}
155170
}
156171
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import java.util.concurrent.{CountDownLatch, TimeUnit}
21+
22+
import org.apache.spark.SparkException
23+
import org.apache.spark.scheduler._
24+
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
25+
import org.apache.spark.sql.execution.joins.HashedRelation
26+
import org.apache.spark.sql.internal.SQLConf
27+
import org.apache.spark.sql.test.SharedSQLContext
28+
29+
class BroadcastExchangeSuite extends SparkPlanTest with SharedSQLContext {
30+
31+
import testImplicits._
32+
33+
test("BroadcastExchange should cancel the job group if timeout") {
34+
val startLatch = new CountDownLatch(1)
35+
val endLatch = new CountDownLatch(1)
36+
var jobEvents: Seq[SparkListenerEvent] = Seq.empty[SparkListenerEvent]
37+
spark.sparkContext.addSparkListener(new SparkListener {
38+
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
39+
jobEvents +:= jobEnd
40+
endLatch.countDown()
41+
}
42+
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
43+
jobEvents +:= jobStart
44+
startLatch.countDown()
45+
}
46+
})
47+
48+
withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> "0") {
49+
val df = spark.range(100).join(spark.range(15).as[Long].map { x =>
50+
Thread.sleep(5000)
51+
x
52+
}).where("id = value")
53+
54+
// get the exchange physical plan
55+
val hashExchange = df.queryExecution.executedPlan
56+
.collect { case p: BroadcastExchangeExec => p }.head
57+
58+
// materialize the future and wait for the job being scheduled
59+
hashExchange.prepare()
60+
startLatch.await(5, TimeUnit.SECONDS)
61+
62+
// check timeout exception is captured by just executing the exchange
63+
val hashEx = intercept[SparkException] {
64+
hashExchange.executeBroadcast[HashedRelation]()
65+
}
66+
assert(hashEx.getMessage.contains("Could not execute broadcast"))
67+
68+
// wait for cancel is posted and then check the results.
69+
endLatch.await(5, TimeUnit.SECONDS)
70+
assert(jobCancelled())
71+
}
72+
73+
def jobCancelled(): Boolean = {
74+
val events = jobEvents.toArray
75+
val hasStart = events(1).isInstanceOf[SparkListenerJobStart]
76+
val hasCancelled = events(0).asInstanceOf[SparkListenerJobEnd].jobResult
77+
.asInstanceOf[JobFailed].exception.getMessage.contains("cancelled job group")
78+
events.length == 2 && hasStart && hasCancelled
79+
}
80+
}
81+
}

0 commit comments

Comments
 (0)