Skip to content

Commit 763d1bc

Browse files
committed
update as per review
1 parent 11fffca commit 763d1bc

File tree

1 file changed

+70
-66
lines changed

1 file changed

+70
-66
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala

Lines changed: 70 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -75,72 +75,76 @@ case class BroadcastExchangeExec(
7575
private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
7676
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
7777
sqlContext.sparkSession, BroadcastExchangeExec.executionContext) {
78-
try {
79-
// Setup a job group here so later it may get cancelled by groupId if necessary.
80-
sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",
81-
interruptOnCancel = true)
82-
val beforeCollect = System.nanoTime()
83-
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
84-
val (numRows, input) = child.executeCollectIterator()
85-
if (numRows >= 512000000) {
86-
throw new SparkException(
87-
s"Cannot broadcast the table with 512 million or more rows: $numRows rows")
88-
}
89-
90-
val beforeBuild = System.nanoTime()
91-
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)
92-
93-
// Construct the relation.
94-
val relation = mode.transform(input, Some(numRows))
95-
96-
val dataSize = relation match {
97-
case map: HashedRelation =>
98-
map.estimatedSize
99-
case arr: Array[InternalRow] =>
100-
arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
101-
case _ =>
102-
throw new SparkException("[BUG] BroadcastMode.transform returned unexpected " +
103-
s"type: ${relation.getClass.getName}")
104-
}
105-
106-
longMetric("dataSize") += dataSize
107-
if (dataSize >= (8L << 30)) {
108-
throw new SparkException(
109-
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
110-
}
111-
112-
val beforeBroadcast = System.nanoTime()
113-
longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)
114-
115-
// Broadcast the relation
116-
val broadcasted = sparkContext.broadcast(relation)
117-
longMetric("broadcastTime") += NANOSECONDS.toMillis(
118-
System.nanoTime() - beforeBroadcast)
119-
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
120-
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
121-
promise.success(broadcasted)
122-
broadcasted
123-
} catch {
124-
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
125-
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
126-
// will catch this exception and re-throw the wrapped fatal throwable.
127-
case oe: OutOfMemoryError =>
128-
val ex = new SparkFatalException(
129-
new OutOfMemoryError("Not enough memory to build and broadcast the table to all " +
130-
"worker nodes. As a workaround, you can either disable broadcast by setting " +
131-
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " +
132-
s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.")
133-
.initCause(oe.getCause))
134-
promise.failure(ex)
135-
throw ex
136-
case e if !NonFatal(e) =>
137-
val ex = new SparkFatalException(e)
138-
promise.failure(ex)
139-
throw ex
140-
case e: Throwable =>
141-
promise.failure(e)
142-
throw e
143-
}
78+
doBroadcast
79+
}
80+
}
81+
82+
private def doBroadcast = {
83+
try {
84+
// Setup a job group here so later it may get cancelled by groupId if necessary.
85+
sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",
86+
interruptOnCancel = true)
87+
val beforeCollect = System.nanoTime()
88+
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
89+
val (numRows, input) = child.executeCollectIterator()
90+
if (numRows >= 512000000) {
91+
throw new SparkException(
92+
s"Cannot broadcast the table with 512 million or more rows: $numRows rows")
93+
}
94+
95+
val beforeBuild = System.nanoTime()
96+
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)
97+
98+
// Construct the relation.
99+
val relation = mode.transform(input, Some(numRows))
100+
101+
val dataSize = relation match {
102+
case map: HashedRelation =>
103+
map.estimatedSize
104+
case arr: Array[InternalRow] =>
105+
arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
106+
case _ =>
107+
throw new SparkException("[BUG] BroadcastMode.transform returned unexpected " +
108+
s"type: ${relation.getClass.getName}")
109+
}
110+
111+
longMetric("dataSize") += dataSize
112+
if (dataSize >= (8L << 30)) {
113+
throw new SparkException(
114+
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
115+
}
116+
117+
val beforeBroadcast = System.nanoTime()
118+
longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)
119+
120+
// Broadcast the relation
121+
val broadcasted = sparkContext.broadcast(relation)
122+
longMetric("broadcastTime") += NANOSECONDS.toMillis(
123+
System.nanoTime() - beforeBroadcast)
124+
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
125+
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
126+
promise.success(broadcasted)
127+
broadcasted
128+
} catch {
129+
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
130+
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
131+
// will catch this exception and re-throw the wrapped fatal throwable.
132+
case oe: OutOfMemoryError =>
133+
val ex = new SparkFatalException(
134+
new OutOfMemoryError("Not enough memory to build and broadcast the table to all " +
135+
"worker nodes. As a workaround, you can either disable broadcast by setting " +
136+
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " +
137+
s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.")
138+
.initCause(oe.getCause))
139+
promise.failure(ex)
140+
throw ex
141+
case e if !NonFatal(e) =>
142+
val ex = new SparkFatalException(e)
143+
promise.failure(ex)
144+
throw ex
145+
case e: Throwable =>
146+
promise.failure(e)
147+
throw e
144148
}
145149
}
146150

0 commit comments

Comments
 (0)