Skip to content

Commit 84f6e57

Browse files
committed
Don't throw exception when spark.executor.cores not set. fix other UTs
1 parent cc2befb commit 84f6e57

File tree

2 files changed

+34
-20
lines changed

2 files changed

+34
-20
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2667,23 +2667,25 @@ object SparkContext extends Logging {
26672667

26682668
// SPARK-26340: Ensure that executor's core num meets at least one task requirement.
26692669
def checkCpusPerTask(
2670-
executorCoreNum: Int = sc.conf.get(EXECUTOR_CORES),
2671-
clusterMode: Boolean = true): Unit = {
2670+
clusterMode: Boolean,
2671+
maxCoresPerExecutor: Option[Int]): Unit = {
26722672
val cpusPerTask = sc.conf.get(CPUS_PER_TASK)
2673-
if (executorCoreNum < cpusPerTask) {
2674-
val message = if (clusterMode) {
2675-
s"${CPUS_PER_TASK.key} must be <= ${EXECUTOR_CORES.key} when run on $master."
2676-
} else {
2677-
s"Only $executorCoreNum cores available per executor when run on $master," +
2678-
s" and ${CPUS_PER_TASK.key} must be <= it."
2673+
if (clusterMode && sc.conf.contains(EXECUTOR_CORES)) {
2674+
if (sc.conf.get(EXECUTOR_CORES) < cpusPerTask) {
2675+
throw new SparkException(s"${CPUS_PER_TASK.key}" +
2676+
s" must be <= ${EXECUTOR_CORES.key} when run on $master.")
2677+
}
2678+
} else if (maxCoresPerExecutor.isDefined) {
2679+
if (maxCoresPerExecutor.get < cpusPerTask) {
2680+
throw new SparkException(s"Only ${maxCoresPerExecutor.get} cores available per executor" +
2681+
s" when run on $master, and ${CPUS_PER_TASK.key} must be <= it.")
26792682
}
2680-
throw new SparkException(message)
26812683
}
26822684
}
26832685

26842686
master match {
26852687
case "local" =>
2686-
checkCpusPerTask(1, clusterMode = false)
2688+
checkCpusPerTask(clusterMode = false, Some(1))
26872689
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
26882690
val backend = new LocalSchedulerBackend(sc.getConf, scheduler, 1)
26892691
scheduler.initialize(backend)
@@ -2696,7 +2698,7 @@ object SparkContext extends Logging {
26962698
if (threadCount <= 0) {
26972699
throw new SparkException(s"Asked to run locally with $threadCount threads")
26982700
}
2699-
checkCpusPerTask(threadCount, clusterMode = false)
2701+
checkCpusPerTask(clusterMode = false, Some(threadCount))
27002702
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
27012703
val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount)
27022704
scheduler.initialize(backend)
@@ -2707,22 +2709,22 @@ object SparkContext extends Logging {
27072709
// local[*, M] means the number of cores on the computer with M failures
27082710
// local[N, M] means exactly N threads with M failures
27092711
val threadCount = if (threads == "*") localCpuCount else threads.toInt
2710-
checkCpusPerTask(threadCount, clusterMode = false)
2712+
checkCpusPerTask(clusterMode = false, Some(threadCount))
27112713
val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true)
27122714
val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount)
27132715
scheduler.initialize(backend)
27142716
(backend, scheduler)
27152717

27162718
case SPARK_REGEX(sparkUrl) =>
2717-
checkCpusPerTask()
2719+
checkCpusPerTask(clusterMode = true, None)
27182720
val scheduler = new TaskSchedulerImpl(sc)
27192721
val masterUrls = sparkUrl.split(",").map("spark://" + _)
27202722
val backend = new StandaloneSchedulerBackend(scheduler, sc, masterUrls)
27212723
scheduler.initialize(backend)
27222724
(backend, scheduler)
27232725

27242726
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
2725-
checkCpusPerTask()
2727+
checkCpusPerTask(clusterMode = true, Some(coresPerSlave.toInt))
27262728
// Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
27272729
val memoryPerSlaveInt = memoryPerSlave.toInt
27282730
if (sc.executorMemory > memoryPerSlaveInt) {
@@ -2743,7 +2745,7 @@ object SparkContext extends Logging {
27432745
(backend, scheduler)
27442746

27452747
case masterUrl =>
2746-
checkCpusPerTask()
2748+
checkCpusPerTask(clusterMode = true, None)
27472749
val cm = getClusterManager(masterUrl) match {
27482750
case Some(clusterMgr) => clusterMgr
27492751
case None => throw new SparkException("Could not parse Master URL: '" + master + "'")

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
7777
}
7878

7979
def setupScheduler(confs: (String, String)*): TaskSchedulerImpl = {
80-
val conf = new SparkConf().setMaster("local").setAppName("TaskSchedulerImplSuite")
80+
setupSchedulerWithMaster("local", confs: _*)
81+
}
82+
83+
def setupSchedulerWithMaster(master: String, confs: (String, String)*): TaskSchedulerImpl = {
84+
val conf = new SparkConf().setMaster(master).setAppName("TaskSchedulerImplSuite")
8185
confs.foreach { case (k, v) => conf.set(k, v) }
8286
sc = new SparkContext(conf)
8387
taskScheduler = new TaskSchedulerImpl(sc)
@@ -155,7 +159,9 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
155159

156160
test("Scheduler correctly accounts for multiple CPUs per task") {
157161
val taskCpus = 2
158-
val taskScheduler = setupScheduler(config.CPUS_PER_TASK.key -> taskCpus.toString)
162+
val taskScheduler = setupSchedulerWithMaster(
163+
s"local[$taskCpus]",
164+
config.CPUS_PER_TASK.key -> taskCpus.toString)
159165
// Give zero core offers. Should not generate any tasks
160166
val zeroCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 0),
161167
new WorkerOffer("executor1", "host1", 0))
@@ -185,7 +191,9 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
185191

186192
test("Scheduler does not crash when tasks are not serializable") {
187193
val taskCpus = 2
188-
val taskScheduler = setupScheduler(config.CPUS_PER_TASK.key -> taskCpus.toString)
194+
val taskScheduler = setupSchedulerWithMaster(
195+
s"local[$taskCpus]",
196+
config.CPUS_PER_TASK.key -> taskCpus.toString)
189197
val numFreeCores = 1
190198
val taskSet = new TaskSet(
191199
Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null)
@@ -1241,7 +1249,9 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
12411249

12421250
test("don't schedule for a barrier taskSet if available slots are less than pending tasks") {
12431251
val taskCpus = 2
1244-
val taskScheduler = setupScheduler(config.CPUS_PER_TASK.key -> taskCpus.toString)
1252+
val taskScheduler = setupSchedulerWithMaster(
1253+
s"local[$taskCpus]",
1254+
config.CPUS_PER_TASK.key -> taskCpus.toString)
12451255

12461256
val numFreeCores = 3
12471257
val workerOffers = IndexedSeq(
@@ -1258,7 +1268,9 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
12581268

12591269
test("schedule tasks for a barrier taskSet if all tasks can be launched together") {
12601270
val taskCpus = 2
1261-
val taskScheduler = setupScheduler(config.CPUS_PER_TASK.key -> taskCpus.toString)
1271+
val taskScheduler = setupSchedulerWithMaster(
1272+
s"local[$taskCpus]",
1273+
config.CPUS_PER_TASK.key -> taskCpus.toString)
12621274

12631275
val numFreeCores = 3
12641276
val workerOffers = IndexedSeq(

0 commit comments

Comments
 (0)