diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 83ce14a0a806a..514bddd9cba09 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -55,7 +55,7 @@ private[deploy] object DeployMessages { extends DeployMessage case class DriverStateChanged( - driverId: String, + driverID: String, state: DriverState, exception: Option[Exception]) extends DeployMessage diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala index 33377931d6993..0f33952b28cb9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala @@ -33,4 +33,9 @@ private[spark] class DriverInfo( @transient var exception: Option[Exception] = None /* Most recent worker assigned to this driver */ @transient var worker: Option[WorkerInfo] = None + + /** + * the retry times of starting a in-cluster driver + */ + @transient var retriedcountOnMaster = 0 } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 95bd62e88db2b..e6e4267b03fd4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -71,8 +71,8 @@ private[spark] class Master( var nextAppNumber = 0 val appIdToUI = new HashMap[String, SparkUI] - - val drivers = new HashSet[DriverInfo] + val drivers = new HashMap[String, DriverInfo]//driverid -> driverinfo + val driverAssignments = new HashMap[String, HashSet[String]]//driverid -> HashSet[workerID] val completedDrivers = new ArrayBuffer[DriverInfo] val waitingDrivers = new ArrayBuffer[DriverInfo] // Drivers currently spooled for scheduling var nextDriverNumber = 0 @@ -209,12 +209,9 @@ private[spark] class Master( val driver = createDriver(description) persistenceEngine.addDriver(driver) waitingDrivers += driver - drivers.add(driver) + drivers += driver.id -> driver + driverAssignments(driver.id) = new HashSet[String] schedule() - - // TODO: It might be good to instead have the submission client poll the master to determine - // the current status of the driver. For now it's simply "fire and forget". - sender ! SubmitDriverResponse(true, Some(driver.id), s"Driver successfully submitted as ${driver.id}") } @@ -226,12 +223,12 @@ private[spark] class Master( sender ! KillDriverResponse(driverId, success = false, msg) } else { logInfo("Asked to kill driver " + driverId) - val driver = drivers.find(_.id == driverId) + val driver = drivers.get(driverId) driver match { case Some(d) => if (waitingDrivers.contains(d)) { waitingDrivers -= d - self ! DriverStateChanged(driverId, DriverState.KILLED, None) + self ! DriverStateChanged(d.id, DriverState.KILLED, None) } else { // We just notify the worker to kill the driver here. The final bookkeeping occurs @@ -254,7 +251,7 @@ private[spark] class Master( } case RequestDriverStatus(driverId) => { - (drivers ++ completedDrivers).find(_.id == driverId) match { + (drivers.values ++ completedDrivers).find(_.id == driverId) match { case Some(driver) => sender ! DriverStatusResponse(found = true, Some(driver.state), driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) @@ -305,12 +302,25 @@ private[spark] class Master( } } - case DriverStateChanged(driverId, state, exception) => { + case DriverStateChanged(driverID, state, exception) => { state match { - case DriverState.ERROR | DriverState.FINISHED | DriverState.KILLED | DriverState.FAILED => - removeDriver(driverId, state, exception) + case DriverState.FINISHED | DriverState.KILLED => + removeDriver(driverID, state, exception) + case DriverState.ERROR | DriverState.FAILED => + drivers.get(driverID) match { + case Some(driver) => + val maxRetry = conf.getInt("spark.driver.maxRetry", 0) + if ((maxRetry != 0 && driver.retriedcountOnMaster < Math.min(maxRetry, workers.size)) || + (maxRetry == 0 && driver.retriedcountOnMaster < workers.size)) { + //recover the driver + relaunchDriver(driver) + } else { + removeDriver(driver.id, state, exception) + } + case None => + } case _ => - throw new Exception(s"Received unexpected state update for driver $driverId: $state") + throw new Exception(s"Received unexpected state update for driver $driverID: $state") } } @@ -350,7 +360,7 @@ private[spark] class Master( } for (driverId <- driverIds) { - drivers.find(_.id == driverId).foreach { driver => + drivers.get(driverId).foreach { driver => driver.worker = Some(worker) driver.state = DriverState.RUNNING worker.drivers(driverId) = driver @@ -373,7 +383,7 @@ private[spark] class Master( case RequestMasterState => { sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, - drivers.toArray, completedDrivers.toArray, state) + drivers.values.toArray, completedDrivers.toArray, state) } case CheckForWorkerTimeOut => { @@ -405,7 +415,7 @@ private[spark] class Master( for (driver <- storedDrivers) { // Here we just read in the list of drivers. Any drivers associated with now-lost workers // will be re-launched when we detect that the worker is missing. - drivers += driver + drivers += (driver.id -> driver) } for (worker <- storedWorkers) { @@ -432,14 +442,14 @@ private[spark] class Master( apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication) // Reschedule drivers which were not claimed by any workers - drivers.filter(_.worker.isEmpty).foreach { d => - logWarning(s"Driver ${d.id} was not found after master recovery") - if (d.desc.supervise) { - logWarning(s"Re-launching ${d.id}") - relaunchDriver(d) + drivers.filter(_._2.worker.isEmpty).foreach { d => + logWarning(s"Driver ${d._1} was not found after master recovery") + if (d._2.desc.supervise) { + logWarning(s"Re-launching ${d._1}") + relaunchDriver(d._2) } else { - removeDriver(d.id, DriverState.ERROR, None) - logWarning(s"Did not re-launch ${d.id} because it was not supervised") + removeDriver(d._1, DriverState.ERROR, None) + logWarning(s"Did not re-launch ${d._1} because it was not supervised") } } @@ -468,7 +478,9 @@ private[spark] class Master( val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) { for (driver <- waitingDrivers) { - if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { + if (worker.memoryFree >= driver.desc.mem && + worker.coresFree >= driver.desc.cores && + !driverAssignments(driver.id).contains(worker.id)) { launchDriver(worker, driver) waitingDrivers -= driver } @@ -582,6 +594,8 @@ private[spark] class Master( def relaunchDriver(driver: DriverInfo) { driver.worker = None driver.state = DriverState.RELAUNCHING + //we add this value for both worker failure and program failure + driver.retriedcountOnMaster += 1 waitingDrivers += driver schedule() } @@ -720,16 +734,18 @@ private[spark] class Master( def launchDriver(worker: WorkerInfo, driver: DriverInfo) { logInfo("Launching driver " + driver.id + " on worker " + worker.id) worker.addDriver(driver) + driverAssignments(driver.id) += worker.id driver.worker = Some(worker) worker.actor ! LaunchDriver(driver.id, driver.desc) driver.state = DriverState.RUNNING } def removeDriver(driverId: String, finalState: DriverState, exception: Option[Exception]) { - drivers.find(d => d.id == driverId) match { + drivers.get(driverId) match { case Some(driver) => logInfo(s"Removing driver: $driverId") - drivers -= driver + drivers -= driverId + driverAssignments -= driver.id completedDrivers += driver persistenceEngine.removeDriver(driver) driver.state = finalState diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index b4df1a0dd4718..e33d63efbb8fc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -28,10 +28,10 @@ import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileUtil, Path} -import org.apache.spark.Logging -import org.apache.spark.deploy.{Command, DriverDescription} +import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.deploy.{DriverDescription, Command} import org.apache.spark.deploy.DeployMessages.DriverStateChanged -import org.apache.spark.deploy.master.DriverState +import org.apache.spark.deploy.master.{DriverInfo, DriverState} import org.apache.spark.deploy.master.DriverState.DriverState /** @@ -39,13 +39,18 @@ import org.apache.spark.deploy.master.DriverState.DriverState */ private[spark] class DriverRunner( val driverId: String, + val driverDesc: DriverDescription, val workDir: File, val sparkHome: File, - val driverDesc: DriverDescription, val worker: ActorRef, - val workerUrl: String) + val workerUrl: String, + val conf: SparkConf) extends Logging { + class FailedTooManyTimesException(retryN: Int) extends Exception { + var retryCount = retryN + } + @volatile var process: Option[Process] = None @volatile var killed = false @@ -54,6 +59,9 @@ private[spark] class DriverRunner( var finalException: Option[Exception] = None var finalExitCode: Option[Int] = None + // Retry counters + var retryNum = 0 + // Decoupled for testing private[deploy] def setClock(_clock: Clock) = clock = _clock private[deploy] def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper @@ -97,7 +105,6 @@ private[spark] class DriverRunner( } finalState = Some(state) - worker ! DriverStateChanged(driverId, state, finalException) } }.start() @@ -184,7 +191,6 @@ private[spark] class DriverRunner( val successfulRunDuration = 5 var keepTrying = !killed - while (keepTrying) { logInfo("Launch Command: " + command.command.mkString("\"", "\" \"", "\"")) @@ -199,15 +205,19 @@ private[spark] class DriverRunner( if (clock.currentTimeMillis() - processStart > successfulRunDuration * 1000) { waitSeconds = 1 } - - if (supervise && exitCode != 0 && !killed) { + val maxRetry = conf.getInt("spark.driver.maxRetry", 0) + keepTrying = supervise && exitCode != 0 && !killed + if (keepTrying) { + retryNum += 1 + if (retryNum > maxRetry) + throw new FailedTooManyTimesException(retryNum) + //sleep only when we want to retry logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.") sleeper.sleep(waitSeconds) waitSeconds = waitSeconds * 2 // exponential back-off + } else { + finalExitCode = Some(exitCode) } - - keepTrying = supervise && exitCode != 0 && !killed - finalExitCode = Some(exitCode) } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 8a71ddda4cb5e..fcfabac91b909 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -265,7 +265,7 @@ private[spark] class Worker( case LaunchDriver(driverId, driverDesc) => { logInfo(s"Asked to launch driver $driverId") - val driver = new DriverRunner(driverId, workDir, sparkHome, driverDesc, self, akkaUrl) + val driver = new DriverRunner(driverId, driverDesc, workDir, sparkHome, self, akkaUrl, conf) drivers(driverId) = driver driver.start() @@ -286,11 +286,11 @@ private[spark] class Worker( case DriverStateChanged(driverId, state, exception) => { state match { case DriverState.ERROR => - logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") + logWarning(s"Driver ${driverId} failed with unrecoverable exception: ${exception.get}") case DriverState.FINISHED => - logInfo(s"Driver $driverId exited successfully") + logInfo(s"Driver ${driverId} exited successfully") case DriverState.KILLED => - logInfo(s"Driver $driverId was killed by user") + logInfo(s"Driver ${driverId} was killed by user") } masterLock.synchronized { master ! DriverStateChanged(driverId, state, exception) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala index 85200ab0e102d..1213f46eda593 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala @@ -22,7 +22,7 @@ import scala.xml.Node import akka.pattern.ask import javax.servlet.http.HttpServletRequest -import org.json4s.JValue +import net.liftweb.json.JsonAST.JValue import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} @@ -144,7 +144,7 @@ private[spark] class IndexPage(parent: WorkerWebUI) { def driverRow(driver: DriverRunner): Seq[Node] = {