diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala index 7b29b40668def..9cba81943dac8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala @@ -22,6 +22,7 @@ import scala.util.Random import org.apache.spark.{ExecutorAllocationClient, SparkConf} import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageSubmitted} import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, Utils} @@ -49,7 +50,7 @@ private[streaming] class ExecutorAllocationManager( receiverTracker: ReceiverTracker, conf: SparkConf, batchDurationMs: Long, - clock: Clock) extends StreamingListener with Logging { + clock: Clock) extends SparkListener with StreamingListener with Logging { import ExecutorAllocationManager._ @@ -61,13 +62,17 @@ private[streaming] class ExecutorAllocationManager( private val minNumExecutors = conf.getInt( MIN_EXECUTORS_KEY, math.max(1, receiverTracker.numReceivers)) - private val maxNumExecutors = conf.getInt(MAX_EXECUTORS_KEY, Integer.MAX_VALUE) + private var maxNumExecutors = conf.getInt(MAX_EXECUTORS_KEY, Integer.MAX_VALUE) private val timer = new RecurringTimer(clock, scalingIntervalSecs * 1000, _ => manageAllocation(), "streaming-executor-allocation-manager") @volatile private var batchProcTimeSum = 0L @volatile private var batchProcTimeCount = 0 + private val conCurrentJobs = conf.getInt("spark.streaming.concurrentJobs", 1) + private val cpuPerTask = conf.getInt("spark.task.cpus", 1) + private val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt).getOrElse(1) + validateSettings() def start(): Unit = { @@ -183,6 +188,15 @@ private[streaming] class ExecutorAllocationManager( batchCompleted.batchInfo.processingDelay.foreach(addBatchProcTime) } } + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + val numTasks = stageSubmitted.stageInfo.numTasks + val curNumExecutors = numTasks * conCurrentJobs * cpuPerTask / coresPerExecutor + if(curNumExecutors > maxNumExecutors) { + maxNumExecutors = curNumExecutors + logInfo(s"change maxNumExecutor to [$maxNumExecutors]") + } + } } private[streaming] object ExecutorAllocationManager extends Logging {