Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2019,7 +2019,10 @@ class SQLConf extends Serializable with Logging {

def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD)

def broadcastTimeout: Long = getConf(BROADCAST_TIMEOUT)
def broadcastTimeout: Long = {
val timeoutValue = getConf(BROADCAST_TIMEOUT)
if (timeoutValue < 0) Long.MaxValue else timeoutValue
}

def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.sql.execution.exchange

import java.util.concurrent.TimeoutException
import java.util.UUID
import java.util.concurrent._

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.NANOSECONDS
import scala.util.control.NonFatal

import org.apache.spark.{broadcast, SparkException}
Expand All @@ -43,6 +44,8 @@ case class BroadcastExchangeExec(
mode: BroadcastMode,
child: SparkPlan) extends Exchange {

private val runId: UUID = UUID.randomUUID

override lazy val metrics = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"),
Expand All @@ -56,79 +59,79 @@ case class BroadcastExchangeExec(
}

@transient
private val timeout: Duration = {
val timeoutValue = sqlContext.conf.broadcastTimeout
if (timeoutValue < 0) {
Duration.Inf
} else {
timeoutValue.seconds
}
}
private val timeout: Long = SQLConf.get.broadcastTimeout

@transient
private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
// broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
try {
val beforeCollect = System.nanoTime()
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
val (numRows, input) = child.executeCollectIterator()
if (numRows >= 512000000) {
throw new SparkException(
s"Cannot broadcast the table with 512 million or more rows: $numRows rows")
}

val beforeBuild = System.nanoTime()
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)

// Construct the relation.
val relation = mode.transform(input, Some(numRows))

val dataSize = relation match {
case map: HashedRelation =>
map.estimatedSize
case arr: Array[InternalRow] =>
arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
case _ =>
throw new SparkException("[BUG] BroadcastMode.transform returned unexpected type: " +
relation.getClass.getName)
}

longMetric("dataSize") += dataSize
if (dataSize >= (8L << 30)) {
throw new SparkException(
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
val task = new Callable[broadcast.Broadcast[Any]]() {
override def call(): broadcast.Broadcast[Any] = {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
try {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a comment to explain why we set up a job group here. There is no other public API that can cancel a specific job AFAIK.

Copy link
Contributor

@yeshengm yeshengm Oct 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why can't we just inherit the job group id of the outside thread so that when the SQL statement was cancelled, these broadcast sub-jobs can be cancelled as a whole?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a good idea. We should only set the job group if there is no one outside.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that cancelling the broadcast job cause the outer main job to cancel?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only set the job group if there is no one outside.

and I guess it would be a partial fix?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's always better to have fewer configs if possible. And I don't think we can override the job group id here if the config is true, as this is used to cancel broadcast after timeout.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon @jiangxb1987 @yeshengm What‘s your opinion of my idea?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am discussing about multiple job group support which will fundamentally fix all these problems. This is actually a general problem that's not speicfic to SQL broadcast here only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon Could you please tell me where you are discussing? I also want to make a little contribution.

Copy link
Member

@HyukjinKwon HyukjinKwon Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I am discussing offline first. I will send out an email or JIRA soon for more open discussion soon.

interruptOnCancel = true)
val beforeCollect = System.nanoTime()
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
val (numRows, input) = child.executeCollectIterator()
if (numRows >= 512000000) {
throw new SparkException(
s"Cannot broadcast the table with 512 million or more rows: $numRows rows")
}

val beforeBuild = System.nanoTime()
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)

// Construct the relation.
val relation = mode.transform(input, Some(numRows))

val dataSize = relation match {
case map: HashedRelation =>
map.estimatedSize
case arr: Array[InternalRow] =>
arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
case _ =>
throw new SparkException("[BUG] BroadcastMode.transform returned unexpected " +
s"type: ${relation.getClass.getName}")
}

longMetric("dataSize") += dataSize
if (dataSize >= (8L << 30)) {
throw new SparkException(
s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB")
}

val beforeBroadcast = System.nanoTime()
longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)

// Broadcast the relation
val broadcasted = sparkContext.broadcast(relation)
longMetric("broadcastTime") += NANOSECONDS.toMillis(
System.nanoTime() - beforeBroadcast)

SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
broadcasted
} catch {
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
// will catch this exception and re-throw the wrapped fatal throwable.
case oe: OutOfMemoryError =>
throw new SparkFatalException(
new OutOfMemoryError("Not enough memory to build and broadcast the table to all " +
"worker nodes. As a workaround, you can either disable broadcast by setting " +
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " +
s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.")
.initCause(oe.getCause))
case e if !NonFatal(e) =>
throw new SparkFatalException(e)
}

val beforeBroadcast = System.nanoTime()
longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild)

// Broadcast the relation
val broadcasted = sparkContext.broadcast(relation)
longMetric("broadcastTime") += NANOSECONDS.toMillis(System.nanoTime() - beforeBroadcast)

SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
broadcasted
} catch {
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
// will catch this exception and re-throw the wrapped fatal throwable.
case oe: OutOfMemoryError =>
throw new SparkFatalException(
new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " +
s"all worker nodes. As a workaround, you can either disable broadcast by setting " +
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " +
s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value")
.initCause(oe.getCause))
case e if !NonFatal(e) =>
throw new SparkFatalException(e)
}
}
}(BroadcastExchangeExec.executionContext)
}
BroadcastExchangeExec.executionContext.submit[broadcast.Broadcast[Any]](task)
}

override protected def doPrepare(): Unit = {
Expand All @@ -143,11 +146,15 @@ case class BroadcastExchangeExec(

override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
try {
ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]]
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
} catch {
case ex: TimeoutException =>
logError(s"Could not execute broadcast in ${timeout.toSeconds} secs.", ex)
throw new SparkException(s"Could not execute broadcast in ${timeout.toSeconds} secs. " +
logError(s"Could not execute broadcast in $timeout secs.", ex)
if (!relationFuture.isDone) {
sparkContext.cancelJobGroup(runId.toString)
relationFuture.cancel(true)
}
throw new SparkException(s"Could not execute broadcast in $timeout secs. " +
s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " +
s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1",
ex)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import java.util.concurrent.{CountDownLatch, TimeUnit}

import org.apache.spark.SparkException
import org.apache.spark.scheduler._
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.HashedRelation
import org.apache.spark.sql.functions.broadcast
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

class BroadcastExchangeSuite extends SparkPlanTest with SharedSQLContext {

import testImplicits._

test("BroadcastExchange should cancel the job group if timeout") {
val startLatch = new CountDownLatch(1)
val endLatch = new CountDownLatch(1)
var jobEvents: Seq[SparkListenerEvent] = Seq.empty[SparkListenerEvent]
spark.sparkContext.addSparkListener(new SparkListener {
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
jobEvents :+= jobEnd
endLatch.countDown()
}
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
jobEvents :+= jobStart
startLatch.countDown()
}
})

withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> "0") {
val df = spark.range(100).join(spark.range(15).as[Long].map { x =>
Thread.sleep(5000)
x
}).where("id = value")

// get the exchange physical plan
val hashExchange = df.queryExecution.executedPlan
.collect { case p: BroadcastExchangeExec => p }.head

// materialize the future and wait for the job being scheduled
hashExchange.prepare()
startLatch.await(5, TimeUnit.SECONDS)

// check timeout exception is captured by just executing the exchange
val hashEx = intercept[SparkException] {
hashExchange.executeBroadcast[HashedRelation]()
}
assert(hashEx.getMessage.contains("Could not execute broadcast"))

// wait for cancel is posted and then check the results.
endLatch.await(5, TimeUnit.SECONDS)
assert(jobCancelled())
}

def jobCancelled(): Boolean = {
val events = jobEvents.toArray
val hasStart = events(0).isInstanceOf[SparkListenerJobStart]
val hasCancelled = events(1).asInstanceOf[SparkListenerJobEnd].jobResult
.asInstanceOf[JobFailed].exception.getMessage.contains("cancelled job group")
events.length == 2 && hasStart && hasCancelled
}
}

test("set broadcastTimeout to -1") {
withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> "-1") {
val df = spark.range(1).toDF()
val joinDF = df.join(broadcast(df), "id")
val broadcastExchangeExec = joinDF.queryExecution.executedPlan
.collect { case p: BroadcastExchangeExec => p }
assert(broadcastExchangeExec.size == 1, "one and only BroadcastExchangeExec")
assert(joinDF.collect().length == 1)
}
}
}