diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7f577f015973..301d94800fc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2019,7 +2019,10 @@ class SQLConf extends Serializable with Logging { def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) - def broadcastTimeout: Long = getConf(BROADCAST_TIMEOUT) + def broadcastTimeout: Long = { + val timeoutValue = getConf(BROADCAST_TIMEOUT) + if (timeoutValue < 0) Long.MaxValue else timeoutValue + } def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index aa0dd1d62840..8017188eb165 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.exchange -import java.util.concurrent.TimeoutException +import java.util.UUID +import java.util.concurrent._ -import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext +import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal import org.apache.spark.{broadcast, SparkException} @@ -43,6 +44,8 @@ case class BroadcastExchangeExec( mode: BroadcastMode, child: SparkPlan) extends Exchange { + private val runId: UUID = UUID.randomUUID + override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"), @@ -56,79 +59,79 @@ case class BroadcastExchangeExec( } @transient - private val timeout: Duration = { - val timeoutValue = sqlContext.conf.broadcastTimeout - if (timeoutValue < 0) { - Duration.Inf - } else { - timeoutValue.seconds - } - } + private val timeout: Long = SQLConf.get.broadcastTimeout @transient private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. + // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - Future { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { - try { - val beforeCollect = System.nanoTime() - // Use executeCollect/executeCollectIterator to avoid conversion to Scala types - val (numRows, input) = child.executeCollectIterator() - if (numRows >= 512000000) { - throw new SparkException( - s"Cannot broadcast the table with 512 million or more rows: $numRows rows") - } - - val beforeBuild = System.nanoTime() - longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect) - - // Construct the relation. - val relation = mode.transform(input, Some(numRows)) - - val dataSize = relation match { - case map: HashedRelation => - map.estimatedSize - case arr: Array[InternalRow] => - arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum - case _ => - throw new SparkException("[BUG] BroadcastMode.transform returned unexpected type: " + - relation.getClass.getName) - } - - longMetric("dataSize") += dataSize - if (dataSize >= (8L << 30)) { - throw new SparkException( - s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") + val task = new Callable[broadcast.Broadcast[Any]]() { + override def call(): broadcast.Broadcast[Any] = { + // This will run in another thread. Set the execution id so that we can connect these jobs + // with the correct execution. + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { + try { + // Setup a job group here so later it may get cancelled by groupId if necessary. + sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)", + interruptOnCancel = true) + val beforeCollect = System.nanoTime() + // Use executeCollect/executeCollectIterator to avoid conversion to Scala types + val (numRows, input) = child.executeCollectIterator() + if (numRows >= 512000000) { + throw new SparkException( + s"Cannot broadcast the table with 512 million or more rows: $numRows rows") + } + + val beforeBuild = System.nanoTime() + longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect) + + // Construct the relation. + val relation = mode.transform(input, Some(numRows)) + + val dataSize = relation match { + case map: HashedRelation => + map.estimatedSize + case arr: Array[InternalRow] => + arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + case _ => + throw new SparkException("[BUG] BroadcastMode.transform returned unexpected " + + s"type: ${relation.getClass.getName}") + } + + longMetric("dataSize") += dataSize + if (dataSize >= (8L << 30)) { + throw new SparkException( + s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") + } + + val beforeBroadcast = System.nanoTime() + longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild) + + // Broadcast the relation + val broadcasted = sparkContext.broadcast(relation) + longMetric("broadcastTime") += NANOSECONDS.toMillis( + System.nanoTime() - beforeBroadcast) + + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + broadcasted + } catch { + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult + // will catch this exception and re-throw the wrapped fatal throwable. + case oe: OutOfMemoryError => + throw new SparkFatalException( + new OutOfMemoryError("Not enough memory to build and broadcast the table to all " + + "worker nodes. As a workaround, you can either disable broadcast by setting " + + s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " + + s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.") + .initCause(oe.getCause)) + case e if !NonFatal(e) => + throw new SparkFatalException(e) } - - val beforeBroadcast = System.nanoTime() - longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild) - - // Broadcast the relation - val broadcasted = sparkContext.broadcast(relation) - longMetric("broadcastTime") += NANOSECONDS.toMillis(System.nanoTime() - beforeBroadcast) - - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) - broadcasted - } catch { - // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw - // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult - // will catch this exception and re-throw the wrapped fatal throwable. - case oe: OutOfMemoryError => - throw new SparkFatalException( - new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " + - s"all worker nodes. As a workaround, you can either disable broadcast by setting " + - s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " + - s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value") - .initCause(oe.getCause)) - case e if !NonFatal(e) => - throw new SparkFatalException(e) } } - }(BroadcastExchangeExec.executionContext) + } + BroadcastExchangeExec.executionContext.submit[broadcast.Broadcast[Any]](task) } override protected def doPrepare(): Unit = { @@ -143,11 +146,15 @@ case class BroadcastExchangeExec( override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { try { - ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] } catch { case ex: TimeoutException => - logError(s"Could not execute broadcast in ${timeout.toSeconds} secs.", ex) - throw new SparkException(s"Could not execute broadcast in ${timeout.toSeconds} secs. " + + logError(s"Could not execute broadcast in $timeout secs.", ex) + if (!relationFuture.isDone) { + sparkContext.cancelJobGroup(runId.toString) + relationFuture.cancel(true) + } + throw new SparkException(s"Could not execute broadcast in $timeout secs. " + s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " + s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1", ex) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala new file mode 100644 index 000000000000..4e39df928603 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import org.apache.spark.SparkException +import org.apache.spark.scheduler._ +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec +import org.apache.spark.sql.execution.joins.HashedRelation +import org.apache.spark.sql.functions.broadcast +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class BroadcastExchangeSuite extends SparkPlanTest with SharedSQLContext { + + import testImplicits._ + + test("BroadcastExchange should cancel the job group if timeout") { + val startLatch = new CountDownLatch(1) + val endLatch = new CountDownLatch(1) + var jobEvents: Seq[SparkListenerEvent] = Seq.empty[SparkListenerEvent] + spark.sparkContext.addSparkListener(new SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + jobEvents :+= jobEnd + endLatch.countDown() + } + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobEvents :+= jobStart + startLatch.countDown() + } + }) + + withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> "0") { + val df = spark.range(100).join(spark.range(15).as[Long].map { x => + Thread.sleep(5000) + x + }).where("id = value") + + // get the exchange physical plan + val hashExchange = df.queryExecution.executedPlan + .collect { case p: BroadcastExchangeExec => p }.head + + // materialize the future and wait for the job being scheduled + hashExchange.prepare() + startLatch.await(5, TimeUnit.SECONDS) + + // check timeout exception is captured by just executing the exchange + val hashEx = intercept[SparkException] { + hashExchange.executeBroadcast[HashedRelation]() + } + assert(hashEx.getMessage.contains("Could not execute broadcast")) + + // wait for cancel is posted and then check the results. + endLatch.await(5, TimeUnit.SECONDS) + assert(jobCancelled()) + } + + def jobCancelled(): Boolean = { + val events = jobEvents.toArray + val hasStart = events(0).isInstanceOf[SparkListenerJobStart] + val hasCancelled = events(1).asInstanceOf[SparkListenerJobEnd].jobResult + .asInstanceOf[JobFailed].exception.getMessage.contains("cancelled job group") + events.length == 2 && hasStart && hasCancelled + } + } + + test("set broadcastTimeout to -1") { + withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> "-1") { + val df = spark.range(1).toDF() + val joinDF = df.join(broadcast(df), "id") + val broadcastExchangeExec = joinDF.queryExecution.executedPlan + .collect { case p: BroadcastExchangeExec => p } + assert(broadcastExchangeExec.size == 1, "one and only BroadcastExchangeExec") + assert(joinDF.collect().length == 1) + } + } +}