diff --git a/core/src/main/scala/org/apache/spark/ExtResource.scala b/core/src/main/scala/org/apache/spark/ExtResource.scala new file mode 100644 index 000000000000..fe6a25e46a49 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ExtResource.scala @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +// import java.io.Serializable + + +import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet} + +case class ExtResourceInfo(slaveHostname: String, executorId: String, + name: String, timestamp: Long, sharable: Boolean, + partitionAffined: Boolean, instanceCount: Int, + instanceUseCount: Int) { + override def toString = { + ("host: %s\texecutor: %s\tname: %s\ttimestamp: %d\tsharable: %s\tpartitionAffined: " + + "%s\tinstanceCount %d\tinstances in use%d").format(slaveHostname, executorId, name, timestamp, + sharable.toString, partitionAffined.toString, instanceCount, instanceUseCount) + } +} + + +object ExternalResourceManager { + private lazy val taskToUsedResources = new HashMap[Long, HashSet[ExtResource[_]]] + + def cleanupResourcesPerTask(taskId: Long): Unit = { + // mark the resources by this task as unused + synchronized { +// taskToUsedResources.get(taskId).get.foreach(_.putInstances(taskId)) + val res = taskToUsedResources.get(taskId) + res.isDefined match { + case true => { + res.get.foreach(_.putInstances(taskId)) + taskToUsedResources -= taskId + } + //sma: debug +// case _ => print(s"\n +++++ cleanupResourcesPerTask : taskId ($taskId) not exist!") + case _ => + } + } + } + + def addResource(taskId: Long, res: ExtResource[_]) = { + synchronized { + taskToUsedResources.getOrElseUpdate(taskId, new HashSet[ExtResource[_]]()) += res + } + } +} + +/** record of number of uses of a shared resource instance per partition + */ +class ResourceRefCountPerPartition[T] (var refCnt: Int = 0, val instance: T) + +/** + * An external resource + */ +case class ExtResource[T]( + name: String, + shared: Boolean = false, + params: Seq[_], + init: (Int, Seq[_]) => T = null, // Initialization function + term: (Int, T, Seq[_]) => Unit = null, // Termination function + partitionAffined: Boolean = false, // partition speficication preferred + expiration: Int = -1 // optional expiration time, default to none; + // 0 for one-time use + ) extends Serializable { + + + private var instances: Any = null + + def getInstancesStat(shared: Boolean, + partitionAffined: Boolean): Any ={ + + def instInit(): Any ={ + println("init extResources instances") + (shared, partitionAffined) match{ + case (true, true) =>{ + instances = new HashMap[Int, ResourceRefCountPerPartition[T]] // map of partition to (use count, instance) + println("++++ TT instance type: "+ instances.getClass.getName) + instances + } + case (true, false) => + instances = init(-1, params) + case (false, true) => + instances = new HashMap[Int, ArrayBuffer[T]] + case (false, false) => + // large number of tasks per executor may deterioate modification performance + instances = ArrayBuffer[T]() + } + } + + Option(instances) match { + case None => instInit() + case _ => instances + } + instances + } + + + private var instancesInUse : Any = null + def getInstancesInUseStat (shared: Boolean, + partitionAffined: Boolean) : Any ={ + + def instInUseInit(): Unit ={ + (shared, partitionAffined) match { + case (true, true) => + instancesInUse = new HashMap[Long, Int]() // map from task id to partition + case (true, false) => instancesInUse = 0 // use count + case (false, true) => + instancesInUse = new HashMap[Long, Pair[Int, ArrayBuffer[T]]]() // map of task id to (partition, instances in use) + case (false, false) => + instancesInUse = new HashMap[Long, ArrayBuffer[T]]() // map of task id to instances in use + } + } + + Option(instancesInUse) match{ + case None => instInUseInit() + case _ => instancesInUse + } + instancesInUse + } + + + override def hashCode: Int = name.hashCode + + override def equals(other: Any): Boolean = other match { + case o: ExtResource[T] => + name.equals(o.name) + case _ => + false + } + + def getResourceInfo(host: String, executorId: String, timestamp: Long) + : ExtResourceInfo = { + synchronized { + instances = getInstancesStat(shared, partitionAffined) + instancesInUse = getInstancesInUseStat(shared, partitionAffined) + + println("++++ instance type: "+ instances.getClass.getName) + + (shared, partitionAffined) match { + case (true, true) => { + val instanceCnt = instances.asInstanceOf[HashMap[Int, ResourceRefCountPerPartition[T]]].size + val instanceUseCnt = instances.asInstanceOf[HashMap[Int, ResourceRefCountPerPartition[T]]].values.map(_.refCnt).foldLeft(0)(_ + _) +// val instanceUseCnt = instanceCnt match { +// case 0 => 0 +// case _ => instances.asInstanceOf[HashMap[Int, ResourceRefCountPerPartition[T]]].values.map(_.refCnt).reduce(_ + _) +// } + ExtResourceInfo(host, executorId, name, timestamp, true, true, instanceCnt, instanceUseCnt) + } + case (true, false) => { + ExtResourceInfo(host, executorId, name, timestamp, true, false, 1, instancesInUse.asInstanceOf[Int]) + } + case (false, true) => + val usedCount = instancesInUse.asInstanceOf[HashMap[Long, Pair[Int, ArrayBuffer[T]]]].values.map(_._2.size).foldLeft(0)(_ + _) + ExtResourceInfo(host, executorId, name, timestamp, false, true + , instances.asInstanceOf[HashMap[Int, ArrayBuffer[T]]].values.map(_.size).foldLeft(0)(_ + _) + usedCount, usedCount) + case (false, false) => + val usedCount = instancesInUse.asInstanceOf[HashMap[Long, ArrayBuffer[T]]].values.map(_.size).foldLeft(0)(_ + _) + ExtResourceInfo(host, executorId, name, timestamp, false, false + , instances.asInstanceOf[ArrayBuffer[T]].size + usedCount, usedCount) + } + } + } + + // Grab a newly established instance or from pool + def getInstance(split: Int, taskId: Long): T = { + synchronized { + // TODO: too conservative a locking: finer granular ones hoped + instances = getInstancesStat(shared, partitionAffined) + instancesInUse = getInstancesInUseStat(shared, partitionAffined) + + var result : T = { + (shared, partitionAffined) match { + case (false, false) => + val l = instances.asInstanceOf[ArrayBuffer[T]] + if (l.isEmpty) + init(split, params) + else + l.remove(0) + case (false, true) => + val hml = instances.asInstanceOf[HashMap[Int, ArrayBuffer[T]]] + var resList = hml.getOrElseUpdate(split, ArrayBuffer(init(split, params))) + if (resList.isEmpty) + init(split, params) + else + resList.remove(0) + case (true, true) => + val res = instances.asInstanceOf[HashMap[Int, ResourceRefCountPerPartition[T]]] + .getOrElseUpdate(split, new ResourceRefCountPerPartition[T](instance=init(split, params))) + res.refCnt += 1 + res.instance + case (true, false) => + if(instances != null) + instances + else + instances = init(-1, params) + instances.asInstanceOf[T] + } + } + + (shared, partitionAffined) match { + case (true, true) => + instancesInUse.asInstanceOf[HashMap[Long, Int]].put(taskId, split) + case (true, false) => + instancesInUse = instancesInUse.asInstanceOf[Int] + 1 + case (false, true) => + // add to the in-use instance list for non-sharable resources + val hml=instancesInUse.asInstanceOf[HashMap[Long, Pair[Int, ArrayBuffer[T]]]] + hml.getOrElseUpdate(taskId, (split, ArrayBuffer[T]()))._2 += result + case (false, false) => + val hm = instancesInUse.asInstanceOf[HashMap[Long, ArrayBuffer[T]]] + hm.getOrElseUpdate(taskId, ArrayBuffer[T]()) += result + } + ExternalResourceManager.addResource(taskId, this) + result + } + } + + // return instance to the pool; called by executor at task's termination + def putInstances(taskId: Long) : Unit = { + synchronized { + // TODO: too conservative a locking: finer granular ones hoped + instances = getInstancesStat(shared, partitionAffined) + instancesInUse = getInstancesInUseStat(shared, partitionAffined) + + (shared, partitionAffined) match { + case (true, true) => + instances.asInstanceOf[HashMap[Int, ResourceRefCountPerPartition[T]]] + .get(instancesInUse.asInstanceOf[HashMap[Long, Int]].get(taskId).get).get.refCnt -= 1 + case (true, false) => + instancesInUse = instancesInUse.asInstanceOf[Int] - 1 + case (false, true) => + val hml = instancesInUse.asInstanceOf[HashMap[Long, Pair[Int, ArrayBuffer[T]]]] +// val p = hml.get(taskId).get +// instances.asInstanceOf[HashMap[Int, ArrayBuffer[Any]]] +// .getOrElseUpdate(p._1, ArrayBuffer[Any]()) ++= p._2 + hml.get(taskId).map(p => instances.asInstanceOf[HashMap[Int, ArrayBuffer[T]]] + .getOrElseUpdate(p._1, ArrayBuffer[T]()) ++= p._2) + hml -= taskId + //sma : debug + instances.asInstanceOf[HashMap[Int, ArrayBuffer[T]]].foreach(hm => hm._2.foreach( + ab => println("++++ sma: debug: putInstances type: "+ab.getClass +"\n++++ ab value: "+ab + +"\n++++ instances after put: "+instances) + + )) + case (false, false) => + val hm = instancesInUse.asInstanceOf[HashMap[Long, ArrayBuffer[T]]] + hm.get(taskId).map(instances.asInstanceOf[ArrayBuffer[T]] ++= _) + hm -= taskId + } + } + } + + def cleanup(slaveHostname: String, executorId: String): String = { + val errorString + = "Executor %s at %s : External Resource %s has instances in use and can't be cleaned up now".format(executorId, slaveHostname, name) + val successString + = "Executor %s at %s : External Resource %s cleanup succeeds".format(executorId, slaveHostname, name) + synchronized { + instances = getInstancesStat(shared, partitionAffined) + instancesInUse = getInstancesInUseStat(shared, partitionAffined) + + (shared, partitionAffined) match { + case (true, true) => + // an all-or-nothing cleanup mechanism + if (instances.asInstanceOf[HashMap[Int, ResourceRefCountPerPartition[T]]].values.exists(_.refCnt >0)) + return errorString + else { + instances.asInstanceOf[HashMap[Int, ResourceRefCountPerPartition[T]]] + .foreach(r=>term(r._1, r._2.instance, params)) + instances.asInstanceOf[HashMap[Int, ResourceRefCountPerPartition[T]]].clear + } + case (true, false) => + if (instancesInUse.asInstanceOf[Int] > 0) + // an all-or-nothing cleanup mechanism + return errorString + else { + if (instances != null) + term(-1, instances.asInstanceOf[T], params) + instances = null + } + case (false, true) => + if (!instancesInUse.asInstanceOf[HashMap[Long, Pair[Int, ArrayBuffer[T]]]].isEmpty) + // an all-or-nothing cleanup mechanism + return errorString + else { + instances.asInstanceOf[HashMap[Int, ArrayBuffer[T]]].foreach(l =>l._2.foreach + (e => { + println("++++ cleanup extRsc: " + e) //e.asInstanceOf[Connection] + term(l._1, e, params)})) + instances.asInstanceOf[HashMap[Int, ArrayBuffer[T]]].clear + } + case (false, false) => + if (!instancesInUse.asInstanceOf[HashMap[Long, ArrayBuffer[T]]].isEmpty) + // an all-or-nothing cleanup mechanism + return errorString + else { + instances.asInstanceOf[ArrayBuffer[T]].foreach(term(-1, _, params)) + instances.asInstanceOf[ArrayBuffer[T]].clear + } + } + successString + } + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e132955f0f85..3bb8e48c1123 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -214,6 +214,9 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] val addedFiles = HashMap[String, Long]() private[spark] val addedJars = HashMap[String, Long]() + private val nextExtResourceId = new AtomicInteger(0) + private[spark] val addedExtResources = HashMap[ExtResource[_], Long]() + // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] private[spark] val metadataCleaner = @@ -833,6 +836,41 @@ class SparkContext(config: SparkConf) extends Logging { postEnvironmentUpdate() } + /** + * Add an (external) resource to be used with this Spark job on every node; + * overwrite if the resource already exists + */ + def addOrReplaceResource(res: ExtResource[_]) { + val ts = System.currentTimeMillis + addedExtResources(res) = ts + + logInfo("Added resource " + res.name + " with timestamp " + ts) + + postEnvironmentUpdate() + } + + def listResourceRDD() : RDD[ExtResourceInfo] = new ExtResourceListRDD(this) + // cleanup all outstanding resources + def cleanupAllResourceRDD() : RDD[String] = new ExtResourceCleanupRDD(this) + def cleanupResourceRDD(resourceName: String) : RDD[String] + = new ExtResourceCleanupRDD(this, Some(resourceName)) + + /** + * Add an (external) resource to be used with this Spark job on every node. + */ + def addResource(res: ExtResource[_]) { + val ts: Long= System.currentTimeMillis + if (addedExtResources.containsKey(res)) { + logError("Error adding resource (" + res.name + "): already added ") + } else { + addedExtResources(res) = ts + + logInfo("Added resource " + res.name + " with timestamp " + ts) + + postEnvironmentUpdate() + } + } + /** * :: DeveloperApi :: * Register a listener to receive up-calls from events that happen during execution. @@ -1132,6 +1170,9 @@ class SparkContext(config: SparkConf) extends Logging { * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { + //sma : test + getExecutorsAndLocations + print("++++ sma : getExecutorsAndLocations") runJob(rdd, func, 0 until rdd.partitions.size, false) } @@ -1293,8 +1334,9 @@ class SparkContext(config: SparkConf) extends Logging { val schedulingMode = getSchedulingMode.toString val addedJarPaths = addedJars.keys.toSeq val addedFilePaths = addedFiles.keys.toSeq + val addedExtResourceNames = addedExtResources.keys.map(_.name).toSeq val environmentDetails = - SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, addedFilePaths) + SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, addedFilePaths, addedExtResourceNames) val environmentUpdate = SparkListenerEnvironmentUpdate(environmentDetails) listenerBus.post(environmentUpdate) } @@ -1304,6 +1346,12 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def cleanup(cleanupTime: Long) { persistentRdds.clearOldValues(cleanupTime) } + + def getExecutorsAndLocations(): Seq[TaskLocation] = { + // all supported task schedulers are actually TaskSchedulerImpl ? + taskScheduler.asInstanceOf[TaskSchedulerImpl].getExecutorIdsAndLocations() +// Nil + } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 72716567ca99..6c9ecfcf99d4 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -298,7 +298,8 @@ object SparkEnv extends Logging { conf: SparkConf, schedulingMode: String, addedJars: Seq[String], - addedFiles: Seq[String]): Map[String, Seq[(String, String)]] = { + addedFiles: Seq[String], + addedExtResourceNames: Seq[String]): Map[String, Seq[(String, String)]] = { import Properties._ val jvmInformation = Seq( @@ -330,11 +331,13 @@ object SparkEnv extends Logging { .map((_, "System Classpath")) val addedJarsAndFiles = (addedJars ++ addedFiles).map((_, "Added By User")) val classPaths = (addedJarsAndFiles ++ classPathEntries).sorted + val addedExtResourceNames2 = addedExtResourceNames.map((_, "Added By User")) Map[String, Seq[(String, String)]]( "JVM Information" -> jvmInformation, "Spark Properties" -> sparkProperties, "System Properties" -> otherProperties, - "Classpath Entries" -> classPaths) + "Classpath Entries" -> classPaths, + "External Resources" -> addedExtResourceNames2) } } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 2b99b8a5af25..3d25f2594d86 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -17,7 +17,9 @@ package org.apache.spark -import scala.collection.mutable.ArrayBuffer +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics @@ -40,8 +42,10 @@ class TaskContext( val partitionId: Int, val attemptId: Long, val runningLocally: Boolean = false, - private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty) - extends Serializable { + val resources: Option[ConcurrentHashMap[String, Pair[ExtResource[_], Long]]] = None, + val executorId: Option[String] = None, + val slaveHostname: Option[String] = None, + private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)extends Serializable { @deprecated("use partitionId", "0.8.1") def splitId = partitionId @@ -111,4 +115,50 @@ class TaskContext( private[spark] def markInterrupted(): Unit = { interrupted = true } + + def getExtResourceUsageInfo() : Iterator[ExtResourceInfo] = { + synchronized { + //sma : debug +// println("++++ sma : getExtResourceUsageInfo") +// if(!resources.isDefined) println("++++ !resources.isDefined") +// if(!slaveHostname.isDefined) println("++++ slaveHostname.isDefined") +// if(!executorId.isDefined) println("++++ executorId.isDefined") + + if (resources.isDefined && slaveHostname.isDefined + && executorId.isDefined){ +// if (resources.isDefined){ + val res = resources.get.size + println(s"++++ sma : resources size : $res") + val smap = mapAsScalaMap(resources.get) + smap.map(r=>r._2._1.getResourceInfo(slaveHostname.get, + executorId.get, r._2._2)).toIterator + } + else{ + //sma : debug + println(s"++++ sma : resources or slaveHostname or executorId is not defined") + ArrayBuffer[ExtResourceInfo]().toIterator + } + } + } + + def cleanupResources(resourceName: Option[String]) : Iterator[String] = { + synchronized { + if (!resources.isDefined) + ArrayBuffer[String]("No external resources available to tasks for Executor %s at %s" + .format(executorId, slaveHostname)).toIterator + else if (resources.get.isEmpty) { + ArrayBuffer[String]("No external resources registered for Executor %s at %s" + .format(executorId, slaveHostname)).toIterator + } else if (resourceName.isDefined) { + if (resources.get.contains(resourceName.get)) + ArrayBuffer[String](resources.get.get(resourceName.get) + ._1.cleanup(slaveHostname.get, executorId.get)).toIterator + else + ArrayBuffer[String]("No external resources %s registered for Executor %s at %s" + .format(resourceName.get, executorId.get, slaveHostname.get)).toIterator + } else { + resources.get.map(_._2._1.cleanup(slaveHostname.get, executorId.get)).toIterator + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2f76e532aeb7..068bde59ce9d 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -1,376 +1,409 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import java.io.File -import java.lang.management.ManagementFactory -import java.nio.ByteBuffer -import java.util.concurrent._ - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap} - -import org.apache.spark._ -import org.apache.spark.scheduler._ -import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} -import org.apache.spark.util.{AkkaUtils, Utils} - -/** - * Spark executor used with Mesos, YARN, and the standalone scheduler. - */ -private[spark] class Executor( - executorId: String, - slaveHostname: String, - properties: Seq[(String, String)], - isLocal: Boolean = false) - extends Logging -{ - // Application dependencies (added through SparkContext) that we've fetched so far on this node. - // Each map holds the master's timestamp for the version of that file or JAR we got. - private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() - private val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - - private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) - - @volatile private var isStopped = false - - // No ip or host:port - just hostname - Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") - // must not have port specified. - assert (0 == Utils.parseHostPort(slaveHostname)._2) - - // Make sure the local hostname we report matches the cluster scheduler's name for this host - Utils.setCustomHostname(slaveHostname) - - // Set spark.* properties from executor arg - val conf = new SparkConf(true) - conf.setAll(properties) - - if (!isLocal) { - // Setup an uncaught exception handler for non-local mode. - // Make any thread terminations due to uncaught exceptions kill the entire - // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler(ExecutorUncaughtExceptionHandler) - } - - val executorSource = new ExecutorSource(this, executorId) - - // Initialize Spark environment (using system properties read above) - private val env = { - if (!isLocal) { - val _env = SparkEnv.create(conf, executorId, slaveHostname, 0, - isDriver = false, isLocal = false) - SparkEnv.set(_env) - _env.metricsSystem.registerSource(executorSource) - _env - } else { - SparkEnv.get - } - } - - // Create our ClassLoader - // do this after SparkEnv creation so can access the SecurityManager - private val urlClassLoader = createClassLoader() - private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) - - // Set the classloader for serializer - env.serializer.setDefaultClassLoader(urlClassLoader) - - // Akka's message frame size. If task result is bigger than this, we use the block manager - // to send the result back. - private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - - // Start worker thread pool - val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") - - // Maintains the list of running tasks. - private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] - - startDriverHeartbeater() - - def launchTask( - context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) { - val tr = new TaskRunner(context, taskId, taskName, serializedTask) - runningTasks.put(taskId, tr) - threadPool.execute(tr) - } - - def killTask(taskId: Long, interruptThread: Boolean) { - val tr = runningTasks.get(taskId) - if (tr != null) { - tr.kill(interruptThread) - } - } - - def stop() { - env.metricsSystem.report() - isStopped = true - threadPool.shutdown() - } - - class TaskRunner( - execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer) - extends Runnable { - - @volatile private var killed = false - @volatile var task: Task[Any] = _ - @volatile var attemptedTask: Option[Task[Any]] = None - - def kill(interruptThread: Boolean) { - logInfo(s"Executor is trying to kill $taskName (TID $taskId)") - killed = true - if (task != null) { - task.kill(interruptThread) - } - } - - override def run() { - val startTime = System.currentTimeMillis() - SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(replClassLoader) - val ser = SparkEnv.get.closureSerializer.newInstance() - logInfo(s"Running $taskName (TID $taskId)") - execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var taskStart: Long = 0 - def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum - val startGCTime = gcTime - - try { - SparkEnv.set(env) - Accumulators.clear() - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) - updateDependencies(taskFiles, taskJars) - task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) - - // If this task has been killed before we deserialized it, let's quit now. Otherwise, - // continue executing the task. - if (killed) { - // Throw an exception rather than returning, because returning within a try{} block - // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl - // exception will be caught by the catch block, leading to an incorrect ExceptionFailure - // for the task. - throw new TaskKilledException - } - - attemptedTask = Some(task) - logDebug("Task " + taskId + "'s epoch is " + task.epoch) - env.mapOutputTracker.updateEpoch(task.epoch) - - // Run the actual task and measure its runtime. - taskStart = System.currentTimeMillis() - val value = task.run(taskId.toInt) - val taskFinish = System.currentTimeMillis() - - // If the task has been killed, let's fail it. - if (task.killed) { - throw new TaskKilledException - } - - val resultSer = SparkEnv.get.serializer.newInstance() - val beforeSerialization = System.currentTimeMillis() - val valueBytes = resultSer.serialize(value) - val afterSerialization = System.currentTimeMillis() - - for (m <- task.metrics) { - m.executorDeserializeTime = taskStart - startTime - m.executorRunTime = taskFinish - taskStart - m.jvmGCTime = gcTime - startGCTime - m.resultSerializationTime = afterSerialization - beforeSerialization - } - - val accumUpdates = Accumulators.values - - val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) - val serializedDirectResult = ser.serialize(directResult) - val resultSize = serializedDirectResult.limit - - // directSend = sending directly back to the driver - val (serializedResult, directSend) = { - if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - val blockId = TaskResultBlockId(taskId) - env.blockManager.putBytes( - blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) - (ser.serialize(new IndirectTaskResult[Any](blockId)), false) - } else { - (serializedDirectResult, true) - } - } - - execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) - - if (directSend) { - logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") - } else { - logInfo( - s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") - } - } catch { - case ffe: FetchFailedException => { - val reason = ffe.toTaskEndReason - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - } - - case _: TaskKilledException | _: InterruptedException if task.killed => { - logInfo(s"Executor killed $taskName (TID $taskId)") - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) - } - - case t: Throwable => { - // Attempt to exit cleanly by informing the driver of our failure. - // If anything goes wrong (or this was a fatal exception), we will delegate to - // the default uncaught exception handler, which will terminate the Executor. - logError(s"Exception in $taskName (TID $taskId)", t) - - val serviceTime = System.currentTimeMillis() - taskStart - val metrics = attemptedTask.flatMap(t => t.metrics) - for (m <- metrics) { - m.executorRunTime = serviceTime - m.jvmGCTime = gcTime - startGCTime - } - val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics) - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - - // Don't forcibly exit unless the exception was inherently fatal, to avoid - // stopping other tasks unnecessarily. - if (Utils.isFatalError(t)) { - ExecutorUncaughtExceptionHandler.uncaughtException(t) - } - } - } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() - runningTasks.remove(taskId) - } - } - } - - /** - * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes - * created by the interpreter to the search path - */ - private def createClassLoader(): MutableURLClassLoader = { - val currentLoader = Utils.getContextOrSparkClassLoader - - // For each of the jars in the jarSet, add them to the class loader. - // We assume each of the files has already been fetched. - val urls = currentJars.keySet.map { uri => - new File(uri.split("/").last).toURI.toURL - }.toArray - val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false) - userClassPathFirst match { - case true => new ChildExecutorURLClassLoader(urls, currentLoader) - case false => new ExecutorURLClassLoader(urls, currentLoader) - } - } - - /** - * If the REPL is in use, add another ClassLoader that will read - * new classes defined by the REPL as the user types code - */ - private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = { - val classUri = conf.get("spark.repl.class.uri", null) - if (classUri != null) { - logInfo("Using REPL class URI: " + classUri) - val userClassPathFirst: java.lang.Boolean = - conf.getBoolean("spark.files.userClassPathFirst", false) - try { - val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") - .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader], - classOf[Boolean]) - constructor.newInstance(classUri, parent, userClassPathFirst) - } catch { - case _: ClassNotFoundException => - logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") - System.exit(1) - null - } - } else { - parent - } - } - - /** - * Download any missing dependencies if we receive a new set of files and JARs from the - * SparkContext. Also adds any new JARs we fetched to the class loader. - */ - private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - synchronized { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) - currentFiles(name) = timestamp - } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - urlClassLoader.addURL(url) - } - } - } - } - - def startDriverHeartbeater() { - val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) - val timeout = AkkaUtils.lookupTimeout(conf) - val retryAttempts = AkkaUtils.numRetries(conf) - val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) - - val t = new Thread() { - override def run() { - // Sleep a random interval so the heartbeats don't end up in sync - Thread.sleep(interval + (math.random * interval).asInstanceOf[Int]) - - while (!isStopped) { - val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() - for (taskRunner <- runningTasks.values()) { - if (!taskRunner.attemptedTask.isEmpty) { - Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => - metrics.updateShuffleReadMetrics - tasksMetrics += ((taskRunner.taskId, metrics)) - } - } - } - - val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) - if (response.reregisterBlockManager) { - logWarning("Told to re-register on heartbeat") - env.blockManager.reregister() - } - Thread.sleep(interval) - } - } - } - t.setDaemon(true) - t.setName("Driver Heartbeater") - t.start() - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.io.File +import java.lang.management.ManagementFactory +import java.nio.ByteBuffer +import java.util.concurrent._ + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, HashMap} + +import org.apache.spark._ +import org.apache.spark.scheduler._ +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} +import org.apache.spark.util.{AkkaUtils, Utils} + +/** + * Spark executor used with Mesos, YARN, and the standalone scheduler. + */ +private[spark] class Executor( + executorId: String, + slaveHostname: String, + properties: Seq[(String, String)], + isLocal: Boolean = false) + extends Logging +{ + // Application dependencies (added through SparkContext) that we've fetched so far on this node. + // Each map holds the master's timestamp for the version of that file or JAR we got. + private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() + private val currentJars: HashMap[String, Long] = new HashMap[String, Long]() + private val currentResources: ConcurrentHashMap[String, Pair[ExtResource[_], Long]] = new ConcurrentHashMap[String, Pair[ExtResource[_], Long]]() + + private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) + + @volatile private var isStopped = false + + // No ip or host:port - just hostname + Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") + // must not have port specified. + assert (0 == Utils.parseHostPort(slaveHostname)._2) + + // Make sure the local hostname we report matches the cluster scheduler's name for this host + Utils.setCustomHostname(slaveHostname) + + // Set spark.* properties from executor arg + val conf = new SparkConf(true) + conf.setAll(properties) + + if (!isLocal) { + // Setup an uncaught exception handler for non-local mode. + // Make any thread terminations due to uncaught exceptions kill the entire + // executor process to avoid surprising stalls. + Thread.setDefaultUncaughtExceptionHandler(ExecutorUncaughtExceptionHandler) + } + + val executorSource = new ExecutorSource(this, executorId) + + // Initialize Spark environment (using system properties read above) + private val env = { + if (!isLocal) { + val _env = SparkEnv.create(conf, executorId, slaveHostname, 0, + isDriver = false, isLocal = false) + SparkEnv.set(_env) + _env.metricsSystem.registerSource(executorSource) + _env + } else { + SparkEnv.get + } + } + + // Create our ClassLoader + // do this after SparkEnv creation so can access the SecurityManager + private val urlClassLoader = createClassLoader() + private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + + // Set the classloader for serializer + env.serializer.setDefaultClassLoader(urlClassLoader) + + // Akka's message frame size. If task result is bigger than this, we use the block manager + // to send the result back. + private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + + // Start worker thread pool + val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") + + // Maintains the list of running tasks. + private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + + startDriverHeartbeater() + + def launchTask( + context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) { + //sma : debug + println("++++ scheduledbackend : LaunchTask(data) -> executor.launchTask") + val tr = new TaskRunner(context, taskId, taskName, serializedTask) + runningTasks.put(taskId, tr) + threadPool.execute(tr) + } + + def killTask(taskId: Long, interruptThread: Boolean) { + val tr = runningTasks.get(taskId) + if (tr != null) { + tr.kill(interruptThread) + } + } + + def stop() { + env.metricsSystem.report() + isStopped = true + threadPool.shutdown() + // terminate live external resources + currentResources.foreach(_._2._1.cleanup(slaveHostname, executorId)) + } + + class TaskRunner( + execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer) + extends Runnable { + + @volatile private var killed = false + @volatile var task: Task[Any] = _ + @volatile var attemptedTask: Option[Task[Any]] = None + + def kill(interruptThread: Boolean) { + logInfo(s"Executor is trying to kill $taskName (TID $taskId)") + killed = true + if (task != null) { + task.kill(interruptThread) + } + } + + override def run() { + val startTime = System.currentTimeMillis() + SparkEnv.set(env) + Thread.currentThread.setContextClassLoader(replClassLoader) + val ser = SparkEnv.get.closureSerializer.newInstance() + logInfo(s"Running $taskName (TID $taskId)") + execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) + var taskStart: Long = 0 + def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + val startGCTime = gcTime + + try { + SparkEnv.set(env) + Accumulators.clear() + println("########## before deserializeWithDependencies") + val (taskFiles, taskJars, taskResources, taskBytes) = Task.deserializeWithDependencies(serializedTask) + updateDependencies(taskFiles, taskJars, taskResources) + task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + + // If this task has been killed before we deserialized it, let's quit now. Otherwise, + // continue executing the task. + if (killed) { + // Throw an exception rather than returning, because returning within a try{} block + // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl + // exception will be caught by the catch block, leading to an incorrect ExceptionFailure + // for the task. + throw new TaskKilledException + } + + attemptedTask = Some(task) + logDebug("Task " + taskId + "'s epoch is " + task.epoch) + env.mapOutputTracker.updateEpoch(task.epoch) + task.resources = Some(currentResources) + task.executorId = Some(executorId) + task.slaveHostname = Some(slaveHostname) + + //sma : debug + var rsDefinded = false + if (task.resources.isDefined) { + rsDefinded = true + task.resources.get.foreach(e => println("++++ Executor : resource name : "+ e._1)) + } + println(s"++++ check if excutor get resources : $rsDefinded") + + // Run the actual task and measure its runtime. + taskStart = System.currentTimeMillis() + val value = task.run(taskId.toInt) + val taskFinish = System.currentTimeMillis() + + // If the task has been killed, let's fail it. + if (task.killed) { + throw new TaskKilledException + } + + val resultSer = SparkEnv.get.serializer.newInstance() + val beforeSerialization = System.currentTimeMillis() + val valueBytes = resultSer.serialize(value) + val afterSerialization = System.currentTimeMillis() + + for (m <- task.metrics) { + m.executorDeserializeTime = taskStart - startTime + m.executorRunTime = taskFinish - taskStart + m.jvmGCTime = gcTime - startGCTime + m.resultSerializationTime = afterSerialization - beforeSerialization + } + + val accumUpdates = Accumulators.values + + val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) + val serializedDirectResult = ser.serialize(directResult) + val resultSize = serializedDirectResult.limit + + // directSend = sending directly back to the driver + val (serializedResult, directSend) = { + if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { + val blockId = TaskResultBlockId(taskId) + env.blockManager.putBytes( + blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) + (ser.serialize(new IndirectTaskResult[Any](blockId)), false) + } else { + (serializedDirectResult, true) + } + } + + execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) + + if (directSend) { + logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") + } else { + logInfo( + s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") + } + } catch { + case ffe: FetchFailedException => { + val reason = ffe.toTaskEndReason + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + } + + case _: TaskKilledException | _: InterruptedException if task.killed => { + logInfo(s"Executor killed $taskName (TID $taskId)") + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + } + + case t: Throwable => { + // Attempt to exit cleanly by informing the driver of our failure. + // If anything goes wrong (or this was a fatal exception), we will delegate to + // the default uncaught exception handler, which will terminate the Executor. + logError(s"Exception in $taskName (TID $taskId)", t) + + val serviceTime = System.currentTimeMillis() - taskStart + val metrics = attemptedTask.flatMap(t => t.metrics) + for (m <- metrics) { + m.executorRunTime = serviceTime + m.jvmGCTime = gcTime - startGCTime + } + val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics) + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + + // Don't forcibly exit unless the exception was inherently fatal, to avoid + // stopping other tasks unnecessarily. + if (Utils.isFatalError(t)) { + ExecutorUncaughtExceptionHandler.uncaughtException(t) + } + } + } finally { + // Release all external resources used + ExternalResourceManager.cleanupResourcesPerTask(taskId) + // Release memory used by this thread for shuffles + env.shuffleMemoryManager.releaseMemoryForThisThread() + // Release memory used by this thread for unrolling blocks + env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() + runningTasks.remove(taskId) + } + } + } + + /** + * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes + * created by the interpreter to the search path + */ + private def createClassLoader(): MutableURLClassLoader = { + val currentLoader = Utils.getContextOrSparkClassLoader + println(s"#################") + + // For each of the jars in the jarSet, add them to the class loader. + // We assume each of the files has already been fetched. + val urls = currentJars.keySet.map { uri => + println(s"################# $uri") + new File(uri.split("/").last).toURI.toURL + }.toArray + val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false) + userClassPathFirst match { + case true => new ChildExecutorURLClassLoader(urls, currentLoader) + case false => new ExecutorURLClassLoader(urls, currentLoader) + } + } + + /** + * If the REPL is in use, add another ClassLoader that will read + * new classes defined by the REPL as the user types code + */ + private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = { + val classUri = conf.get("spark.repl.class.uri", null) + if (classUri != null) { + logInfo("Using REPL class URI: " + classUri) + val userClassPathFirst: java.lang.Boolean = + conf.getBoolean("spark.files.userClassPathFirst", false) + try { + val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") + .asInstanceOf[Class[_ <: ClassLoader]] + val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader], + classOf[Boolean]) + constructor.newInstance(classUri, parent, userClassPathFirst) + } catch { + case _: ClassNotFoundException => + logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") + System.exit(1) + null + } + } else { + parent + } + } + + /** + * Download any missing dependencies if we receive a new set of files and JARs from the + * SparkContext. Also adds any new JARs we fetched to the class loader. + */ + private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long], newResources: HashMap[ExtResource[_], Long]) { + synchronized { + //sma : debug + println("++++ updateDependencies") + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL + if (!urlClassLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + + //sma : debug + println("++++ executor.updateDependencies : Adding " + url + " to class loader") + } + } + for ((resource, timestamp) <- newResources + if currentResources.getOrElse(resource.name, (null, -1L))._2 < timestamp) { + //sma : debug + println("++++ updateDependencies : currentResources: "+ resource.name) + logInfo("Initializing " + resource.name + " with timestamp " + timestamp) + currentResources(resource.name) = (resource, timestamp) + } + } + } + + def startDriverHeartbeater() { + val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) + val timeout = AkkaUtils.lookupTimeout(conf) + val retryAttempts = AkkaUtils.numRetries(conf) + val retryIntervalMs = AkkaUtils.retryWaitMs(conf) + val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + + val t = new Thread() { + override def run() { + // Sleep a random interval so the heartbeats don't end up in sync + Thread.sleep(interval + (math.random * interval).asInstanceOf[Int]) + + while (!isStopped) { + val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() + for (taskRunner <- runningTasks.values()) { + if (!taskRunner.attemptedTask.isEmpty) { + Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => + metrics.updateShuffleReadMetrics + tasksMetrics += ((taskRunner.taskId, metrics)) + } + } + } + + val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) + val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, + retryAttempts, retryIntervalMs, timeout) + if (response.reregisterBlockManager) { + logWarning("Told to re-register on heartbeat") + env.blockManager.reregister() + } + Thread.sleep(interval) + } + } + } + t.setDaemon(true) + t.setName("Driver Heartbeater") + t.start() + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/AdminRDD.scala b/core/src/main/scala/org/apache/spark/rdd/AdminRDD.scala new file mode 100644 index 000000000000..8a2630b5aab8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/AdminRDD.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark._ +import org.apache.spark.scheduler.TaskLocation +import org.apache.spark.annotation.DeveloperApi + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +/** + * A Spark split class that is per executor + */ +private[spark] class ExecutorPartition(idx: Int, val taskLocation: TaskLocation) + extends Partition { + + override val index: Int = idx +} + +/** + * :: DeveloperApi :: + * An abstract RDD that provides basics to dispatch executor-specific task + * (mostly for administration purposes) + * + * @param sc The SparkContext to associate the RDD with. + */ +@DeveloperApi +abstract class AdminRDD[T: ClassTag]( + @transient sc: SparkContext) + extends RDD[T](sc, Nil) with Logging { + + override def getPartitions: Array[Partition] = { + val executors : Seq[TaskLocation] = sc.getExecutorsAndLocations + print(s"value of executors : $executors") + val array = new Array[Partition](executors.size) + for (i <- 0 until executors.size) { + array(i) = new ExecutorPartition(i, executors(i)) + } + array + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + val adminSplit = split.asInstanceOf[ExecutorPartition] + ArrayBuffer[String](adminSplit.taskLocation.host) + } +} + + +class AdminTestRDD( + sc: SparkContext) + extends AdminRDD[Char](sc) { + override def compute(split: Partition, context: TaskContext): Iterator[Char] = +// context.getExtResourceUsageInfo + "ABC".toCharArray.toIterator +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/rdd/AdminRDDKen.scala b/core/src/main/scala/org/apache/spark/rdd/AdminRDDKen.scala new file mode 100644 index 000000000000..a842144db337 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/AdminRDDKen.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark._ +import org.apache.spark.scheduler.TaskLocation +import org.apache.spark.annotation.DeveloperApi + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +/** + * A Spark split class that is per executor + */ +private[spark] class ExecutorPartitionKen(idx: Int, @transient val taskLocation: TaskLocation) + extends Partition { + + override val index: Int = idx +} + +/** + * :: DeveloperApi :: + * An abstract RDD that provides basics to dispatch executor-specific task + * (mostly for administration purposes) + * + * @param sc The SparkContext to associate the RDD with. + */ +@DeveloperApi +abstract class AdminRDDKen[T: ClassTag]( + sc: SparkContext) + extends RDD[T](sc, Nil) with Logging { + + override def getPartitions: Array[Partition] = { + val executors : Seq[TaskLocation] = sc.getExecutorsAndLocations + val array = new Array[Partition](executors.size) + for (i <- 0 until executors.size) { + array(i) = new ExecutorPartition(i, executors(i)) + } + array + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + val adminSplit = split.asInstanceOf[ExecutorPartition] + ArrayBuffer[String](adminSplit.taskLocation.host) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ExtResourceAdminRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ExtResourceAdminRDD.scala new file mode 100644 index 000000000000..83422ffe4373 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ExtResourceAdminRDD.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.scheduler.TaskLocation +import scala.collection.mutable.ArrayBuffer + + + +class ExtResourceListRDD( + sc: SparkContext) + extends AdminRDD[ExtResourceInfo](sc) { + override def compute(split: Partition, context: TaskContext): Iterator[ExtResourceInfo] = + context.getExtResourceUsageInfo +} + +class ExtResourceCleanupRDD( + sc: SparkContext, + resourceName: Option[String] = None) + extends AdminRDD[String](sc) { + override def compute(split: Partition, context: TaskContext): Iterator[String]= + context.cleanupResources(resourceName) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDDExtRsc.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDDExtRsc.scala new file mode 100644 index 000000000000..38636bf7bae2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDDExtRsc.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.sql.{Connection, ResultSet} + +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} +import org.apache.spark.util.{Utils, NextIterator} + +private[spark] class JdbcPartitionExt(idx: Int, val lower: Long, val upper: Long) extends Partition { + override def index = idx +} +// TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private +/** + * An RDD that executes an SQL query on a JDBC connection and reads results. + * For usage example, see test case JdbcRDDSuite. + * + * @param getConnection a function that returns an open Connection. + * The RDD takes care of closing the connection. + * @param sql the text of the query. + * The query must contain two ? placeholders for parameters used to partition the results. + * E.g. "select title, author from books where ? <= id and id <= ?" + * @param lowerBound the minimum value of the first placeholder + * @param upperBound the maximum value of the second placeholder + * The lower and upper bounds are inclusive. + * @param numPartitions the number of partitions. + * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + * the query would be executed twice, once with (1, 10) and once with (11, 20) + * @param mapRow a function from a ResultSet to a single row of the desired result type(s). + * This should only call getInt, getString, etc; the RDD takes care of calling next. + * The default maps a ResultSet to an array of Object. + */ +class JdbcRDDExtRsc[T: ClassTag]( + sc: SparkContext, + extRscName: String, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + mapRow: (ResultSet) => T = JdbcRDDExtRsc.resultSetToObjectArray _) + extends RDD[T](sc, Nil) with Logging { + + override def getPartitions: Array[Partition] = { + // bounds are inclusive, hence the + 1 here and - 1 on end + val length = 1 + upperBound - lowerBound + (0 until numPartitions).map(i => { + val start = lowerBound + ((i * length) / numPartitions).toLong + val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1 + new JdbcPartitionExt(i, start, end) + }).toArray + } + + override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + val x = Class.forName("com.mysql.jdbc.Driver", true, Utils.getContextOrSparkClassLoader) + println(x.toString) + println("get driver class "+ x) + + context.addTaskCompletionListener{ context => closeIfNeeded() } + val part = thePart.asInstanceOf[JdbcPartitionExt] + // val conn = getConnection() + //Todo: sma : exception handling + if (!context.resources.isDefined) throw new Exception("No available ExtResources") + + val extRsc = context.resources.get.get(extRscName) + if (extRsc==null) throw new Exception(s"No such resource : $extRscName") + val rsc = extRsc._1 + println("++++ rdd compute: param size : "+rsc.params.size ) + println("Object Id = :" + rsc) +// Thread.dumpStack() + //sma : debug + println("before extRsc: ins: " +rsc.getInstancesStat(rsc.shared, rsc.partitionAffined)) + println("before extRsc useinstance: " + rsc.getInstancesInUseStat(rsc.shared, rsc.partitionAffined)) + + + //Todo: sma : exception handling + val conn = rsc.getInstance(context.partitionId, context.attemptId).asInstanceOf[Connection] + val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + + //sma : debug + println("after extRsc: ins: " +rsc.getInstancesStat(rsc.shared, rsc.partitionAffined)) + println("after extRsc useinstance: " + rsc.getInstancesInUseStat(rsc.shared, rsc.partitionAffined)) + + // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results, + // rather than pulling entire resultset into memory. + // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html + if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) { + stmt.setFetchSize(Integer.MIN_VALUE) + logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ") + } + + stmt.setLong(1, part.lower) + stmt.setLong(2, part.upper) + val rs = stmt.executeQuery() + + override def getNext: T = { + if (rs.next()) { + mapRow(rs) + } else { + finished = true + null.asInstanceOf[T] + } + } + + override def close() { + try { + if (null != rs && ! rs.isClosed()) { + rs.close() + } + } catch { + case e: Exception => logWarning("Exception closing resultset", e) + } + try { + if (null != stmt && ! stmt.isClosed()) { + stmt.close() + } + } catch { + case e: Exception => logWarning("Exception closing statement", e) + } +// try { +// if (null != conn && ! conn.isClosed()) { +//// conn.close() +// rsc.putInstances(context.attemptId) +// } +// logInfo("closed connection") +// } catch { +// case e: Exception => logWarning("Exception closing connection", e) +// } + } + } +} + +object JdbcRDDExtRsc { + def resultSetToObjectArray(rs: ResultSet): Array[Object] = { + Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) + } +} + diff --git a/core/src/main/scala/org/apache/spark/rdd/extRscTestRDD.scala b/core/src/main/scala/org/apache/spark/rdd/extRscTestRDD.scala new file mode 100644 index 000000000000..795de0407fd8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/extRscTestRDD.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.sql.{Connection, ResultSet} + +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} +import org.apache.spark.util.NextIterator + + +class extRscTestRDD[T: ClassTag]( + sc: SparkContext, + extRscName: String) + extends RDD[T](sc, Nil) with Logging { + override def getPartitions: Array[Partition] = { + val array = new Array[Partition](4) + for (i <- 0 until array.size){ + array(i) = new Partition { override def index: Int = i } + } + array + } + + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + throw new UnsupportedOperationException("empty RDD") + //Todo: sma : exception handling +// if (!context.resources.isDefined) throw new Exception("No available ExtResources") +// +// val extRsc = context.resources.get.get(extRscName) +// if (extRsc==null) throw new Exception(s"No such resource : $extRscName") +// val rsc = extRsc._1 +// println("++++ rdd compute: param size : "+rsc.params.size ) +// println("Object Id = :" + rsc) +// // Thread.dumpStack() +// +// rsc.getInstance(context.partitionId, context.attemptId).asInstanceOf[String].toIterator + + } +} + + + diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 2ccc27324ac8..efd8a6005053 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{AdminRDD, ExecutorPartition} import org.apache.spark.storage._ import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -745,6 +746,8 @@ class DAGScheduler( listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties)) runLocally(job) } else { + //sma : debug + println("++++ DAGScheduler.handleJobSubmitted : nonlocal") jobIdToActiveJob(jobId) = job activeJobs += job finalStage.resultOfJob = Some(job) @@ -1288,6 +1291,15 @@ class DAGScheduler( // Nil has already been returned for previously visited partitions. return Nil } + + // For an admin RDD, the TaskLocations are explicitly set + // This needs to be called before checking on cached locations + // since the locations and executors could be dynamic and fluid + // for failed nodes etc. + if (rdd.isInstanceOf[AdminRDD[_]]) { + return rdd.partitions.map(_.asInstanceOf[ExecutorPartition].taskLocation) + } + // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) if (!cached.isEmpty) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 6aa0cca06878..698f18bd96e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,13 +17,16 @@ package org.apache.spark.scheduler -import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream, + ObjectOutputStream, ObjectInputStream} import java.nio.ByteBuffer +import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.HashMap import org.apache.spark.TaskContext -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.ExtResource +import org.apache.spark.executor.{TaskMetrics, Executor} import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils @@ -45,7 +48,12 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { final def run(attemptId: Long): T = { - context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + //sma : debug + println(s"++++ Task instance : $this") + val rs = this.resources.getOrElse(new ConcurrentHashMap[String, Pair[ExtResource[_], Long]]()).size + println(s"++++ # resources in task is $rs") + context = new TaskContext(stageId, partitionId, attemptId, + false, this.resources, this.executorId, this.slaveHostname) context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { @@ -73,6 +81,10 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex // initialized when kill() is invoked. @volatile @transient private var _killed = false + @transient var resources : Option[ConcurrentHashMap[String, Pair[ExtResource[_], Long]]] = None + @transient var executorId : Option[String] = None + @transient var slaveHostname : Option[String] = None + /** * Whether the task has been killed. */ @@ -110,11 +122,12 @@ private[spark] object Task { task: Task[_], currentFiles: HashMap[String, Long], currentJars: HashMap[String, Long], + currentExtResources: HashMap[ExtResource[_], Long], serializer: SerializerInstance) : ByteBuffer = { val out = new ByteArrayOutputStream(4096) - val dataOut = new DataOutputStream(out) + val dataOut = new ObjectOutputStream(out) // Write currentFiles dataOut.writeInt(currentFiles.size) @@ -130,6 +143,14 @@ private[spark] object Task { dataOut.writeLong(timestamp) } + // Write currentExtResources + // If the init/term closures are big, serde per task won't be efficient + dataOut.writeInt(currentExtResources.size) + for ((resource, timestamp) <- currentExtResources) { + dataOut.writeObject(resource) + dataOut.writeLong(timestamp) + } + // Write the task itself and finish dataOut.flush() val taskBytes = serializer.serialize(task).array() @@ -145,10 +166,11 @@ private[spark] object Task { * @return (taskFiles, taskJars, taskBytes) */ def deserializeWithDependencies(serializedTask: ByteBuffer) - : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { + : (HashMap[String, Long], HashMap[String, Long], + HashMap[ExtResource[_], Long], ByteBuffer) = { val in = new ByteBufferInputStream(serializedTask) - val dataIn = new DataInputStream(in) + val dataIn = new ObjectInputStream(in) // Read task's files val taskFiles = new HashMap[String, Long]() @@ -164,8 +186,17 @@ private[spark] object Task { taskJars(dataIn.readUTF()) = dataIn.readLong() } + // Read task's external resources + // If the init/term closures are big, serde per task won't be efficient + val taskResources = new HashMap[ExtResource[_], Long]() + val numResources = dataIn.readInt() + for (i <- 0 until numResources) { + val ob = dataIn.readObject() + taskResources(dataIn.readObject().asInstanceOf[ExtResource[_]]) = dataIn.readLong() + } + // Create a sub-buffer for the rest of the data, which is the serialized Task object val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task - (taskFiles, taskJars, subBuffer) + (taskFiles, taskJars, taskResources, subBuffer) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index ad051e59af86..248f793fc25f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -93,7 +93,7 @@ private[spark] class TaskSchedulerImpl( protected val hostsByRack = new HashMap[String, HashSet[String]] - protected val executorIdToHost = new HashMap[String, String] + protected val executorIdToHosts = new HashMap[String, String] // Listener object to pass upcalls into var dagScheduler: DAGScheduler = null @@ -177,6 +177,7 @@ private[spark] class TaskSchedulerImpl( } hasReceivedTask = true } + println("++++ TaskSchedulerImpl.backend : $backend") backend.reviveOffers() } @@ -222,7 +223,7 @@ private[spark] class TaskSchedulerImpl( // Also track if new executor is added var newExecAvail = false for (o <- offers) { - executorIdToHost(o.executorId) = o.host + executorIdToHosts(o.executorId) = o.host if (!executorsByHost.contains(o.host)) { executorsByHost(o.host) = new HashSet[String]() executorAdded(o.executorId, o.host) @@ -420,7 +421,7 @@ private[spark] class TaskSchedulerImpl( synchronized { if (activeExecutorIds.contains(executorId)) { - val hostPort = executorIdToHost(executorId) + val hostPort = executorIdToHosts(executorId) logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) removeExecutor(executorId) failedExecutor = Some(executorId) @@ -442,7 +443,7 @@ private[spark] class TaskSchedulerImpl( /** Remove an executor from all our data structures and mark it as lost */ private def removeExecutor(executorId: String) { activeExecutorIds -= executorId - val host = executorIdToHost(executorId) + val host = executorIdToHosts(executorId) val execs = executorsByHost.getOrElse(host, new HashSet) execs -= executorId if (execs.isEmpty) { @@ -454,7 +455,7 @@ private[spark] class TaskSchedulerImpl( } } } - executorIdToHost -= executorId + executorIdToHosts -= executorId rootPool.executorLost(executorId, host) } @@ -491,6 +492,13 @@ private[spark] class TaskSchedulerImpl( } } } + + def getExecutorIdsAndLocations() : Seq[TaskLocation] = { + executorIdToHosts.foreach(str => println(s"+++++ getExecutorIdsAndLocations : $str._1 and $str._2" )) + val eSize = executorIdToHosts.size + println(s"++++ sma: executorIdToHosts size : $eSize") + executorIdToHosts.map(e=>TaskLocation(e._1, e._2)).toSeq + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d9d53faf843f..18f960dac51d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -436,7 +436,8 @@ private[spark] class TaskSetManager( // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here // we assume the task can be serialized without exceptions. val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) + task, sched.sc.addedFiles, sched.sc.addedJars, + sched.sc.addedExtResources, ser) if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && !emittedTaskSizeWarning) { emittedTaskSizeWarning = true diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 7e18f45de7b5..06a1c2ee0ce2 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -73,6 +73,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { tmpFile = textFile tmpJarUrl = jarFile.toURI.toURL.toString + println(tmpJarUrl) } override def afterAll() { @@ -135,8 +136,11 @@ class FileServerSuite extends FunSuite with LocalSparkContext { sc.parallelize(testData).foreach { x => if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { throw new SparkException("jar not added") + }else{ + println(Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt")) } } + println("done!") } test("Distributing files on a standalone cluster") { diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 0b6511a80df1..81cb2441ff46 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -30,6 +30,8 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => var conf = new SparkConf(false) override def beforeAll() { +// _sc = new SparkContext("spark://127.0.0.1:7077", "test", conf) +// _sc = new SparkContext("local-cluster[3, 1, 512]", "test", conf) _sc = new SparkContext("local", "test", conf) super.beforeAll() } diff --git a/core/src/test/scala/org/apache/spark/rdd/JDBCexample.scala b/core/src/test/scala/org/apache/spark/rdd/JDBCexample.scala new file mode 100644 index 000000000000..1748e79d0ebd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/JDBCexample.scala @@ -0,0 +1,41 @@ +/** + * Created by ken on 10/22/14. + */ +package org.apache.spark.rdd + +import java.sql.DriverManager +import java.sql.Connection + + +object ScalaJdbcConnectSelect { + + def main(args: Array[String]) { + // connect to the database named "mysql" on the localhost + val driver = "com.mysql.jdbc.Driver" + val url = "jdbc:mysql://localhost/mysql" + val username = "ken" + val password = "km" + + // there's probably a better way to do this + var connection:Connection = null + + try { + // make the connection + Class.forName(driver) + connection = DriverManager.getConnection(url, username, password) + + // create the statement, and run the select query + val statement = connection.createStatement() + val resultSet = statement.executeQuery("SELECT host, user FROM user") + while ( resultSet.next() ) { + val host = resultSet.getString("host") + val user = resultSet.getString("user") + println("host, user = " + host + ", " + user) + } + } catch { + case e: Throwable => e.printStackTrace + } + connection.close() + } + +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 926d4fecb5b9..1f860938f39a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -32,6 +32,18 @@ import org.apache.spark.rdd.RDDSuiteUtils._ class RDDSuite extends FunSuite with SharedSparkContext { + + test("test add jar") { + + val driver = "com.mysql.jdbc.Driver" + sc.addJar("file:///usr/share/java/mysql-connector-java.jar") + sc.parallelize((1 to 40), 4).foreach { iter => + val x = Thread.currentThread.getContextClassLoader.loadClass("com.mysql.jdbc.Driver") + println(x.toString) + } + + } + test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuiteKen.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteKen.scala new file mode 100644 index 000000000000..474fc68870e4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuiteKen.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.net.{URLClassLoader, URL} +import java.sql.{Driver, ResultSet, DriverManager, Connection} + +import org.apache.spark.util.Utils._ + +import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +import org.scalatest.FunSuite + +import org.apache.spark._ +import org.apache.spark.SparkContext._ +import org.apache.spark.util.Utils + +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.rdd.RDDSuiteUtils._ + +class RDDSuiteKen extends FunSuite with SharedSparkContext { + test("AdminRdd operation") { + val arr = new ArrayBuffer[String]() + arr += "ABC" + + val eSrc = new ExtResource("Ken's test", true, arr, null , null, true) + val rsInfo = eSrc.getResourceInfo("Hostname", "executorId", 12345) + println(s"ResourceInfo instanceCount: "+ rsInfo.instanceCount + + "\nResourceInfo instanceUseCount: "+ rsInfo.instanceUseCount + ) + + sc.addOrReplaceResource(eSrc) + + val adRdd = new ExtResourceListRDD(sc) + val v = adRdd.collect().size + print(s"\nsize of adminRdd : $v") + + val cnt = adRdd.count() + print(s"\nCount of adminRdd : $cnt") + + + adRdd.collect().foreach{ + + e => println("admin RDD : "+e.name + " : "+ e.slaveHostname ) + } + + } + + test("add mysql jdbc connection as ExtResource") { + + val driver = "com.mysql.jdbc.Driver" + val url = "jdbc:mysql://127.0.0.1/mysql" + val username = "ken" + val password = "km" + + var myparams=Array(driver, url, username, password) + + val myinit = (split: Int, params: Seq[_]) => { + + require(params.size>3, s"parameters error, current param size: "+ params.size) + val p = params + val driver = p(0).toString + val url = p(1).toString + val username = p(2).toString + val password = p(3).toString + + var connection:Connection = null + try { + // make the connection + val cl = Thread.currentThread.getContextClassLoader + println("++++ cl.loadClass : "+cl.loadClass(driver).newInstance()) + + val loader = Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader) + + val clV = driver.replaceAll("\\.", "/")+".class" + println(cl.getResource(clV)) +// Class.forName(driver) +// val sqlCl = Class.forName(driver, true, Thread.currentThread.getContextClassLoader).newInstance() + val sqlCl = Class.forName(driver, true, loader) +// val sqlCl = cl.loadClass(driver).newInstance() + DriverManager.registerDriver(sqlCl.asInstanceOf[Driver]) + connection = DriverManager.getConnection(url, username, password) + } catch { + case e: Throwable => e.printStackTrace + } + connection + } + + + val myterm = (split: Int, conn: Any, params: Seq[_]) => { + require(Option(conn) != None, "Connection error") + try{ + val c = conn.asInstanceOf[Connection] + c.close() + }catch { + case e: Throwable => e.printStackTrace + } + } + + val eSrc = new ExtResource("mysql ExtRsc test", + shared=false, params = myparams, init=myinit , term=myterm, partitionAffined=false) + + //1. add extRsc to sparkContext + sc.addOrReplaceResource(eSrc) + sc.addJar("file:///usr/share/java/mysql-connector-java.jar") + + + + //2. create rdd + val rdd = new JdbcRDDExtRsc( + sc, + eSrc.name, + "SELECT host, user FROM user limit ?, ?", + 1, 6, 3, + (r: ResultSet) => { r.getString(1) } ).cache() + + //3. output + println("# of rows (rdd.count): "+rdd.count) +// println("# of row (rdd.reduce(_+_)): "+rdd.count) +// assert(rdd.count === 100) +// assert(rdd.reduce(_+_) === 10100) + + } + + test("test ..."){ + + } + + test("test add jar") { + val driver = "com.mysql.jdbc.Driver" + val url = "jdbc:mysql://127.0.0.1/mysql" + val username = "ken" + val password = "km" + + sc.addJar("file:///usr/share/java/mysql-connector-java.jar") + sc.parallelize((1 to 40), 4).foreach { iter => + val x = Class.forName(driver, true, Thread.currentThread.getContextClassLoader).newInstance() +// val x = Class.forName(driver).newInstance() + println(x.toString) + x.isInstanceOf[Driver] match { + case true => { + println("get driver class "+ x) + + var connection:Connection = null + try { +// DriverManager.registerDriver(x.asInstanceOf[Driver]) + connection = DriverManager.getConnection(url, username, password) + println("successfully create connection: "+ connection) + } catch { + case e: Throwable => e.printStackTrace + }finally { + if (connection !=null) connection.close() + } + } + case _ => println("get driver class fail! "+ x) + } + } + + } + + test("jdbc test"){ + val driver = "com.mysql.jdbc.Driver" + val url = "jdbc:mysql://localhost/mysql" + val username = "ken" + val password = "km" + + // there's probably a better way to do this + var connection:Connection = null + + try { + // make the connection + Class.forName(driver) + connection = DriverManager.getConnection(url, username, password) + + // create the statement, and run the select query + val statement = connection.createStatement() + val resultSet = statement.executeQuery("SELECT host, user FROM user") + while ( resultSet.next() ) { + val host = resultSet.getString("host") + val user = resultSet.getString("user") + println("host, user = " + host + ", " + user) + } + } catch { + case e: Throwable => e.printStackTrace + } + connection.close() + } + + + test("Ken : basic operations") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(nums.collect().toList === List(1, 2, 3, 4)) + assert(nums.toLocalIterator.toList === List(1, 2, 3, 4)) + } + + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/Test.scala b/examples/src/main/scala/org/apache/spark/examples/Test.scala new file mode 100644 index 000000000000..71024bb7e186 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/Test.scala @@ -0,0 +1,79 @@ +package org.apache.spark.examples + +import java.net.{URLClassLoader, URL} +import java.sql.{Driver, ResultSet, DriverManager, Connection} +import scala.collection.mutable.{ArrayBuffer, HashMap} +import org.apache.spark.rdd._ +import org.apache.spark._ +/** + * Created by ken on 11/10/14. + */ +object Test { + def main (args: Array[String]) { + val cf = new SparkConf().setAppName("wfwfwf").setMaster("local") + val sc = new SparkContext(cf) + + val driver = "com.mysql.jdbc.Driver" + val url = "jdbc:mysql://127.0.0.1/mysql" + val username = "ken" + val password = "km" + + var myparams=Array(driver, url, username, password) + + def myinit(split: Int, params: Seq[_]): Connection = { + require(params.size>3, s"parameters error, current param size: "+ params.size) + val p = params + val driver = p(0).toString + val url = p(1).toString + val username = p(2).toString + val password = p(3).toString +// +// var connection:Connection = null +// try { +// val loader = Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader) +// +//// val x = Class.forName(driver, true, loader) +// val x = "wf" +// println(x.toString) +// println("get driver class "+ x) +// connection = DriverManager.getConnection(url, username, password) +// } catch { +// case e: Throwable => e.printStackTrace +// } +// connection + null + } + + +// val myterm = (split: Int, conn: Any, params: Seq[_]) => { +// require(Option(conn) != None, "Connection error") +// try{ +// val c = conn.asInstanceOf[Connection] +// c.close() +// }catch { +// case e: Throwable => e.printStackTrace +// } +// } + + val eSrc = new ExtResource("mysql ExtRsc test", false, myparams, null , null, false) + + sc.addJar("file:///usr/share/java/mysql-connector-java.jar") + Thread.sleep(1000) + //1. add extRsc to sparkContext + sc.addOrReplaceResource(eSrc) + + + + //2. create rdd + val rdd = new JdbcRDDExtRsc( + sc, + eSrc.name, + "SELECT host, user FROM user limit ?, ?", + 1, 6, 3, + (r: ResultSet) => { r.getString(1) } ).cache() + + //3. output + println("# of rows (rdd.count): "+rdd.count) + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/Test1.scala b/examples/src/main/scala/org/apache/spark/examples/Test1.scala new file mode 100644 index 000000000000..e59079a6955d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/Test1.scala @@ -0,0 +1,49 @@ +package org.apache.spark.examples + +import java.sql.{ResultSet, DriverManager, Connection} + +import org.apache.spark.examples.Test._ +import org.apache.spark.rdd.JdbcRDDExtRsc +import org.apache.spark.util.Utils +import org.apache.spark.{ExtResource, SparkContext, SparkConf} + +/** + * Created by ken on 11/10/14. + */ +object Test1 { + def main (args: Array[String]) { + val cf = new SparkConf().setAppName("wfwfwf").setMaster("local") + val sc = new SparkContext(cf) + + val driver = "com.mysql.jdbc.Driver" + val url = "jdbc:mysql://127.0.0.1/mysql" + val username = "ken" + val password = "km" + + sc.addJar("file:///usr/share/java/mysql-connector-java.jar") + Thread.sleep(1000) + + + sc.parallelize((1 to 40), 4).foreach { iter => + val x = Class.forName(driver, true, Utils.getContextOrSparkClassLoader) + println(x.toString) + println("get driver class " + x) + var connection: Connection = null + try { + connection = DriverManager.getConnection(url, username, password) + println("successfully create connection: " + connection) + } catch { + case e: Throwable => e.printStackTrace + } finally { + if (connection != null) connection.close() + } + } +// val rdd = new JdbcRDDExtRsc( +// sc, +// "ken", +// "SELECT host, user FROM user limit ?, ?", +// 1, 6, 3, +// (r: ResultSet) => { r.getString(1) } ).cache() + + } +}