Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration)
if (retval != null) Some(retval) else None
}

// By default, if rack is unknown, return nothing
override def getCachedHostsForRack(rack: String): Option[Set[String]] = {
if (rack == None || rack == null) return None

YarnAllocationHandler.fetchCachedHostsForRack(rack)
}

override def postStartHook() {
val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
if (sparkContextInitialized){
Expand Down
62 changes: 31 additions & 31 deletions core/src/main/scala/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ private[spark] class MapOutputTracker extends Logging {

// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var generation: Long = 0
private val generationLock = new java.lang.Object
private var epoch: Long = 0
private val epochLock = new java.lang.Object

// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
var cacheEpoch = epoch
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]

val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
Expand Down Expand Up @@ -108,10 +108,10 @@ private[spark] class MapOutputTracker extends Logging {
def registerMapOutputs(
shuffleId: Int,
statuses: Array[MapStatus],
changeGeneration: Boolean = false) {
changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeGeneration) {
incrementGeneration()
if (changeEpoch) {
incrementEpoch()
}
}

Expand All @@ -124,7 +124,7 @@ private[spark] class MapOutputTracker extends Logging {
array(mapId) = null
}
}
incrementGeneration()
incrementEpoch()
} else {
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
Expand Down Expand Up @@ -206,58 +206,58 @@ private[spark] class MapOutputTracker extends Logging {
trackerActor = null
}

// Called on master to increment the generation number
def incrementGeneration() {
generationLock.synchronized {
generation += 1
logDebug("Increasing generation to " + generation)
// Called on master to increment the epoch number
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
logDebug("Increasing epoch to " + epoch)
}
}

// Called on master or workers to get current generation number
def getGeneration: Long = {
generationLock.synchronized {
return generation
// Called on master or workers to get current epoch number
def getEpoch: Long = {
epochLock.synchronized {
return epoch
}
}

// Called on workers to update the generation number, potentially clearing old outputs
// because of a fetch failure. (Each Mesos task calls this with the latest generation
// Called on workers to update the epoch number, potentially clearing old outputs
// because of a fetch failure. (Each worker task calls this with the latest epoch
// number on the master at the time it was created.)
def updateGeneration(newGen: Long) {
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
def updateEpoch(newEpoch: Long) {
epochLock.synchronized {
if (newEpoch > epoch) {
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
// mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
mapStatuses.clear()
generation = newGen
epoch = newEpoch
}
}
}

def getSerializedLocations(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var generationGotten: Long = -1
generationLock.synchronized {
if (generation > cacheGeneration) {
var epochGotten: Long = -1
epochLock.synchronized {
if (epoch > cacheEpoch) {
cachedSerializedStatuses.clear()
cacheGeneration = generation
cacheEpoch = epoch
}
cachedSerializedStatuses.get(shuffleId) match {
case Some(bytes) =>
return bytes
case None =>
statuses = mapStatuses(shuffleId)
generationGotten = generation
epochGotten = epoch
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
val bytes = serializeStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the generation hasn't changed while we were working
generationLock.synchronized {
if (generation == generationGotten) {
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
if (epoch == epochGotten) {
cachedSerializedStatuses(shuffleId) = bytes
}
}
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/spark/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ abstract class RDD[T: ClassManifest](
}

/**
* Get the preferred location of a split, taking into account whether the
* RDD is checkpointed or not.
* Get the preferred locations of a partition (as hostnames), taking into account whether the
* RDD is checkpointed.
*/
final def preferredLocations(split: Partition): Seq[String] = {
checkpointRDD.map(_.getPreferredLocations(split)).getOrElse {
Expand Down
19 changes: 2 additions & 17 deletions core/src/main/scala/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,7 @@ class SparkEnv (
val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
// To be set only as part of initialization of SparkContext.
// (executorId, defaultHostPort) => executorHostPort
// If executorId is NOT found, return defaultHostPort
var executorIdToHostPort: Option[(String, String) => String]) {
val metricsSystem: MetricsSystem) {

private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()

Expand All @@ -83,16 +79,6 @@ class SparkEnv (
pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
}
}

def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = {
val env = SparkEnv.get
if (env.executorIdToHostPort.isEmpty) {
// default to using host, not host port. Relevant to non cluster modes.
return defaultHostPort
}

env.executorIdToHostPort.get(executorId, defaultHostPort)
}
}

object SparkEnv extends Logging {
Expand Down Expand Up @@ -236,7 +222,6 @@ object SparkEnv extends Logging {
connectionManager,
httpFileServer,
sparkFilesDir,
metricsSystem,
None)
metricsSystem)
}
}
31 changes: 2 additions & 29 deletions core/src/main/scala/spark/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -393,41 +393,14 @@ private object Utils extends Logging {
retval
}

/*
// Used by DEBUG code : remove when all testing done
private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$")
def checkHost(host: String, message: String = "") {
// Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous !
// if (host.matches("^[0-9]+(\\.[0-9]+)*$")) {
if (ipPattern.matcher(host).matches()) {
Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message)
}
if (Utils.parseHostPort(host)._2 != 0){
Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message)
}
assert(host.indexOf(':') == -1, message)
}

// Used by DEBUG code : remove when all testing done
def checkHostPort(hostPort: String, message: String = "") {
val (host, port) = Utils.parseHostPort(hostPort)
checkHost(host)
if (port <= 0){
Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message)
}
assert(hostPort.indexOf(':') != -1, message)
}

// Used by DEBUG code : remove when all testing done
def logErrorWithStack(msg: String) {
try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
// temp code for debug
System.exit(-1)
}
*/

// Once testing is complete in various modes, replace with this ?
def checkHost(host: String, message: String = "") {}
def checkHostPort(hostPort: String, message: String = "") {}

// Used by DEBUG code : remove when all testing done
def logErrorWithStack(msg: String) {
try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/spark/deploy/DeployMessage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ private[deploy] object DeployMessages {

case class RegisteredApplication(appId: String) extends DeployMessage

// TODO(matei): replace hostPort with host
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
Utils.checkHostPort(hostPort, "Required hostport")
}
Expand Down
6 changes: 2 additions & 4 deletions core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@ private[spark] class ExecutorRunner(
val memory: Int,
val worker: ActorRef,
val workerId: String,
val hostPort: String,
val host: String,
val sparkHome: File,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not very sure about this - iirc this needs to be hostport - since we can have multiple executors in the same node.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only used to set the host below in this file (no other references), so it should be fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove references to host also ?
iirc we can infer the host from the executorId right ? This will remove all references to host when we can avoid it ...

val workDir: File)
extends Logging {

Utils.checkHostPort(hostPort, "Expected hostport")

val fullId = appId + "/" + execId
var workerThread: Thread = null
var process: Process = null
Expand Down Expand Up @@ -92,7 +90,7 @@ private[spark] class ExecutorRunner(
/** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
case "{{EXECUTOR_ID}}" => execId.toString
case "{{HOSTNAME}}" => Utils.parseHostPort(hostPort)._1
case "{{HOSTNAME}}" => host
case "{{CORES}}" => cores.toString
case other => other
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/spark/deploy/worker/Worker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ private[spark] class Worker(
case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
val manager = new ExecutorRunner(
appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir)
appId, execId, appDesc, cores_, memory_, self, workerId, host, new File(execSparkHome_), workDir)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
Expand Down
12 changes: 8 additions & 4 deletions core/src/main/scala/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ import spark._
/**
* The Mesos executor for Spark.
*/
private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging {

private[spark] class Executor(
executorId: String,
slaveHostname: String,
properties: Seq[(String, String)])
extends Logging
{
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
Expand Down Expand Up @@ -125,8 +129,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
updateDependencies(taskFiles, taskJars)
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
attemptedTask = Some(task)
logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation)
logInfo("Its epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)
taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
Expand Down
7 changes: 3 additions & 4 deletions core/src/main/scala/spark/rdd/BlockRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@ private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc, Nil) {

@transient lazy val locations_ = BlockManager.blockIdsToExecutorLocations(blockIds, SparkEnv.get)
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)

override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => {
new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
}).toArray


override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDPartition].blockId
Expand All @@ -45,8 +44,8 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
}
}

override def getPreferredLocations(split: Partition): Seq[String] =
override def getPreferredLocations(split: Partition): Seq[String] = {
locations_(split.asInstanceOf[BlockRDDPartition].blockId)

}
}

2 changes: 1 addition & 1 deletion core/src/main/scala/spark/rdd/CartesianRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](

override def getPreferredLocations(split: Partition): Seq[String] = {
val currSplit = split.asInstanceOf[CartesianPartition]
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
(rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct
}

override def compute(split: Partition, context: TaskContext) = {
Expand Down
30 changes: 9 additions & 21 deletions core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,15 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
}

override def getPreferredLocations(s: Partition): Seq[String] = {
// Note that as number of rdd's increase and/or number of slaves in cluster increase, the computed preferredLocations below
// become diminishingly small : so we might need to look at alternate strategies to alleviate this.
// If there are no (or very small number of preferred locations), we will end up transferred the blocks to 'any' node in the
// cluster - paying with n/w and cache cost.
// Maybe pick a node which figures max amount of time ?
// Choose node which is hosting 'larger' of some subset of blocks ?
// Look at rack locality to ensure chosen host is atleast rack local to both hosting node ?, etc (would be good to defer this if possible)
val splits = s.asInstanceOf[ZippedPartitionsPartition].partitions
val rddSplitZip = rdds.zip(splits)

// exact match.
val exactMatchPreferredLocations = rddSplitZip.map(x => x._1.preferredLocations(x._2))
val exactMatchLocations = exactMatchPreferredLocations.reduce((x, y) => x.intersect(y))

// Remove exact match and then do host local match.
val exactMatchHosts = exactMatchLocations.map(Utils.parseHostPort(_)._1)
val matchPreferredHosts = exactMatchPreferredLocations.map(locs => locs.map(Utils.parseHostPort(_)._1))
.reduce((x, y) => x.intersect(y))
val otherNodeLocalLocations = matchPreferredHosts.filter { s => !exactMatchHosts.contains(s) }

otherNodeLocalLocations ++ exactMatchLocations
val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions
val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) }
// Check whether there are any hosts that match all RDDs; otherwise return the union
val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y))
if (!exactMatchLocations.isEmpty) {
exactMatchLocations
} else {
prefs.flatten.distinct
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are both "host:port" and "host" preferred locations, this will result in loosing the "host" preferred locations - right ?
Will that happen - I am not sure, but we should be defending against future changes imo.

Same applies to other similar changes ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I've decided not to allow host:port in preferred locations. Instead, Tasks that have a preferred executor can pass an executorID as part of their TaskLocation object.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, that is an excellent change ! had not noticed that


override def clearDependencies() {
Expand Down
Loading