Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@

package org.apache.spark

import scala.language.implicitConversions

import java.io._
import java.net.URI
import java.util.concurrent.atomic.AtomicInteger
import java.util.{Properties, UUID}
import java.util.UUID.randomUUID
import scala.collection.{Map, Set}
import scala.collection.JavaConversions._
import scala.collection.generic.Growable
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.language.implicitConversions
import scala.collection.mutable.HashMap
import scala.reflect.{ClassTag, classTag}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -836,18 +838,22 @@ class SparkContext(config: SparkConf) extends Logging {
}

/**
* Return pools for fair scheduler
* TODO(xiajunluan): We should take nested pools into account
* :: DeveloperApi ::
* Return pools for fair scheduler
*/
def getAllPools: ArrayBuffer[Schedulable] = {
taskScheduler.rootPool.schedulableQueue
@DeveloperApi
def getAllPools: Seq[Schedulable] = {
// TODO(xiajunluan): We should take nested pools into account
taskScheduler.rootPool.schedulableQueue.toSeq
}

/**
* :: DeveloperApi ::
* Return the pool associated with the given name, if one exists
*/
@DeveloperApi
def getPoolForName(pool: String): Option[Schedulable] = {
taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)
Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
}

/**
Expand Down
31 changes: 16 additions & 15 deletions core/src/main/scala/org/apache/spark/scheduler/Pool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.scheduler

import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}

import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap

import org.apache.spark.Logging
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
Expand All @@ -35,18 +37,15 @@ private[spark] class Pool(
extends Schedulable
with Logging {

var schedulableQueue = new ArrayBuffer[Schedulable]
var schedulableNameToSchedulable = new HashMap[String, Schedulable]

val schedulableQueue = new ConcurrentLinkedQueue[Schedulable]
val schedulableNameToSchedulable = new ConcurrentHashMap[String, Schedulable]
var weight = initWeight
var minShare = initMinShare
var runningTasks = 0

var priority = 0

// A pool's stage id is used to break the tie in scheduling.
var stageId = -1

var name = poolName
var parent: Pool = null

Expand All @@ -60,19 +59,20 @@ private[spark] class Pool(
}

override def addSchedulable(schedulable: Schedulable) {
schedulableQueue += schedulable
schedulableNameToSchedulable(schedulable.name) = schedulable
require(schedulable != null)
schedulableQueue.add(schedulable)
schedulableNameToSchedulable.put(schedulable.name, schedulable)
schedulable.parent = this
}

override def removeSchedulable(schedulable: Schedulable) {
schedulableQueue -= schedulable
schedulableNameToSchedulable -= schedulable.name
schedulableQueue.remove(schedulable)
schedulableNameToSchedulable.remove(schedulable.name)
}

override def getSchedulableByName(schedulableName: String): Schedulable = {
if (schedulableNameToSchedulable.contains(schedulableName)) {
return schedulableNameToSchedulable(schedulableName)
if (schedulableNameToSchedulable.containsKey(schedulableName)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, man, good catch

return schedulableNameToSchedulable.get(schedulableName)
}
for (schedulable <- schedulableQueue) {
val sched = schedulable.getSchedulableByName(schedulableName)
Expand All @@ -95,11 +95,12 @@ private[spark] class Pool(
shouldRevive
}

override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
override def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator)
val sortedSchedulableQueue =
schedulableQueue.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator)
for (schedulable <- sortedSchedulableQueue) {
sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue()
sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue
}
sortedTaskSetQueue
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.scheduler

import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
Expand All @@ -28,7 +30,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
private[spark] trait Schedulable {
var parent: Pool
// child queues
def schedulableQueue: ArrayBuffer[Schedulable]
def schedulableQueue: ConcurrentLinkedQueue[Schedulable]
def schedulingMode: SchedulingMode
def weight: Int
def minShare: Int
Expand All @@ -42,5 +44,5 @@ private[spark] trait Schedulable {
def getSchedulableByName(name: String): Schedulable
def executorLost(executorId: String, host: String): Unit
def checkSpeculatableTasks(): Boolean
def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager]
def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager]
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ private[spark] class TaskSchedulerImpl(
// Build a list of tasks to assign to each worker.
val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
val availableCpus = shuffledOffers.map(o => o.cores).toArray
val sortedTaskSets = rootPool.getSortedTaskSetQueue()
val sortedTaskSets = rootPool.getSortedTaskSetQueue
for (taskSet <- sortedTaskSets) {
logDebug("parentName: %s, name: %s, runningTasks: %s".format(
taskSet.parent.name, taskSet.name, taskSet.runningTasks))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin
}

def resourceOffer(rootPool: Pool): Int = {
val taskSetQueue = rootPool.getSortedTaskSetQueue()
val taskSetQueue = rootPool.getSortedTaskSetQueue
/* Just for Test*/
for (manager <- taskSetQueue) {
logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(
Expand Down