"
- var startTime = -1L
- var endTime = -1L
- var viewAcls = ""
- var adminAcls = ""
-
- def applicationStarted = startTime != -1
-
- def applicationCompleted = endTime != -1
-
- def applicationDuration: Long = {
- val difference = endTime - startTime
- if (applicationStarted && applicationCompleted && difference > 0) difference else -1L
- }
+ var appName: Option[String] = None
+ var appId: Option[String] = None
+ var sparkUser: Option[String] = None
+ var startTime: Option[Long] = None
+ var endTime: Option[Long] = None
+ var viewAcls: Option[String] = None
+ var adminAcls: Option[String] = None
override def onApplicationStart(applicationStart: SparkListenerApplicationStart) {
- appName = applicationStart.appName
- startTime = applicationStart.time
- sparkUser = applicationStart.sparkUser
+ appName = Some(applicationStart.appName)
+ appId = applicationStart.appId
+ startTime = Some(applicationStart.time)
+ sparkUser = Some(applicationStart.sparkUser)
}
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) {
- endTime = applicationEnd.time
+ endTime = Some(applicationEnd.time)
}
override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) {
synchronized {
val environmentDetails = environmentUpdate.environmentDetails
val allProperties = environmentDetails("Spark Properties").toMap
- viewAcls = allProperties.getOrElse("spark.ui.view.acls", "")
- adminAcls = allProperties.getOrElse("spark.admin.acls", "")
+ viewAcls = allProperties.get("spark.ui.view.acls")
+ adminAcls = allProperties.get("spark.admin.acls")
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b86cfbfa48fbe..b2774dfc47553 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -164,7 +164,7 @@ class DAGScheduler(
*/
def executorHeartbeatReceived(
execId: String,
- taskMetrics: Array[(Long, Int, TaskMetrics)], // (taskId, stageId, metrics)
+ taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics)
blockManagerId: BlockManagerId): Boolean = {
listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
implicit val timeout = Timeout(600 seconds)
@@ -241,9 +241,9 @@ class DAGScheduler(
callSite: CallSite)
: Stage =
{
+ val parentStages = getParentStages(rdd, jobId)
val id = nextStageId.getAndIncrement()
- val stage =
- new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
+ val stage = new Stage(id, rdd, numTasks, shuffleDep, parentStages, jobId, callSite)
stageIdToStage(id) = stage
updateJobIdStageIdMaps(jobId, stage)
stage
@@ -507,11 +507,16 @@ class DAGScheduler(
resultHandler: (Int, U) => Unit,
properties: Properties = null)
{
+ val start = System.nanoTime
val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
waiter.awaitResult() match {
- case JobSucceeded => {}
+ case JobSucceeded => {
+ logInfo("Job %d finished: %s, took %f s".format
+ (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
+ }
case JobFailed(exception: Exception) =>
- logInfo("Failed to run " + callSite.shortForm)
+ logInfo("Job %d failed: %s, took %f s".format
+ (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
throw exception
}
}
@@ -677,7 +682,10 @@ class DAGScheduler(
}
private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) {
- listenerBus.post(SparkListenerTaskStart(task.stageId, taskInfo))
+ // Note that there is a chance that this task is launched after the stage is cancelled.
+ // In that case, we wouldn't have the stage anymore in stageIdToStage.
+ val stageAttemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1)
+ listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo))
submitWaitingStages()
}
@@ -695,8 +703,8 @@ class DAGScheduler(
// is in the process of getting stopped.
val stageFailedMessage = "Stage cancelled because SparkContext was shut down"
runningStages.foreach { stage =>
- stage.info.stageFailed(stageFailedMessage)
- listenerBus.post(SparkListenerStageCompleted(stage.info))
+ stage.latestInfo.stageFailed(stageFailedMessage)
+ listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
}
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
}
@@ -781,7 +789,16 @@ class DAGScheduler(
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()
- var tasks = ArrayBuffer[Task[_]]()
+
+ // First figure out the indexes of partition ids to compute.
+ val partitionsToCompute: Seq[Int] = {
+ if (stage.isShuffleMap) {
+ (0 until stage.numPartitions).filter(id => stage.outputLocs(id) == Nil)
+ } else {
+ val job = stage.resultOfJob.get
+ (0 until job.numPartitions).filter(id => !job.finished(id))
+ }
+ }
val properties = if (jobIdToActiveJob.contains(jobId)) {
jobIdToActiveJob(stage.jobId).properties
@@ -795,7 +812,8 @@ class DAGScheduler(
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
- listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))
+ stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size))
+ listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
// Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
@@ -826,20 +844,19 @@ class DAGScheduler(
return
}
- if (stage.isShuffleMap) {
- for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
- val locs = getPreferredLocs(stage.rdd, p)
- val part = stage.rdd.partitions(p)
- tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs)
+ val tasks: Seq[Task[_]] = if (stage.isShuffleMap) {
+ partitionsToCompute.map { id =>
+ val locs = getPreferredLocs(stage.rdd, id)
+ val part = stage.rdd.partitions(id)
+ new ShuffleMapTask(stage.id, taskBinary, part, locs)
}
} else {
- // This is a final stage; figure out its job's missing partitions
val job = stage.resultOfJob.get
- for (id <- 0 until job.numPartitions if !job.finished(id)) {
+ partitionsToCompute.map { id =>
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
- tasks += new ResultTask(stage.id, taskBinary, part, locs, id)
+ new ResultTask(stage.id, taskBinary, part, locs, id)
}
}
@@ -869,11 +886,11 @@ class DAGScheduler(
logDebug("New pending tasks: " + stage.pendingTasks)
taskScheduler.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
- stage.info.submissionTime = Some(clock.getTime())
+ stage.latestInfo.submissionTime = Some(clock.getTime())
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should post
// SparkListenerStageCompleted here in case there are no tasks to run.
- listenerBus.post(SparkListenerStageCompleted(stage.info))
+ listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
runningStages -= stage
@@ -892,8 +909,9 @@ class DAGScheduler(
// The success case is dealt with separately below, since we need to compute accumulator
// updates before posting.
if (event.reason != Success) {
- listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo,
- event.taskMetrics))
+ val attemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1)
+ listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason,
+ event.taskInfo, event.taskMetrics))
}
if (!stageIdToStage.contains(task.stageId)) {
@@ -902,14 +920,19 @@ class DAGScheduler(
}
val stage = stageIdToStage(task.stageId)
- def markStageAsFinished(stage: Stage) = {
- val serviceTime = stage.info.submissionTime match {
+ def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = {
+ val serviceTime = stage.latestInfo.submissionTime match {
case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0)
case _ => "Unknown"
}
- logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
- stage.info.completionTime = Some(clock.getTime())
- listenerBus.post(SparkListenerStageCompleted(stage.info))
+ if (errorMessage.isEmpty) {
+ logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
+ stage.latestInfo.completionTime = Some(clock.getTime())
+ } else {
+ stage.latestInfo.stageFailed(errorMessage.get)
+ logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime))
+ }
+ listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
runningStages -= stage
}
event.reason match {
@@ -924,7 +947,7 @@ class DAGScheduler(
val name = acc.name.get
val stringPartialValue = Accumulators.stringifyPartialValue(partialValue)
val stringValue = Accumulators.stringifyValue(acc.value)
- stage.info.accumulables(id) = AccumulableInfo(id, name, stringValue)
+ stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue)
event.taskInfo.accumulables +=
AccumulableInfo(id, name, Some(stringPartialValue), stringValue)
}
@@ -935,8 +958,8 @@ class DAGScheduler(
logError(s"Failed to update accumulators for $task", e)
}
}
- listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo,
- event.taskMetrics))
+ listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType,
+ event.reason, event.taskInfo, event.taskMetrics))
stage.pendingTasks -= task
task match {
case rt: ResultTask[_, _] =>
@@ -1027,30 +1050,39 @@ class DAGScheduler(
stage.pendingTasks += task
case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
- // Mark the stage that the reducer was in as unrunnable
val failedStage = stageIdToStage(task.stageId)
- runningStages -= failedStage
- // TODO: Cancel running tasks in the stage
- logInfo("Marking " + failedStage + " (" + failedStage.name +
- ") for resubmision due to a fetch failure")
- // Mark the map whose fetch failed as broken in the map stage
val mapStage = shuffleToMapStage(shuffleId)
- if (mapId != -1) {
- mapStage.removeOutputLoc(mapId, bmAddress)
- mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+
+ // It is likely that we receive multiple FetchFailed for a single stage (because we have
+ // multiple tasks running concurrently on different executors). In that case, it is possible
+ // the fetch failure has already been handled by the scheduler.
+ if (runningStages.contains(failedStage)) {
+ logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
+ s"due to a fetch failure from $mapStage (${mapStage.name})")
+ markStageAsFinished(failedStage, Some("Fetch failure"))
+ runningStages -= failedStage
}
- logInfo("The failed fetch was from " + mapStage + " (" + mapStage.name +
- "); marking it for resubmission")
+
if (failedStages.isEmpty && eventProcessActor != null) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled. eventProcessActor may be
// null during unit tests.
+ // TODO: Cancel running tasks in the stage
import env.actorSystem.dispatcher
+ logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
+ s"$failedStage (${failedStage.name}) due to fetch failure")
env.actorSystem.scheduler.scheduleOnce(
RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages)
}
failedStages += failedStage
failedStages += mapStage
+
+ // Mark the map whose fetch failed as broken in the map stage
+ if (mapId != -1) {
+ mapStage.removeOutputLoc(mapId, bmAddress)
+ mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+ }
+
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, Some(task.epoch))
@@ -1142,7 +1174,7 @@ class DAGScheduler(
}
val dependentJobs: Seq[ActiveJob] =
activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
- failedStage.info.completionTime = Some(clock.getTime())
+ failedStage.latestInfo.completionTime = Some(clock.getTime())
for (job <- dependentJobs) {
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
}
@@ -1182,8 +1214,8 @@ class DAGScheduler(
if (runningStages.contains(stage)) {
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
- stage.info.stageFailed(failureReason)
- listenerBus.post(SparkListenerStageCompleted(stage.info))
+ stage.latestInfo.stageFailed(failureReason)
+ listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
} catch {
case e: UnsupportedOperationException =>
logInfo(s"Could not cancel tasks for stage $stageId", e)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 370fcd85aa680..64b32ae0edaac 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -29,6 +29,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, SparkConf, SparkContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.io.CompressionCodec
+import org.apache.spark.SPARK_VERSION
import org.apache.spark.util.{FileLogger, JsonProtocol, Utils}
/**
@@ -44,11 +45,14 @@ import org.apache.spark.util.{FileLogger, JsonProtocol, Utils}
private[spark] class EventLoggingListener(
appName: String,
sparkConf: SparkConf,
- hadoopConf: Configuration = SparkHadoopUtil.get.newConfiguration())
+ hadoopConf: Configuration)
extends SparkListener with Logging {
import EventLoggingListener._
+ def this(appName: String, sparkConf: SparkConf) =
+ this(appName, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf))
+
private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false)
private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false)
private val testing = sparkConf.getBoolean("spark.eventLog.testing", false)
@@ -83,7 +87,7 @@ private[spark] class EventLoggingListener(
sparkConf.get("spark.io.compression.codec", CompressionCodec.DEFAULT_COMPRESSION_CODEC)
logger.newFile(COMPRESSION_CODEC_PREFIX + codec)
}
- logger.newFile(SPARK_VERSION_PREFIX + SparkContext.SPARK_VERSION)
+ logger.newFile(SPARK_VERSION_PREFIX + SPARK_VERSION)
logger.newFile(LOG_PREFIX + logger.fileIndex)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index e9bfee2248e5b..29879b374b801 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -23,7 +23,7 @@ package org.apache.spark.scheduler
*/
private[spark] class JobWaiter[T](
dagScheduler: DAGScheduler,
- jobId: Int,
+ val jobId: Int,
totalTasks: Int,
resultHandler: (Int, T) => Unit)
extends JobListener {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index e41e0a9841691..a0be8307eff27 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -31,4 +31,12 @@ private[spark] trait SchedulerBackend {
def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit =
throw new UnsupportedOperationException
def isReady(): Boolean = true
+
+ /**
+ * The application ID associated with the job, if any.
+ *
+ * @return The application ID, or None if the backend does not provide an ID.
+ */
+ def applicationId(): Option[String] = None
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index d01d318633877..86afe3bd5265f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -39,7 +39,8 @@ case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Propert
case class SparkListenerStageCompleted(stageInfo: StageInfo) extends SparkListenerEvent
@DeveloperApi
-case class SparkListenerTaskStart(stageId: Int, taskInfo: TaskInfo) extends SparkListenerEvent
+case class SparkListenerTaskStart(stageId: Int, stageAttemptId: Int, taskInfo: TaskInfo)
+ extends SparkListenerEvent
@DeveloperApi
case class SparkListenerTaskGettingResult(taskInfo: TaskInfo) extends SparkListenerEvent
@@ -47,6 +48,7 @@ case class SparkListenerTaskGettingResult(taskInfo: TaskInfo) extends SparkListe
@DeveloperApi
case class SparkListenerTaskEnd(
stageId: Int,
+ stageAttemptId: Int,
taskType: String,
reason: TaskEndReason,
taskInfo: TaskInfo,
@@ -65,25 +67,30 @@ case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(S
extends SparkListenerEvent
@DeveloperApi
-case class SparkListenerBlockManagerAdded(blockManagerId: BlockManagerId, maxMem: Long)
+case class SparkListenerBlockManagerAdded(time: Long, blockManagerId: BlockManagerId, maxMem: Long)
extends SparkListenerEvent
@DeveloperApi
-case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId)
+case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockManagerId)
extends SparkListenerEvent
@DeveloperApi
case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent
+/**
+ * Periodic updates from executors.
+ * @param execId executor id
+ * @param taskMetrics sequence of (task id, stage id, stage attempt, metrics)
+ */
@DeveloperApi
case class SparkListenerExecutorMetricsUpdate(
execId: String,
- taskMetrics: Seq[(Long, Int, TaskMetrics)])
+ taskMetrics: Seq[(Long, Int, Int, TaskMetrics)])
extends SparkListenerEvent
@DeveloperApi
-case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String)
- extends SparkListenerEvent
+case class SparkListenerApplicationStart(appName: String, appId: Option[String], time: Long,
+ sparkUser: String) extends SparkListenerEvent
@DeveloperApi
case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 800905413d145..071568cdfb429 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -43,6 +43,9 @@ import org.apache.spark.util.CallSite
* stage, the callSite gives the user code that created the RDD being shuffled. For a result
* stage, the callSite gives the user code that executes the associated action (e.g. count()).
*
+ * A single stage can consist of multiple attempts. In that case, the latestInfo field will
+ * be updated for each attempt.
+ *
*/
private[spark] class Stage(
val id: Int,
@@ -71,8 +74,8 @@ private[spark] class Stage(
val name = callSite.shortForm
val details = callSite.longForm
- /** Pointer to the [StageInfo] object, set by DAGScheduler. */
- var info: StageInfo = StageInfo.fromStage(this)
+ /** Pointer to the latest [StageInfo] object, set by DAGScheduler. */
+ var latestInfo: StageInfo = StageInfo.fromStage(this)
def isAvailable: Boolean = {
if (!isShuffleMap) {
@@ -116,6 +119,7 @@ private[spark] class Stage(
}
}
+ /** Return a new attempt id, starting with 0. */
def newAttemptId(): Int = {
val id = nextAttemptId
nextAttemptId += 1
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 2a407e47a05bd..c6dc3369ba5cc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -29,6 +29,7 @@ import org.apache.spark.storage.RDDInfo
@DeveloperApi
class StageInfo(
val stageId: Int,
+ val attemptId: Int,
val name: String,
val numTasks: Int,
val rddInfos: Seq[RDDInfo],
@@ -56,9 +57,15 @@ private[spark] object StageInfo {
* shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a
* sequence of narrow dependencies should also be associated with this Stage.
*/
- def fromStage(stage: Stage): StageInfo = {
+ def fromStage(stage: Stage, numTasks: Option[Int] = None): StageInfo = {
val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd)
val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos
- new StageInfo(stage.id, stage.name, stage.numTasks, rddInfos, stage.details)
+ new StageInfo(
+ stage.id,
+ stage.attemptId,
+ stage.name,
+ numTasks.getOrElse(stage.numTasks),
+ rddInfos,
+ stage.details)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 1a0b877c8a5e1..1c1ce666eab0f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -64,4 +64,12 @@ private[spark] trait TaskScheduler {
*/
def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)],
blockManagerId: BlockManagerId): Boolean
+
+ /**
+ * The application ID associated with the job, if any.
+ *
+ * @return The application ID, or None if the backend does not provide an ID.
+ */
+ def applicationId(): Option[String] = None
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 6c0d1b2752a81..633e892554c50 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -333,12 +333,12 @@ private[spark] class TaskSchedulerImpl(
execId: String,
taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics
blockManagerId: BlockManagerId): Boolean = {
- val metricsWithStageIds = taskMetrics.flatMap {
- case (id, metrics) => {
+
+ val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
+ taskMetrics.flatMap { case (id, metrics) =>
taskIdToTaskSetId.get(id)
.flatMap(activeTaskSets.get)
- .map(_.stageId)
- .map(x => (id, x, metrics))
+ .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
}
}
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
@@ -491,6 +491,9 @@ private[spark] class TaskSchedulerImpl(
}
}
}
+
+ override def applicationId(): Option[String] = backend.applicationId()
+
}
@@ -535,4 +538,5 @@ private[spark] object TaskSchedulerImpl {
retval.toList
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
index 613fa7850bb25..c3ad325156f53 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
@@ -31,9 +31,5 @@ private[spark] class TaskSet(
val properties: Properties) {
val id: String = stageId + "." + attempt
- def kill(interruptThread: Boolean) {
- tasks.foreach(_.kill(interruptThread))
- }
-
override def toString: String = "TaskSet " + id
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 2a3711ae2a78c..9a0cb1c6c6ccd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -51,12 +51,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
val conf = scheduler.sc.conf
private val timeout = AkkaUtils.askTimeout(conf)
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
- // Submit tasks only after (registered resources / total expected resources)
+ // Submit tasks only after (registered resources / total expected resources)
// is equal to at least this value, that is double between 0 and 1.
var minRegisteredRatio =
math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0))
// Submit tasks after maxRegisteredWaitingTime milliseconds
- // if minRegisteredRatio has not yet been reached
+ // if minRegisteredRatio has not yet been reached
val maxRegisteredWaitingTime =
conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000)
val createTime = System.currentTimeMillis()
@@ -292,7 +292,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
conf.set("spark.ui.filters", filterName)
conf.set(s"spark.$filterName.params", filterParams)
- JettyUtils.addFilters(scheduler.sc.ui.getHandlers, conf)
+ scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index d99c76117c168..ee10aa061f4e9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -17,10 +17,10 @@
package org.apache.spark.scheduler.cluster
-import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
-import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.{Logging, SparkContext, SparkEnv}
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.TaskSchedulerImpl
private[spark] class SimrSchedulerBackend(
@@ -38,22 +38,25 @@ private[spark] class SimrSchedulerBackend(
override def start() {
super.start()
- val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
- sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port"),
+ val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
+ SparkEnv.driverActorSystemName,
+ sc.conf.get("spark.driver.host"),
+ sc.conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
- val conf = new Configuration()
+ val conf = SparkHadoopUtil.get.newConfiguration(sc.conf)
val fs = FileSystem.get(conf)
+ val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("")
logInfo("Writing to HDFS file: " + driverFilePath)
logInfo("Writing Akka address: " + driverUrl)
- logInfo("Writing Spark UI Address: " + sc.ui.appUIAddress)
+ logInfo("Writing Spark UI Address: " + appUIAddress)
// Create temporary file to prevent race condition where executors get empty driverUrl file
val temp = fs.create(tmpPath, true)
temp.writeUTF(driverUrl)
temp.writeInt(maxCores)
- temp.writeUTF(sc.ui.appUIAddress)
+ temp.writeUTF(appUIAddress)
temp.close()
// "Atomic" rename
@@ -61,9 +64,10 @@ private[spark] class SimrSchedulerBackend(
}
override def stop() {
- val conf = new Configuration()
+ val conf = SparkHadoopUtil.get.newConfiguration(sc.conf)
val fs = FileSystem.get(conf)
fs.delete(new Path(driverFilePath), false)
super.stop()
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 589dba2e40d20..5c5ecc8434d78 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -17,7 +17,7 @@
package org.apache.spark.scheduler.cluster
-import org.apache.spark.{Logging, SparkConf, SparkContext}
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.deploy.client.{AppClient, AppClientListener}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl}
@@ -34,6 +34,10 @@ private[spark] class SparkDeploySchedulerBackend(
var client: AppClient = null
var stopping = false
var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
+ @volatile var appId: String = _
+
+ val registrationLock = new Object()
+ var registrationDone = false
val maxCores = conf.getOption("spark.cores.max").map(_.toInt)
val totalExpectedCores = maxCores.getOrElse(0)
@@ -42,8 +46,10 @@ private[spark] class SparkDeploySchedulerBackend(
super.start()
// The endpoint for executors to talk to us
- val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
- conf.get("spark.driver.host"), conf.get("spark.driver.port"),
+ val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
+ SparkEnv.driverActorSystemName,
+ conf.get("spark.driver.host"),
+ conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}")
val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
@@ -61,11 +67,15 @@ private[spark] class SparkDeploySchedulerBackend(
val javaOpts = sparkJavaOpts ++ extraJavaOpts
val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend",
args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts)
+ val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("")
+ val eventLogDir = sc.eventLogger.map(_.logDir)
val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
- sc.ui.appUIAddress, sc.eventLogger.map(_.logDir))
+ appUIAddress, eventLogDir)
client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf)
client.start()
+
+ waitForRegistration()
}
override def stop() {
@@ -79,15 +89,19 @@ private[spark] class SparkDeploySchedulerBackend(
override def connected(appId: String) {
logInfo("Connected to Spark cluster with app ID " + appId)
+ this.appId = appId
+ notifyContext()
}
override def disconnected() {
+ notifyContext()
if (!stopping) {
logWarning("Disconnected from Spark cluster! Waiting for reconnection...")
}
}
override def dead(reason: String) {
+ notifyContext()
if (!stopping) {
logError("Application has been killed. Reason: " + reason)
scheduler.error(reason)
@@ -114,4 +128,22 @@ private[spark] class SparkDeploySchedulerBackend(
override def sufficientResourcesRegistered(): Boolean = {
totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio
}
+
+ override def applicationId(): Option[String] = Option(appId)
+
+ private def waitForRegistration() = {
+ registrationLock.synchronized {
+ while (!registrationDone) {
+ registrationLock.wait()
+ }
+ }
+ }
+
+ private def notifyContext() = {
+ registrationLock.synchronized {
+ registrationDone = true
+ registrationLock.notifyAll()
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 9f45400bcf852..64568409dbafd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -28,7 +28,7 @@ import org.apache.mesos.{Scheduler => MScheduler}
import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
-import org.apache.spark.{Logging, SparkContext, SparkException}
+import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException}
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
@@ -71,9 +71,6 @@ private[spark] class CoarseMesosSchedulerBackend(
val taskIdToSlaveId = new HashMap[Int, String]
val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed
- val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
- "Spark home is not set; set it through the spark.home system " +
- "property, the SPARK_HOME environment variable or the SparkContext constructor"))
val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0)
@@ -110,6 +107,11 @@ private[spark] class CoarseMesosSchedulerBackend(
}
def createCommand(offer: Offer, numCores: Int): CommandInfo = {
+ val executorSparkHome = conf.getOption("spark.mesos.executor.home")
+ .orElse(sc.getSparkHome())
+ .getOrElse {
+ throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!")
+ }
val environment = Environment.newBuilder()
val extraClassPath = conf.getOption("spark.executor.extraClassPath")
extraClassPath.foreach { cp =>
@@ -122,6 +124,12 @@ private[spark] class CoarseMesosSchedulerBackend(
val extraLibraryPath = conf.getOption(libraryPathOption).map(p => s"-Djava.library.path=$p")
val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ")
+ environment.addVariables(
+ Environment.Variable.newBuilder()
+ .setName("SPARK_EXECUTOR_OPTS")
+ .setValue(extraOpts)
+ .build())
+
sc.executorEnvs.foreach { case (key, value) =>
environment.addVariables(Environment.Variable.newBuilder()
.setName(key)
@@ -130,25 +138,26 @@ private[spark] class CoarseMesosSchedulerBackend(
}
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
- val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
+ val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
+ SparkEnv.driverActorSystemName,
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
val uri = conf.get("spark.executor.uri", null)
if (uri == null) {
- val runScript = new File(sparkHome, "./bin/spark-class").getCanonicalPath
+ val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath
command.setValue(
- "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %s %d".format(
- runScript, extraOpts, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
+ "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format(
+ runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
("cd %s*; " +
- "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %s %d")
- .format(basename, extraOpts, driverUrl, offer.getSlaveId.getValue,
+ "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d")
+ .format(basename, driverUrl, offer.getSlaveId.getValue,
offer.getHostname, numCores))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
@@ -300,4 +309,5 @@ private[spark] class CoarseMesosSchedulerBackend(
logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue))
slaveLost(d, s)
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index c717e7c621a8f..a9ef126f5de0e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -86,10 +86,26 @@ private[spark] class MesosSchedulerBackend(
}
def createExecutorInfo(execId: String): ExecutorInfo = {
- val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
- "Spark home is not set; set it through the spark.home system " +
- "property, the SPARK_HOME environment variable or the SparkContext constructor"))
+ val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home")
+ .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility
+ .getOrElse {
+ throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!")
+ }
val environment = Environment.newBuilder()
+ sc.conf.getOption("spark.executor.extraClassPath").foreach { cp =>
+ environment.addVariables(
+ Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build())
+ }
+ val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
+ val extraLibraryPath = sc.conf.getOption("spark.executor.extraLibraryPath").map { lp =>
+ s"-Djava.library.path=$lp"
+ }
+ val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ")
+ environment.addVariables(
+ Environment.Variable.newBuilder()
+ .setName("SPARK_EXECUTOR_OPTS")
+ .setValue(extraOpts)
+ .build())
sc.executorEnvs.foreach { case (key, value) =>
environment.addVariables(Environment.Variable.newBuilder()
.setName(key)
@@ -100,7 +116,7 @@ private[spark] class MesosSchedulerBackend(
.setEnvironment(environment)
val uri = sc.conf.get("spark.executor.uri", null)
if (uri == null) {
- command.setValue(new File(sparkHome, "/sbin/spark-executor").getCanonicalPath)
+ command.setValue(new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath)
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
@@ -333,4 +349,5 @@ private[spark] class MesosSchedulerBackend(
// TODO: query Mesos for number of cores
override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8)
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index bec9502f20466..9ea25c2bc7090 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -114,4 +114,5 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores:
override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
localActor ! StatusUpdate(taskId, state, serializedData)
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 87ef9bb0b43c6..d6386f8c06fff 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -27,9 +27,9 @@ import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
import org.apache.spark._
import org.apache.spark.broadcast.HttpBroadcast
+import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage._
-import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}
import org.apache.spark.util.BoundedPriorityQueue
import org.apache.spark.util.collection.CompactBuffer
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
similarity index 82%
rename from core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
rename to core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index b8f5d3a5b02aa..439981d232349 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -15,22 +15,23 @@
* limitations under the License.
*/
-package org.apache.spark.storage
+package org.apache.spark.shuffle
import java.io.File
+import java.nio.ByteBuffer
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConversions._
-import org.apache.spark.Logging
+import org.apache.spark.{SparkEnv, SparkConf, Logging}
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.ShuffleManager
-import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
+import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
+import org.apache.spark.storage._
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
-import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.executor.ShuffleWriteMetrics
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {
@@ -61,20 +62,18 @@ private[spark] trait ShuffleWriterGroup {
* each block stored in each file. In order to find the location of a shuffle block, we search the
* files within a ShuffleFileGroups associated with the block's reducer.
*/
-// TODO: Factor this into a separate class for each ShuffleManager implementation
+
private[spark]
-class ShuffleBlockManager(blockManager: BlockManager,
- shuffleManager: ShuffleManager) extends Logging {
- def conf = blockManager.conf
+class FileShuffleBlockManager(conf: SparkConf)
+ extends ShuffleBlockManager with Logging {
+
+ private lazy val blockManager = SparkEnv.get.blockManager
// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
// TODO: Remove this once the shuffle file consolidation feature is stable.
- val consolidateShuffleFiles =
+ private val consolidateShuffleFiles =
conf.getBoolean("spark.shuffle.consolidateFiles", false)
- // Are we using sort-based shuffle?
- val sortBasedShuffle = shuffleManager.isInstanceOf[SortShuffleManager]
-
private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024
/**
@@ -93,22 +92,11 @@ class ShuffleBlockManager(blockManager: BlockManager,
val completedMapTasks = new ConcurrentLinkedQueue[Int]()
}
- type ShuffleId = Int
private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]
private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf)
- /**
- * Register a completed map without getting a ShuffleWriterGroup. Used by sort-based shuffle
- * because it just writes a single file by itself.
- */
- def addCompletedMap(shuffleId: Int, mapId: Int, numBuckets: Int): Unit = {
- shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
- val shuffleState = shuffleStates(shuffleId)
- shuffleState.completedMapTasks.add(mapId)
- }
-
/**
* Get a ShuffleWriterGroup for the given map task, which will register it as complete
* when the writers are closed successfully
@@ -168,7 +156,7 @@ class ShuffleBlockManager(blockManager: BlockManager,
val filename = physicalFileName(shuffleId, bucketId, fileId)
blockManager.diskBlockManager.getFile(filename)
}
- val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files)
+ val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files)
shuffleState.allFileGroups.add(fileGroup)
fileGroup
}
@@ -179,19 +167,28 @@ class ShuffleBlockManager(blockManager: BlockManager,
}
}
- /**
- * Returns the physical file segment in which the given BlockId is located.
- * This function should only be called if shuffle file consolidation is enabled, as it is
- * an error condition if we don't find the expected block.
- */
- def getBlockLocation(id: ShuffleBlockId): FileSegment = {
- // Search all file groups associated with this shuffle.
- val shuffleState = shuffleStates(id.shuffleId)
- for (fileGroup <- shuffleState.allFileGroups) {
- val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId)
- if (segment.isDefined) { return segment.get }
+ override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
+ val segment = getBlockData(blockId)
+ Some(segment.nioByteBuffer())
+ }
+
+ override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
+ if (consolidateShuffleFiles) {
+ // Search all file groups associated with this shuffle.
+ val shuffleState = shuffleStates(blockId.shuffleId)
+ val iter = shuffleState.allFileGroups.iterator
+ while (iter.hasNext) {
+ val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId)
+ if (segmentOpt.isDefined) {
+ val segment = segmentOpt.get
+ return new FileSegmentManagedBuffer(segment.file, segment.offset, segment.length)
+ }
+ }
+ throw new IllegalStateException("Failed to find shuffle block: " + blockId)
+ } else {
+ val file = blockManager.diskBlockManager.getFile(blockId)
+ new FileSegmentManagedBuffer(file, 0, file.length)
}
- throw new IllegalStateException("Failed to find shuffle block: " + id)
}
/** Remove all the blocks / files and metadata related to a particular shuffle. */
@@ -207,14 +204,7 @@ class ShuffleBlockManager(blockManager: BlockManager,
private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
shuffleStates.get(shuffleId) match {
case Some(state) =>
- if (sortBasedShuffle) {
- // There's a single block ID for each map, plus an index file for it
- for (mapId <- state.completedMapTasks) {
- val blockId = new ShuffleBlockId(shuffleId, mapId, 0)
- blockManager.diskBlockManager.getFile(blockId).delete()
- blockManager.diskBlockManager.getFile(blockId.name + ".index").delete()
- }
- } else if (consolidateShuffleFiles) {
+ if (consolidateShuffleFiles) {
for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
file.delete()
}
@@ -240,13 +230,13 @@ class ShuffleBlockManager(blockManager: BlockManager,
shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
}
- def stop() {
+ override def stop() {
metadataCleaner.cancel()
}
}
private[spark]
-object ShuffleBlockManager {
+object FileShuffleBlockManager {
/**
* A group of shuffle files, one per reducer.
* A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
new file mode 100644
index 0000000000000..4ab34336d3f01
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io._
+import java.nio.ByteBuffer
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer}
+import org.apache.spark.storage._
+
+/**
+ * Create and maintain the shuffle blocks' mapping between logic block and physical file location.
+ * Data of shuffle blocks from the same map task are stored in a single consolidated data file.
+ * The offsets of the data blocks in the data file are stored in a separate index file.
+ *
+ * We use the name of the shuffle data's shuffleBlockId with reduce ID set to 0 and add ".data"
+ * as the filename postfix for data file, and ".index" as the filename postfix for index file.
+ *
+ */
+private[spark]
+class IndexShuffleBlockManager extends ShuffleBlockManager {
+
+ private lazy val blockManager = SparkEnv.get.blockManager
+
+ /**
+ * Mapping to a single shuffleBlockId with reduce ID 0.
+ * */
+ def consolidateId(shuffleId: Int, mapId: Int): ShuffleBlockId = {
+ ShuffleBlockId(shuffleId, mapId, 0)
+ }
+
+ def getDataFile(shuffleId: Int, mapId: Int): File = {
+ blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, 0))
+ }
+
+ private def getIndexFile(shuffleId: Int, mapId: Int): File = {
+ blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, 0))
+ }
+
+ /**
+ * Remove data file and index file that contain the output data from one map.
+ * */
+ def removeDataByMap(shuffleId: Int, mapId: Int): Unit = {
+ var file = getDataFile(shuffleId, mapId)
+ if (file.exists()) {
+ file.delete()
+ }
+
+ file = getIndexFile(shuffleId, mapId)
+ if (file.exists()) {
+ file.delete()
+ }
+ }
+
+ /**
+ * Write an index file with the offsets of each block, plus a final offset at the end for the
+ * end of the output file. This will be used by getBlockLocation to figure out where each block
+ * begins and ends.
+ * */
+ def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]) = {
+ val indexFile = getIndexFile(shuffleId, mapId)
+ val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
+ try {
+ // We take in lengths of each block, need to convert it to offsets.
+ var offset = 0L
+ out.writeLong(offset)
+
+ for (length <- lengths) {
+ offset += length
+ out.writeLong(offset)
+ }
+ } finally {
+ out.close()
+ }
+ }
+
+ override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
+ Some(getBlockData(blockId).nioByteBuffer())
+ }
+
+ override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
+ // The block is actually going to be a range of a single map output file for this map, so
+ // find out the consolidated file, then the offset within that from our index
+ val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
+
+ val in = new DataInputStream(new FileInputStream(indexFile))
+ try {
+ in.skip(blockId.reduceId * 8)
+ val offset = in.readLong()
+ val nextOffset = in.readLong()
+ new FileSegmentManagedBuffer(
+ getDataFile(blockId.shuffleId, blockId.mapId),
+ offset,
+ nextOffset - offset)
+ } finally {
+ in.close()
+ }
+ }
+
+ override def stop() = {}
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
new file mode 100644
index 0000000000000..63863cc0250a3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.network.ManagedBuffer
+import org.apache.spark.storage.ShuffleBlockId
+
+private[spark]
+trait ShuffleBlockManager {
+ type ShuffleId = Int
+
+ /**
+ * Get shuffle block data managed by the local ShuffleBlockManager.
+ * @return Some(ByteBuffer) if block found, otherwise None.
+ */
+ def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer]
+
+ def getBlockData(blockId: ShuffleBlockId): ManagedBuffer
+
+ def stop(): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
index 9c859b8b4a118..801ae54086053 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -49,8 +49,13 @@ private[spark] trait ShuffleManager {
endPartition: Int,
context: TaskContext): ShuffleReader[K, C]
- /** Remove a shuffle's metadata from the ShuffleManager. */
- def unregisterShuffle(shuffleId: Int)
+ /**
+ * Remove a shuffle's metadata from the ShuffleManager.
+ * @return true if the metadata removed successfully, otherwise false.
+ */
+ def unregisterShuffle(shuffleId: Int): Boolean
+
+ def shuffleBlockManager: ShuffleBlockManager
/** Shut down this ShuffleManager. */
def stop(): Unit
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 12b475658e29d..6cf9305977a3c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -21,10 +21,9 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import org.apache.spark._
-import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
private[hash] object BlockStoreShuffleFetcher extends Logging {
@@ -32,8 +31,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
shuffleId: Int,
reduceId: Int,
context: TaskContext,
- serializer: Serializer,
- shuffleMetrics: ShuffleReadMetrics)
+ serializer: Serializer)
: Iterator[T] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
@@ -74,7 +72,13 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
}
}
- val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics)
+ val blockFetcherItr = new ShuffleBlockFetcherIterator(
+ context,
+ SparkEnv.get.blockTransferService,
+ blockManager,
+ blocksByAddress,
+ serializer,
+ SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)
val itr = blockFetcherItr.flatMap(unpackBlock)
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
index df98d18fa8193..62e0629b34400 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
@@ -25,6 +25,9 @@ import org.apache.spark.shuffle._
* mapper (possibly reusing these across waves of tasks).
*/
private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
+
+ private val fileShuffleBlockManager = new FileShuffleBlockManager(conf)
+
/* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
override def registerShuffle[K, V, C](
shuffleId: Int,
@@ -49,12 +52,21 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager
/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
: ShuffleWriter[K, V] = {
- new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
+ new HashShuffleWriter(
+ shuffleBlockManager, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
}
/** Remove a shuffle's metadata from the ShuffleManager. */
- override def unregisterShuffle(shuffleId: Int): Unit = {}
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ shuffleBlockManager.removeShuffle(shuffleId)
+ }
+
+ override def shuffleBlockManager: FileShuffleBlockManager = {
+ fileShuffleBlockManager
+ }
/** Shut down this ShuffleManager. */
- override def stop(): Unit = {}
+ override def stop(): Unit = {
+ shuffleBlockManager.stop()
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 7bed97a63f0f6..88a5f1e5ddf58 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -36,10 +36,8 @@ private[spark] class HashShuffleReader[K, C](
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
- val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val ser = Serializer.getSerializer(dep.serializer)
- val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser,
- readMetrics)
+ val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 51e454d9313c9..4b9454d75abb7 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -17,14 +17,15 @@
package org.apache.spark.shuffle.hash
-import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
-import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext}
-import org.apache.spark.storage.{BlockObjectWriter}
-import org.apache.spark.serializer.Serializer
+import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle._
+import org.apache.spark.storage.BlockObjectWriter
private[spark] class HashShuffleWriter[K, V](
+ shuffleBlockManager: FileShuffleBlockManager,
handle: BaseShuffleHandle[K, V, _],
mapId: Int,
context: TaskContext)
@@ -43,7 +44,6 @@ private[spark] class HashShuffleWriter[K, V](
metrics.shuffleWriteMetrics = Some(writeMetrics)
private val blockManager = SparkEnv.get.blockManager
- private val shuffleBlockManager = blockManager.shuffleBlockManager
private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser,
writeMetrics)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 6dcca47ea7c0c..b727438ae7e47 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -17,14 +17,17 @@
package org.apache.spark.shuffle.sort
-import java.io.{DataInputStream, FileInputStream}
+import java.util.concurrent.ConcurrentHashMap
+import org.apache.spark.{SparkConf, TaskContext, ShuffleDependency}
import org.apache.spark.shuffle._
-import org.apache.spark.{TaskContext, ShuffleDependency}
import org.apache.spark.shuffle.hash.HashShuffleReader
-import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId}
-private[spark] class SortShuffleManager extends ShuffleManager {
+private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager {
+
+ private val indexShuffleBlockManager = new IndexShuffleBlockManager()
+ private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]()
+
/**
* Register a shuffle with the manager and obtain a handle for it to pass to tasks.
*/
@@ -52,29 +55,29 @@ private[spark] class SortShuffleManager extends ShuffleManager {
/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
: ShuffleWriter[K, V] = {
- new SortShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
+ val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]]
+ shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps)
+ new SortShuffleWriter(
+ shuffleBlockManager, baseShuffleHandle, mapId, context)
}
/** Remove a shuffle's metadata from the ShuffleManager. */
- override def unregisterShuffle(shuffleId: Int): Unit = {}
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ if (shuffleMapNumber.containsKey(shuffleId)) {
+ val numMaps = shuffleMapNumber.remove(shuffleId)
+ (0 until numMaps).map{ mapId =>
+ shuffleBlockManager.removeDataByMap(shuffleId, mapId)
+ }
+ }
+ true
+ }
- /** Shut down this ShuffleManager. */
- override def stop(): Unit = {}
+ override def shuffleBlockManager: IndexShuffleBlockManager = {
+ indexShuffleBlockManager
+ }
- /** Get the location of a block in a map output file. Uses the index file we create for it. */
- def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = {
- // The block is actually going to be a range of a single map output file for this map, so
- // figure out the ID of the consolidated file, then the offset within that from our index
- val consolidatedId = blockId.copy(reduceId = 0)
- val indexFile = diskManager.getFile(consolidatedId.name + ".index")
- val in = new DataInputStream(new FileInputStream(indexFile))
- try {
- in.skip(blockId.reduceId * 8)
- val offset = in.readLong()
- val nextOffset = in.readLong()
- new FileSegment(diskManager.getFile(consolidatedId), offset, nextOffset - offset)
- } finally {
- in.close()
- }
+ /** Shut down this ShuffleManager. */
+ override def stop(): Unit = {
+ shuffleBlockManager.stop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 22f656fa371ea..89a78d6982ba0 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -17,34 +17,25 @@
package org.apache.spark.shuffle.sort
-import java.io.{BufferedOutputStream, File, FileOutputStream, DataOutputStream}
-
import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle}
+import org.apache.spark.shuffle.{IndexShuffleBlockManager, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter
private[spark] class SortShuffleWriter[K, V, C](
+ shuffleBlockManager: IndexShuffleBlockManager,
handle: BaseShuffleHandle[K, V, C],
mapId: Int,
context: TaskContext)
extends ShuffleWriter[K, V] with Logging {
private val dep = handle.dependency
- private val numPartitions = dep.partitioner.numPartitions
private val blockManager = SparkEnv.get.blockManager
- private val ser = Serializer.getSerializer(dep.serializer.orNull)
-
- private val conf = SparkEnv.get.conf
- private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024
private var sorter: ExternalSorter[K, V, _] = null
- private var outputFile: File = null
- private var indexFile: File = null
// Are we in the process of stopping? Because map tasks can call stop() with success = true
// and then call stop() with success = false if they get an exception, we want to make sure
@@ -74,17 +65,10 @@ private[spark] class SortShuffleWriter[K, V, C](
sorter.insertAll(records)
}
- // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later
- // serve different ranges of this file using an index file that we create at the end.
- val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0)
-
- outputFile = blockManager.diskBlockManager.getFile(blockId)
- indexFile = blockManager.diskBlockManager.getFile(blockId.name + ".index")
-
- val partitionLengths = sorter.writePartitionedFile(blockId, context)
-
- // Register our map output with the ShuffleBlockManager, which handles cleaning it over time
- blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions)
+ val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
+ val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId)
+ val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
+ shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
mapStatus = new MapStatus(blockManager.blockManagerId,
partitionLengths.map(MapOutputTracker.compressSize))
@@ -100,13 +84,8 @@ private[spark] class SortShuffleWriter[K, V, C](
if (success) {
return Option(mapStatus)
} else {
- // The map task failed, so delete our output file if we created one
- if (outputFile != null) {
- outputFile.delete()
- }
- if (indexFile != null) {
- indexFile.delete()
- }
+ // The map task failed, so delete our output data.
+ shuffleBlockManager.removeDataByMap(dep.shuffleId, mapId)
return None
}
} finally {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
new file mode 100644
index 0000000000000..5b6d086630834
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+
+/**
+ * An interface for providing data for blocks.
+ *
+ * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer.
+ *
+ * Aside from unit tests, [[BlockManager]] is the main class that implements this.
+ */
+private[spark] trait BlockDataProvider {
+ def getBlockData(blockId: String): Either[FileSegment, ByteBuffer]
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
deleted file mode 100644
index 5f44f5f3197fd..0000000000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ /dev/null
@@ -1,339 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.util.concurrent.LinkedBlockingQueue
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.Queue
-import scala.util.{Failure, Success}
-
-import io.netty.buffer.ByteBuf
-
-import org.apache.spark.{Logging, SparkException}
-import org.apache.spark.executor.ShuffleReadMetrics
-import org.apache.spark.network.BufferMessage
-import org.apache.spark.network.ConnectionManagerId
-import org.apache.spark.network.netty.ShuffleCopier
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.util.Utils
-
-/**
- * A block fetcher iterator interface. There are two implementations:
- *
- * BasicBlockFetcherIterator: uses a custom-built NIO communication layer.
- * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer.
- *
- * Eventually we would like the two to converge and use a single NIO-based communication layer,
- * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores),
- * NIO would perform poorly and thus the need for the Netty OIO one.
- */
-
-private[storage]
-trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
- def initialize()
-}
-
-
-private[storage]
-object BlockFetcherIterator {
-
- // A request to fetch one or more blocks, complete with their sizes
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
- val size = blocks.map(_._2).sum
- }
-
- // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
- // the block (since we want all deserializaton to happen in the calling thread); can also
- // represent a fetch failure if size == -1.
- class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
- def failed: Boolean = size == -1
- }
-
- class BasicBlockFetcherIterator(
- private val blockManager: BlockManager,
- val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
- serializer: Serializer,
- readMetrics: ShuffleReadMetrics)
- extends BlockFetcherIterator {
-
- import blockManager._
-
- if (blocksByAddress == null) {
- throw new IllegalArgumentException("BlocksByAddress is null")
- }
-
- // Total number blocks fetched (local + remote). Also number of FetchResults expected
- protected var _numBlocksToFetch = 0
-
- protected var startTime = System.currentTimeMillis
-
- // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
- protected val localBlocksToFetch = new ArrayBuffer[BlockId]()
-
- // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
- protected val remoteBlocksToFetch = new HashSet[BlockId]()
-
- // A queue to hold our results.
- protected val results = new LinkedBlockingQueue[FetchResult]
-
- // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
- // the number of bytes in flight is limited to maxBytesInFlight
- private val fetchRequests = new Queue[FetchRequest]
-
- // Current bytes in flight from our requests
- private var bytesInFlight = 0L
-
- protected def sendRequest(req: FetchRequest) {
- logDebug("Sending request for %d blocks (%s) from %s".format(
- req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
- val cmId = new ConnectionManagerId(req.address.host, req.address.port)
- val blockMessageArray = new BlockMessageArray(req.blocks.map {
- case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
- })
- bytesInFlight += req.size
- val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
- val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
- future.onComplete {
- case Success(message) => {
- val bufferMessage = message.asInstanceOf[BufferMessage]
- val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
- for (blockMessage <- blockMessageArray) {
- if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
- throw new SparkException(
- "Unexpected message " + blockMessage.getType + " received from " + cmId)
- }
- val blockId = blockMessage.getId
- val networkSize = blockMessage.getData.limit()
- results.put(new FetchResult(blockId, sizeMap(blockId),
- () => dataDeserialize(blockId, blockMessage.getData, serializer)))
- // TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can
- // be incrementing bytes read at the same time (SPARK-2625).
- readMetrics.remoteBytesRead += networkSize
- readMetrics.remoteBlocksFetched += 1
- logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
- }
- }
- case Failure(exception) => {
- logError("Could not get block(s) from " + cmId, exception)
- for ((blockId, size) <- req.blocks) {
- results.put(new FetchResult(blockId, -1, null))
- }
- }
- }
- }
-
- protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
- // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
- // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
- // nodes, rather than blocking on reading output from one node.
- val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
-
- // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
- // at most maxBytesInFlight in order to limit the amount of data in flight.
- val remoteRequests = new ArrayBuffer[FetchRequest]
- var totalBlocks = 0
- for ((address, blockInfos) <- blocksByAddress) {
- totalBlocks += blockInfos.size
- if (address == blockManagerId) {
- // Filter out zero-sized blocks
- localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
- _numBlocksToFetch += localBlocksToFetch.size
- } else {
- val iterator = blockInfos.iterator
- var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(BlockId, Long)]
- while (iterator.hasNext) {
- val (blockId, size) = iterator.next()
- // Skip empty blocks
- if (size > 0) {
- curBlocks += ((blockId, size))
- remoteBlocksToFetch += blockId
- _numBlocksToFetch += 1
- curRequestSize += size
- } else if (size < 0) {
- throw new BlockException(blockId, "Negative block size " + size)
- }
- if (curRequestSize >= targetRequestSize) {
- // Add this FetchRequest
- remoteRequests += new FetchRequest(address, curBlocks)
- curBlocks = new ArrayBuffer[(BlockId, Long)]
- logDebug(s"Creating fetch request of $curRequestSize at $address")
- curRequestSize = 0
- }
- }
- // Add in the final request
- if (!curBlocks.isEmpty) {
- remoteRequests += new FetchRequest(address, curBlocks)
- }
- }
- }
- logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " +
- totalBlocks + " blocks")
- remoteRequests
- }
-
- protected def getLocalBlocks() {
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- for (id <- localBlocksToFetch) {
- try {
- // getLocalFromDisk never return None but throws BlockException
- val iter = getLocalFromDisk(id, serializer).get
- // Pass 0 as size since it's not in flight
- readMetrics.localBlocksFetched += 1
- results.put(new FetchResult(id, 0, () => iter))
- logDebug("Got local block " + id)
- } catch {
- case e: Exception => {
- logError(s"Error occurred while fetching local blocks", e)
- results.put(new FetchResult(id, -1, null))
- return
- }
- }
- }
- }
-
- override def initialize() {
- // Split local and remote blocks.
- val remoteRequests = splitLocalRemoteBlocks()
- // Add the remote requests into our queue in a random order
- fetchRequests ++= Utils.randomize(remoteRequests)
-
- // Send out initial requests for blocks, up to our maxBytesInFlight
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
- }
-
- val numFetches = remoteRequests.size - fetchRequests.size
- logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
-
- // Get Local Blocks
- startTime = System.currentTimeMillis
- getLocalBlocks()
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
- }
-
- // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue
- // as they arrive.
- @volatile protected var resultsGotten = 0
-
- override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
-
- override def next(): (BlockId, Option[Iterator[Any]]) = {
- resultsGotten += 1
- val startFetchWait = System.currentTimeMillis()
- val result = results.take()
- val stopFetchWait = System.currentTimeMillis()
- readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
- if (! result.failed) bytesInFlight -= result.size
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
- }
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
- }
- }
- // End of BasicBlockFetcherIterator
-
- class NettyBlockFetcherIterator(
- blockManager: BlockManager,
- blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
- serializer: Serializer,
- readMetrics: ShuffleReadMetrics)
- extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) {
-
- import blockManager._
-
- val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest]
-
- private def startCopiers(numCopiers: Int): List[_ <: Thread] = {
- (for ( i <- Range(0,numCopiers) ) yield {
- val copier = new Thread {
- override def run(){
- try {
- while(!isInterrupted && !fetchRequestsSync.isEmpty) {
- sendRequest(fetchRequestsSync.take())
- }
- } catch {
- case x: InterruptedException => logInfo("Copier Interrupted")
- // case _ => throw new SparkException("Exception Throw in Shuffle Copier")
- }
- }
- }
- copier.start
- copier
- }).toList
- }
-
- // keep this to interrupt the threads when necessary
- private def stopCopiers() {
- for (copier <- copiers) {
- copier.interrupt()
- }
- }
-
- override protected def sendRequest(req: FetchRequest) {
-
- def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) {
- val fetchResult = new FetchResult(blockId, blockSize,
- () => dataDeserialize(blockId, blockData.nioBuffer, serializer))
- results.put(fetchResult)
- }
-
- logDebug("Sending request for %d blocks (%s) from %s".format(
- req.blocks.size, Utils.bytesToString(req.size), req.address.host))
- val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort)
- val cpier = new ShuffleCopier(blockManager.conf)
- cpier.getBlocks(cmId, req.blocks, putResult)
- logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
- }
-
- private var copiers: List[_ <: Thread] = null
-
- override def initialize() {
- // Split Local Remote Blocks and set numBlocksToFetch
- val remoteRequests = splitLocalRemoteBlocks()
- // Add the remote requests into our queue in a random order
- for (request <- Utils.randomize(remoteRequests)) {
- fetchRequestsSync.put(request)
- }
-
- copiers = startCopiers(conf.getInt("spark.shuffle.copier.threads", 6))
- logInfo("Started " + fetchRequestsSync.size + " remote fetches in " +
- Utils.getUsedTimeMs(startTime))
-
- // Get Local Blocks
- startTime = System.currentTimeMillis
- getLocalBlocks()
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
- }
-
- override def next(): (BlockId, Option[Iterator[Any]]) = {
- resultsGotten += 1
- val result = results.take()
- // If all the results has been retrieved, copiers will exit automatically
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
- }
- }
- // End of NettyBlockFetcherIterator
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index c1756ac905417..a83a3f468ae5f 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -58,6 +58,11 @@ case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends Blo
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}
+@DeveloperApi
+case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+ def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data"
+}
+
@DeveloperApi
case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
@@ -92,6 +97,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+ val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r
val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
val TASKRESULT = "taskresult_([0-9]+)".r
@@ -104,6 +110,8 @@ object BlockId {
RDDBlockId(rddId.toInt, splitIndex.toInt)
case SHUFFLE(shuffleId, mapId, reduceId) =>
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+ case SHUFFLE_DATA(shuffleId, mapId, reduceId) =>
+ ShuffleDataBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case BROADCAST(broadcastId, field) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index e4c3d58905e7f..d1bee3d2c033c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -20,12 +20,14 @@ package org.apache.spark.storage
import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
+import scala.concurrent.ExecutionContext.Implicits.global
+
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.util.Random
-import akka.actor.{ActorSystem, Cancellable, Props}
+import akka.actor.{ActorSystem, Props}
import sun.nio.ch.DirectBuffer
import org.apache.spark._
@@ -36,6 +38,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.util._
+
private[spark] sealed trait BlockValues
private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues
private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues
@@ -57,19 +60,14 @@ private[spark] class BlockManager(
defaultSerializer: Serializer,
maxMemory: Long,
val conf: SparkConf,
- securityManager: SecurityManager,
mapOutputTracker: MapOutputTracker,
- shuffleManager: ShuffleManager)
- extends Logging {
+ shuffleManager: ShuffleManager,
+ blockTransferService: BlockTransferService)
+ extends BlockDataManager with Logging {
- private val port = conf.getInt("spark.blockManager.port", 0)
- val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager)
- val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
- conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")))
- val connectionManager =
- new ConnectionManager(port, conf, securityManager, "Connection manager for block manager")
+ blockTransferService.init(this)
- implicit val futureExecContext = connectionManager.futureExecContext
+ val diskBlockManager = new DiskBlockManager(this, conf)
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -83,24 +81,13 @@ private[spark] class BlockManager(
val tachyonStorePath = s"$storeDir/$appFolderName/${this.executorId}"
val tachyonMaster = conf.get("spark.tachyonStore.url", "tachyon://localhost:19998")
val tachyonBlockManager =
- new TachyonBlockManager(shuffleBlockManager, tachyonStorePath, tachyonMaster)
+ new TachyonBlockManager(this, tachyonStorePath, tachyonMaster)
tachyonInitialized = true
new TachyonStore(this, tachyonBlockManager)
}
- // If we use Netty for shuffle, start a new Netty-based shuffle sender service.
- private val nettyPort: Int = {
- val useNetty = conf.getBoolean("spark.shuffle.use.netty", false)
- val nettyPortConfig = conf.getInt("spark.shuffle.sender.port", 0)
- if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
- }
-
val blockManagerId = BlockManagerId(
- executorId, connectionManager.id.host, connectionManager.id.port, nettyPort)
-
- // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
- // for receiving shuffle outputs)
- val maxBytesInFlight = conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024
+ executorId, blockTransferService.hostName, blockTransferService.port)
// Whether to compress broadcast variables that are stored
private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
@@ -143,11 +130,11 @@ private[spark] class BlockManager(
master: BlockManagerMaster,
serializer: Serializer,
conf: SparkConf,
- securityManager: SecurityManager,
mapOutputTracker: MapOutputTracker,
- shuffleManager: ShuffleManager) = {
+ shuffleManager: ShuffleManager,
+ blockTransferService: BlockTransferService) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
- conf, securityManager, mapOutputTracker, shuffleManager)
+ conf, mapOutputTracker, shuffleManager, blockTransferService)
}
/**
@@ -156,7 +143,6 @@ private[spark] class BlockManager(
*/
private def initialize(): Unit = {
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
- BlockManagerWorker.startBlockManagerWorker(this)
}
/**
@@ -219,6 +205,33 @@ private[spark] class BlockManager(
}
}
+ /**
+ * Interface to get local block data.
+ *
+ * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ */
+ override def getBlockData(blockId: String): Option[ManagedBuffer] = {
+ val bid = BlockId(blockId)
+ if (bid.isShuffle) {
+ Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]))
+ } else {
+ val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
+ if (blockBytesOpt.isDefined) {
+ val buffer = blockBytesOpt.get
+ Some(new NioByteBufferManagedBuffer(buffer))
+ } else {
+ None
+ }
+ }
+ }
+
+ /**
+ * Put the block locally, using the given storage level.
+ */
+ override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = {
+ putBytes(BlockId(blockId), data.nioByteBuffer(), level)
+ }
+
/**
* Get the BlockStatus for the block identified by the given ID, if it exists.
* NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon.
@@ -326,10 +339,10 @@ private[spark] class BlockManager(
* shuffle blocks. It is safe to do so without a lock on block info since disk store
* never deletes (recent) items.
*/
- def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
- diskStore.getValues(blockId, serializer).orElse {
- throw new BlockException(blockId, s"Block $blockId not found on disk, though it should be")
- }
+ def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
+ val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
+ val is = wrapForCompression(blockId, buf.inputStream())
+ Some(serializer.newInstance().deserializeStream(is).asIterator)
}
/**
@@ -348,7 +361,8 @@ private[spark] class BlockManager(
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
if (blockId.isShuffle) {
- diskStore.getBytes(blockId) match {
+ val shuffleBlockManager = shuffleManager.shuffleBlockManager
+ shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]) match {
case Some(bytes) =>
Some(bytes)
case None =>
@@ -499,8 +513,9 @@ private[spark] class BlockManager(
val locations = Random.shuffle(master.getLocations(blockId))
for (loc <- locations) {
logDebug(s"Getting remote block $blockId from $loc")
- val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
+ val data = blockTransferService.fetchBlockSync(
+ loc.host, loc.port, blockId.toString).nioByteBuffer()
+
if (data != null) {
if (asBlockResult) {
return Some(new BlockResult(
@@ -534,28 +549,6 @@ private[spark] class BlockManager(
None
}
- /**
- * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns
- * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined
- * fashion as they're received. Expects a size in bytes to be provided for each block fetched,
- * so that we can control the maxMegabytesInFlight for the fetch.
- */
- def getMultiple(
- blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
- serializer: Serializer,
- readMetrics: ShuffleReadMetrics): BlockFetcherIterator = {
- val iter =
- if (conf.getBoolean("spark.shuffle.use.netty", false)) {
- new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer,
- readMetrics)
- } else {
- new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer,
- readMetrics)
- }
- iter.initialize()
- iter
- }
-
def putIterator(
blockId: BlockId,
values: Iterator[Any],
@@ -808,12 +801,15 @@ private[spark] class BlockManager(
data.rewind()
logDebug(s"Try to replicate $blockId once; The size of the data is ${data.limit()} Bytes. " +
s"To node: $peer")
- val putBlock = PutBlock(blockId, data, tLevel)
- val cmId = new ConnectionManagerId(peer.host, peer.port)
- val syncPutBlockSuccess = BlockManagerWorker.syncPutBlock(putBlock, cmId)
- if (!syncPutBlockSuccess) {
- logError(s"Failed to call syncPutBlock to $peer")
+
+ try {
+ blockTransferService.uploadBlockSync(
+ peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel)
+ } catch {
+ case e: Exception =>
+ logError(s"Failed to replicate block to $peer", e)
}
+
logDebug("Replicating BlockId %s once used %fs; The size of the data is %d bytes."
.format(blockId, (System.nanoTime - start) / 1e6, data.limit()))
}
@@ -1038,31 +1034,12 @@ private[spark] class BlockManager(
bytes: ByteBuffer,
serializer: Serializer = defaultSerializer): Iterator[Any] = {
bytes.rewind()
-
- def getIterator: Iterator[Any] = {
- val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true))
- serializer.newInstance().deserializeStream(stream).asIterator
- }
-
- if (blockId.isShuffle) {
- /* Reducer may need to read many local shuffle blocks and will wrap them into Iterators
- * at the beginning. The wrapping will cost some memory (compression instance
- * initialization, etc.). Reducer reads shuffle blocks one by one so we could do the
- * wrapping lazily to save memory. */
- class LazyProxyIterator(f: => Iterator[Any]) extends Iterator[Any] {
- lazy val proxy = f
- override def hasNext: Boolean = proxy.hasNext
- override def next(): Any = proxy.next()
- }
- new LazyProxyIterator(getIterator)
- } else {
- getIterator
- }
+ val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true))
+ serializer.newInstance().deserializeStream(stream).asIterator
}
def stop(): Unit = {
- connectionManager.stop()
- shuffleBlockManager.stop()
+ blockTransferService.stop()
diskBlockManager.stop()
actorSystem.stop(slaveActor)
blockInfo.clear()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index b1585bd8199d1..d4487fce49ab6 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -36,11 +36,10 @@ import org.apache.spark.util.Utils
class BlockManagerId private (
private var executorId_ : String,
private var host_ : String,
- private var port_ : Int,
- private var nettyPort_ : Int
- ) extends Externalizable {
+ private var port_ : Int)
+ extends Externalizable {
- private def this() = this(null, null, 0, 0) // For deserialization only
+ private def this() = this(null, null, 0) // For deserialization only
def executorId: String = executorId_
@@ -60,32 +59,28 @@ class BlockManagerId private (
def port: Int = port_
- def nettyPort: Int = nettyPort_
-
override def writeExternal(out: ObjectOutput) {
out.writeUTF(executorId_)
out.writeUTF(host_)
out.writeInt(port_)
- out.writeInt(nettyPort_)
}
override def readExternal(in: ObjectInput) {
executorId_ = in.readUTF()
host_ = in.readUTF()
port_ = in.readInt()
- nettyPort_ = in.readInt()
}
@throws(classOf[IOException])
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
- override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort)
+ override def toString = s"BlockManagerId($executorId, $host, $port)"
- override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + nettyPort
+ override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port
override def equals(that: Any) = that match {
case id: BlockManagerId =>
- executorId == id.executorId && port == id.port && host == id.host && nettyPort == id.nettyPort
+ executorId == id.executorId && port == id.port && host == id.host
case _ =>
false
}
@@ -100,11 +95,10 @@ private[spark] object BlockManagerId {
* @param execId ID of the executor.
* @param host Host name of the block manager.
* @param port Port of the block manager.
- * @param nettyPort Optional port for the Netty-based shuffle sender.
* @return A new [[org.apache.spark.storage.BlockManagerId]].
*/
- def apply(execId: String, host: String, port: Int, nettyPort: Int) =
- getCachedBlockManagerId(new BlockManagerId(execId, host, port, nettyPort))
+ def apply(execId: String, host: String, port: Int) =
+ getCachedBlockManagerId(new BlockManagerId(execId, host, port))
def apply(in: ObjectInput) = {
val obj = new BlockManagerId()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 669307765d1fa..2e262594b3538 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -27,7 +27,11 @@ import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.AkkaUtils
private[spark]
-class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Logging {
+class BlockManagerMaster(
+ var driverActor: ActorRef,
+ conf: SparkConf,
+ isDriver: Boolean)
+ extends Logging {
private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf)
private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf)
@@ -101,7 +105,8 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
def removeRdd(rddId: Int, blocking: Boolean) {
val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
future.onFailure {
- case e: Throwable => logError("Failed to remove RDD " + rddId, e)
+ case e: Exception =>
+ logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}")
}
if (blocking) {
Await.result(future, timeout)
@@ -112,7 +117,8 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
def removeShuffle(shuffleId: Int, blocking: Boolean) {
val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
future.onFailure {
- case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e)
+ case e: Exception =>
+ logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}")
}
if (blocking) {
Await.result(future, timeout)
@@ -124,9 +130,9 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
val future = askDriverWithReply[Future[Seq[Int]]](
RemoveBroadcast(broadcastId, removeFromMaster))
future.onFailure {
- case e: Throwable =>
- logError("Failed to remove broadcast " + broadcastId +
- " with removeFromMaster = " + removeFromMaster, e)
+ case e: Exception =>
+ logWarning(s"Failed to remove broadcast $broadcastId" +
+ s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}")
}
if (blocking) {
Await.result(future, timeout)
@@ -194,7 +200,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
/** Stop the driver actor, called only on the Spark driver node */
def stop() {
- if (driverActor != null) {
+ if (driverActor != null && isDriver) {
tell(StopBlockManagerMaster)
driverActor = null
logInfo("BlockManagerMaster stopped")
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 3ab07703b6f85..1a6c7cb24f9ac 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -203,7 +203,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
blockLocations.remove(blockId)
}
}
- listenerBus.post(SparkListenerBlockManagerRemoved(blockManagerId))
+ listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId))
}
private def expireDeadHosts() {
@@ -325,6 +325,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ val time = System.currentTimeMillis()
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
@@ -340,9 +341,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
id.hostPort, Utils.bytesToString(maxMemSize)))
blockManagerInfo(id) =
- new BlockManagerInfo(id, System.currentTimeMillis(), maxMemSize, slaveActor)
+ new BlockManagerInfo(id, time, maxMemSize, slaveActor)
}
- listenerBus.post(SparkListenerBlockManagerAdded(id, maxMemSize))
+ listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize))
}
private def updateBlockInfo(
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index c194e0fed3367..14ae2f38c5670 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -21,7 +21,7 @@ import scala.concurrent.Future
import akka.actor.{ActorRef, Actor}
-import org.apache.spark.{Logging, MapOutputTracker}
+import org.apache.spark.{Logging, MapOutputTracker, SparkEnv}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.ActorLogReceive
@@ -55,7 +55,7 @@ class BlockManagerSlaveActor(
if (mapOutputTracker != null) {
mapOutputTracker.unregisterShuffle(shuffleId)
}
- blockManager.shuffleBlockManager.removeShuffle(shuffleId)
+ SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
}
case RemoveBroadcast(broadcastId, tellMaster) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
deleted file mode 100644
index bf002a42d5dc5..0000000000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
-import org.apache.spark.Logging
-import org.apache.spark.network._
-import org.apache.spark.util.Utils
-
-import scala.concurrent.Await
-import scala.concurrent.duration.Duration
-import scala.util.{Try, Failure, Success}
-
-/**
- * A network interface for BlockManager. Each slave should have one
- * BlockManagerWorker.
- *
- * TODO: Use event model.
- */
-private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
-
- blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
-
- def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
- logDebug("Handling message " + msg)
- msg match {
- case bufferMessage: BufferMessage => {
- try {
- logDebug("Handling as a buffer message " + bufferMessage)
- val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
- logDebug("Parsed as a block message array")
- val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
- Some(new BlockMessageArray(responseMessages).toBufferMessage)
- } catch {
- case e: Exception => {
- logError("Exception handling buffer message", e)
- val errorMessage = Message.createBufferMessage(msg.id)
- errorMessage.hasError = true
- Some(errorMessage)
- }
- }
- }
- case otherMessage: Any => {
- logError("Unknown type message received: " + otherMessage)
- val errorMessage = Message.createBufferMessage(msg.id)
- errorMessage.hasError = true
- Some(errorMessage)
- }
- }
- }
-
- def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = {
- blockMessage.getType match {
- case BlockMessage.TYPE_PUT_BLOCK => {
- val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
- logDebug("Received [" + pB + "]")
- putBlock(pB.id, pB.data, pB.level)
- None
- }
- case BlockMessage.TYPE_GET_BLOCK => {
- val gB = new GetBlock(blockMessage.getId)
- logDebug("Received [" + gB + "]")
- val buffer = getBlock(gB.id)
- if (buffer == null) {
- return None
- }
- Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer)))
- }
- case _ => None
- }
- }
-
- private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) {
- val startTimeMs = System.currentTimeMillis()
- logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
- blockManager.putBytes(id, bytes, level)
- logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
- + " with data size: " + bytes.limit)
- }
-
- private def getBlock(id: BlockId): ByteBuffer = {
- val startTimeMs = System.currentTimeMillis()
- logDebug("GetBlock " + id + " started from " + startTimeMs)
- val buffer = blockManager.getLocalBytes(id) match {
- case Some(bytes) => bytes
- case None => null
- }
- logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
- + " and got buffer " + buffer)
- buffer
- }
-}
-
-private[spark] object BlockManagerWorker extends Logging {
- private var blockManagerWorker: BlockManagerWorker = null
-
- def startBlockManagerWorker(manager: BlockManager) {
- blockManagerWorker = new BlockManagerWorker(manager)
- }
-
- def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
- val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val blockMessage = BlockMessage.fromPutBlock(msg)
- val blockMessageArray = new BlockMessageArray(blockMessage)
- val resultMessage = Try(Await.result(connectionManager.sendMessageReliably(
- toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf))
- resultMessage.isSuccess
- }
-
- def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
- val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val blockMessage = BlockMessage.fromGetBlock(msg)
- val blockMessageArray = new BlockMessageArray(blockMessage)
- val responseMessage = Try(Await.result(connectionManager.sendMessageReliably(
- toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf))
- responseMessage match {
- case Success(message) => {
- val bufferMessage = message.asInstanceOf[BufferMessage]
- logDebug("Response message received " + bufferMessage)
- BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => {
- logDebug("Found " + blockMessage)
- return blockMessage.getData
- })
- }
- case Failure(exception) => logDebug("No response message received")
- }
- null
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
new file mode 100644
index 0000000000000..9ef453605f4f1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+
+class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found")
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index adda971fd7b47..9c469370ffe1f 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -65,8 +65,6 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
/**
* BlockObjectWriter which writes directly to a file on disk. Appends to the given file.
- * The given write metrics will be updated incrementally, but will not necessarily be current until
- * commitAndClose is called.
*/
private[spark] class DiskBlockObjectWriter(
blockId: BlockId,
@@ -75,6 +73,8 @@ private[spark] class DiskBlockObjectWriter(
bufferSize: Int,
compressStream: OutputStream => OutputStream,
syncWrites: Boolean,
+ // These write metrics concurrently shared with other active BlockObjectWriter's who
+ // are themselves performing writes. All updates must be relative.
writeMetrics: ShuffleWriteMetrics)
extends BlockObjectWriter(blockId)
with Logging
@@ -94,14 +94,30 @@ private[spark] class DiskBlockObjectWriter(
private var fos: FileOutputStream = null
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
+ private var initialized = false
+
+ /**
+ * Cursors used to represent positions in the file.
+ *
+ * xxxxxxxx|--------|--- |
+ * ^ ^ ^
+ * | | finalPosition
+ * | reportedPosition
+ * initialPosition
+ *
+ * initialPosition: Offset in the file where we start writing. Immutable.
+ * reportedPosition: Position at the time of the last update to the write metrics.
+ * finalPosition: Offset where we stopped writing. Set on closeAndCommit() then never changed.
+ * -----: Current writes to the underlying file.
+ * xxxxx: Existing contents of the file.
+ */
private val initialPosition = file.length()
private var finalPosition: Long = -1
- private var initialized = false
+ private var reportedPosition = initialPosition
/** Calling channel.position() to update the write metrics can be a little bit expensive, so we
* only call it every N writes */
private var writesSinceMetricsUpdate = 0
- private var lastPosition = initialPosition
override def open(): BlockObjectWriter = {
fos = new FileOutputStream(file, true)
@@ -140,17 +156,18 @@ private[spark] class DiskBlockObjectWriter(
// serializer stream and the lower level stream.
objOut.flush()
bs.flush()
- updateBytesWritten()
close()
}
finalPosition = file.length()
+ // In certain compression codecs, more bytes are written after close() is called
+ writeMetrics.shuffleBytesWritten += (finalPosition - reportedPosition)
}
// Discard current writes. We do this by flushing the outstanding writes and then
// truncating the file to its initial position.
override def revertPartialWritesAndClose() {
try {
- writeMetrics.shuffleBytesWritten -= (lastPosition - initialPosition)
+ writeMetrics.shuffleBytesWritten -= (reportedPosition - initialPosition)
if (initialized) {
objOut.flush()
@@ -189,10 +206,14 @@ private[spark] class DiskBlockObjectWriter(
new FileSegment(file, initialPosition, finalPosition - initialPosition)
}
+ /**
+ * Report the number of bytes written in this writer's shuffle write metrics.
+ * Note that this is only valid before the underlying streams are closed.
+ */
private def updateBytesWritten() {
val pos = channel.position()
- writeMetrics.shuffleBytesWritten += (pos - lastPosition)
- lastPosition = pos
+ writeMetrics.shuffleBytesWritten += (pos - reportedPosition)
+ reportedPosition = pos
}
private def callWithTiming(f: => Unit) = {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 4d66ccea211fa..a715594f198c2 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -21,11 +21,9 @@ import java.io.File
import java.text.SimpleDateFormat
import java.util.{Date, Random, UUID}
-import org.apache.spark.{SparkEnv, Logging}
+import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
import org.apache.spark.util.Utils
-import org.apache.spark.shuffle.sort.SortShuffleManager
/**
* Creates and maintains the logical mapping between logical blocks and physical on-disk
@@ -33,49 +31,27 @@ import org.apache.spark.shuffle.sort.SortShuffleManager
* However, it is also possible to have a block map to only a segment of a file, by calling
* mapBlockToFileSegment().
*
- * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
+ * Block files are hashed among the directories listed in spark.local.dir (or in
+ * SPARK_LOCAL_DIRS, if it's set).
*/
-private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, rootDirs: String)
- extends PathResolver with Logging {
+private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkConf)
+ extends Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
-
- private val subDirsPerLocalDir =
- shuffleBlockManager.conf.getInt("spark.diskStore.subDirectories", 64)
+ private val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
/* Create one local directory for each path mentioned in spark.local.dir; then, inside this
* directory, create multiple subdirectories that we will hash files into, in order to avoid
* having really large inodes at the top level. */
- val localDirs: Array[File] = createLocalDirs()
+ val localDirs: Array[File] = createLocalDirs(conf)
if (localDirs.isEmpty) {
logError("Failed to create any local dir.")
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
}
private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
- private var shuffleSender : ShuffleSender = null
addShutdownHook()
- /**
- * Returns the physical file segment in which the given BlockId is located. If the BlockId has
- * been mapped to a specific FileSegment by the shuffle layer, that will be returned.
- * Otherwise, we assume the Block is mapped to the whole file identified by the BlockId.
- */
- def getBlockLocation(blockId: BlockId): FileSegment = {
- val env = SparkEnv.get // NOTE: can be null in unit tests
- if (blockId.isShuffle && env != null && env.shuffleManager.isInstanceOf[SortShuffleManager]) {
- // For sort-based shuffle, let it figure out its blocks
- val sortShuffleManager = env.shuffleManager.asInstanceOf[SortShuffleManager]
- sortShuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId], this)
- } else if (blockId.isShuffle && shuffleBlockManager.consolidateShuffleFiles) {
- // For hash-based shuffle with consolidated files, ShuffleBlockManager takes care of this
- shuffleBlockManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
- } else {
- val file = getFile(blockId.name)
- new FileSegment(file, 0, file.length())
- }
- }
-
def getFile(filename: String): File = {
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(filename)
@@ -105,7 +81,7 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager,
/** Check if disk block manager has a block. */
def containsBlock(blockId: BlockId): Boolean = {
- getBlockLocation(blockId).file.exists()
+ getFile(blockId.name).exists()
}
/** List all the files currently stored on disk by the disk manager. */
@@ -131,10 +107,9 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager,
(blockId, getFile(blockId))
}
- private def createLocalDirs(): Array[File] = {
- logDebug(s"Creating local directories at root dirs '$rootDirs'")
+ private def createLocalDirs(conf: SparkConf): Array[File] = {
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
- rootDirs.split(",").flatMap { rootDir =>
+ Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir =>
var foundLocalDir = false
var localDir: File = null
var localDirId: String = null
@@ -186,15 +161,5 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager,
}
}
}
-
- if (shuffleSender != null) {
- shuffleSender.stop()
- }
- }
-
- private[storage] def startShuffleBlockSender(port: Int): Int = {
- shuffleSender = new ShuffleSender(port, this)
- logInfo(s"Created ShuffleSender binding to port: ${shuffleSender.port}")
- shuffleSender.port
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index c83261dd91b36..e9304f6bb45d0 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import java.io.{FileOutputStream, RandomAccessFile}
+import java.io.{File, FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
import java.nio.channels.FileChannel.MapMode
@@ -34,7 +34,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
val minMemoryMapBytes = blockManager.conf.getLong("spark.storage.memoryMapThreshold", 2 * 4096L)
override def getSize(blockId: BlockId): Long = {
- diskManager.getBlockLocation(blockId).length
+ diskManager.getFile(blockId.name).length
}
override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = {
@@ -89,25 +89,33 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
}
}
- override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
- val segment = diskManager.getBlockLocation(blockId)
- val channel = new RandomAccessFile(segment.file, "r").getChannel
+ private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = {
+ val channel = new RandomAccessFile(file, "r").getChannel
try {
// For small files, directly read rather than memory map
- if (segment.length < minMemoryMapBytes) {
- val buf = ByteBuffer.allocate(segment.length.toInt)
- channel.read(buf, segment.offset)
+ if (length < minMemoryMapBytes) {
+ val buf = ByteBuffer.allocate(length.toInt)
+ channel.read(buf, offset)
buf.flip()
Some(buf)
} else {
- Some(channel.map(MapMode.READ_ONLY, segment.offset, segment.length))
+ Some(channel.map(MapMode.READ_ONLY, offset, length))
}
} finally {
channel.close()
}
}
+ override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
+ val file = diskManager.getFile(blockId.name)
+ getBytes(file, 0, file.length)
+ }
+
+ def getBytes(segment: FileSegment): Option[ByteBuffer] = {
+ getBytes(segment.file, segment.offset, segment.length)
+ }
+
override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
}
@@ -117,24 +125,25 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
* shuffle short-circuit code.
*/
def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
+ // TODO: Should bypass getBytes and use a stream based implementation, so that
+ // we won't use a lot of memory during e.g. external sort merge.
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
}
override def remove(blockId: BlockId): Boolean = {
- val fileSegment = diskManager.getBlockLocation(blockId)
- val file = fileSegment.file
- if (file.exists() && file.length() == fileSegment.length) {
+ val file = diskManager.getFile(blockId.name)
+ // If consolidation mode is used With HashShuffleMananger, the physical filename for the block
+ // is different from blockId.name. So the file returns here will not be exist, thus we avoid to
+ // delete the whole consolidated file by mistake.
+ if (file.exists()) {
file.delete()
} else {
- if (fileSegment.length < file.length()) {
- logWarning(s"Could not delete block associated with only a part of a file: $blockId")
- }
false
}
}
override def contains(blockId: BlockId): Boolean = {
- val file = diskManager.getBlockLocation(blockId).file
+ val file = diskManager.getFile(blockId.name)
file.exists()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
new file mode 100644
index 0000000000000..d868758a7f549
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.Queue
+
+import org.apache.spark.{TaskContext, Logging}
+import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.util.Utils
+
+
+/**
+ * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
+ * manager. For remote blocks, it fetches them using the provided BlockTransferService.
+ *
+ * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a
+ * pipelined fashion as they are received.
+ *
+ * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
+ * using too much memory.
+ *
+ * @param context [[TaskContext]], used for metrics update
+ * @param blockTransferService [[BlockTransferService]] for fetching remote blocks
+ * @param blockManager [[BlockManager]] for reading local blocks
+ * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
+ * For each block we also require the size (in bytes as a long field) in
+ * order to throttle the memory usage.
+ * @param serializer serializer used to deserialize the data.
+ * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
+ */
+private[spark]
+final class ShuffleBlockFetcherIterator(
+ context: TaskContext,
+ blockTransferService: BlockTransferService,
+ blockManager: BlockManager,
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
+ serializer: Serializer,
+ maxBytesInFlight: Long)
+ extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
+
+ import ShuffleBlockFetcherIterator._
+
+ /**
+ * Total number of blocks to fetch. This can be smaller than the total number of blocks
+ * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]].
+ *
+ * This should equal localBlocks.size + remoteBlocks.size.
+ */
+ private[this] var numBlocksToFetch = 0
+
+ /**
+ * The number of blocks proccessed by the caller. The iterator is exhausted when
+ * [[numBlocksProcessed]] == [[numBlocksToFetch]].
+ */
+ private[this] var numBlocksProcessed = 0
+
+ private[this] val startTime = System.currentTimeMillis
+
+ /** Local blocks to fetch, excluding zero-sized blocks. */
+ private[this] val localBlocks = new ArrayBuffer[BlockId]()
+
+ /** Remote blocks to fetch, excluding zero-sized blocks. */
+ private[this] val remoteBlocks = new HashSet[BlockId]()
+
+ /**
+ * A queue to hold our results. This turns the asynchronous model provided by
+ * [[BlockTransferService]] into a synchronous model (iterator).
+ */
+ private[this] val results = new LinkedBlockingQueue[FetchResult]
+
+ // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ // the number of bytes in flight is limited to maxBytesInFlight
+ private[this] val fetchRequests = new Queue[FetchRequest]
+
+ // Current bytes in flight from our requests
+ private[this] var bytesInFlight = 0L
+
+ private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+
+ initialize()
+
+ private[this] def sendRequest(req: FetchRequest) {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
+ bytesInFlight += req.size
+
+ // so we can look up the size of each blockID
+ val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
+ val blockIds = req.blocks.map(_._1.toString)
+
+ blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds,
+ new BlockFetchingListener {
+ override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
+ results.put(new FetchResult(BlockId(blockId), sizeMap(blockId),
+ () => serializer.newInstance().deserializeStream(
+ blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator
+ ))
+ shuffleMetrics.remoteBytesRead += data.size
+ shuffleMetrics.remoteBlocksFetched += 1
+ logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ }
+
+ override def onBlockFetchFailure(e: Throwable): Unit = {
+ logError("Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
+ // Note that there is a chance that some blocks have been fetched successfully, but we
+ // still add them to the failed queue. This is fine because when the caller see a
+ // FetchFailedException, it is going to fail the entire task anyway.
+ for ((blockId, size) <- req.blocks) {
+ results.put(new FetchResult(blockId, -1, null))
+ }
+ }
+ }
+ )
+ }
+
+ private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
+ // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
+
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+
+ // Tracks total number of blocks (including zero sized blocks)
+ var totalBlocks = 0
+ for ((address, blockInfos) <- blocksByAddress) {
+ totalBlocks += blockInfos.size
+ if (address == blockManager.blockManagerId) {
+ // Filter out zero-sized blocks
+ localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
+ numBlocksToFetch += localBlocks.size
+ } else {
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(BlockId, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ // Skip empty blocks
+ if (size > 0) {
+ curBlocks += ((blockId, size))
+ remoteBlocks += blockId
+ numBlocksToFetch += 1
+ curRequestSize += size
+ } else if (size < 0) {
+ throw new BlockException(blockId, "Negative block size " + size)
+ }
+ if (curRequestSize >= targetRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curBlocks = new ArrayBuffer[(BlockId, Long)]
+ logDebug(s"Creating fetch request of $curRequestSize at $address")
+ curRequestSize = 0
+ }
+ }
+ // Add in the final request
+ if (curBlocks.nonEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
+ }
+ }
+ logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
+ remoteRequests
+ }
+
+ private[this] def fetchLocalBlocks() {
+ // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+ // these all at once because they will just memory-map some files, so they won't consume
+ // any memory that might exceed our maxBytesInFlight
+ for (id <- localBlocks) {
+ try {
+ shuffleMetrics.localBlocksFetched += 1
+ results.put(new FetchResult(
+ id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get))
+ logDebug("Got local block " + id)
+ } catch {
+ case e: Exception =>
+ logError(s"Error occurred while fetching local blocks", e)
+ results.put(new FetchResult(id, -1, null))
+ return
+ }
+ }
+ }
+
+ private[this] def initialize(): Unit = {
+ // Split local and remote blocks.
+ val remoteRequests = splitLocalRemoteBlocks()
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
+
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ while (fetchRequests.nonEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+
+ val numFetches = remoteRequests.size - fetchRequests.size
+ logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
+
+ // Get Local Blocks
+ fetchLocalBlocks()
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
+
+ override def next(): (BlockId, Option[Iterator[Any]]) = {
+ numBlocksProcessed += 1
+ val startFetchWait = System.currentTimeMillis()
+ val result = results.take()
+ val stopFetchWait = System.currentTimeMillis()
+ shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
+ if (!result.failed) {
+ bytesInFlight -= result.size
+ }
+ // Send fetch requests up to maxBytesInFlight
+ while (fetchRequests.nonEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+}
+
+
+private[storage]
+object ShuffleBlockFetcherIterator {
+
+ /**
+ * A request to fetch blocks from a remote BlockManager.
+ * @param address remote BlockManager to fetch from.
+ * @param blocks Sequence of tuple, where the first element is the block id,
+ * and the second element is the estimated size, used to calculate bytesInFlight.
+ */
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
+ val size = blocks.map(_._2).sum
+ }
+
+ /**
+ * Result of a fetch from a remote block. A failure is represented as size == -1.
+ * @param blockId block id
+ * @param size estimated size of the block, used to calculate bytesInFlight.
+ * Note that this is NOT the exact bytes.
+ * @param deserialize closure to return the result in the form of an Iterator.
+ */
+ class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
+ def failed: Boolean = size == -1
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
index a6cbe3aa440ff..6908a59a79e60 100644
--- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.Utils
* @param rootDirs The directories to use for storing block files. Data will be hashed among these.
*/
private[spark] class TachyonBlockManager(
- shuffleManager: ShuffleBlockManager,
+ blockManager: BlockManager,
rootDirs: String,
val master: String)
extends Logging {
@@ -49,7 +49,7 @@ private[spark] class TachyonBlockManager(
private val MAX_DIR_CREATION_ATTEMPTS = 10
private val subDirsPerTachyonDir =
- shuffleManager.conf.get("spark.tachyonStore.subDirectories", "64").toInt
+ blockManager.conf.get("spark.tachyonStore.subDirectories", "64").toInt
// Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName;
// then, inside this directory, create multiple subdirectories that we will hash files into,
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
deleted file mode 100644
index aa83ea90ee9ee..0000000000000
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.util.concurrent.ArrayBlockingQueue
-
-import akka.actor._
-import org.apache.spark.shuffle.hash.HashShuffleManager
-import util.Random
-
-import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
-import org.apache.spark.scheduler.LiveListenerBus
-import org.apache.spark.serializer.KryoSerializer
-
-/**
- * This class tests the BlockManager and MemoryStore for thread safety and
- * deadlocks. It spawns a number of producer and consumer threads. Producer
- * threads continuously pushes blocks into the BlockManager and consumer
- * threads continuously retrieves the blocks form the BlockManager and tests
- * whether the block is correct or not.
- */
-private[spark] object ThreadingTest {
-
- val numProducers = 5
- val numBlocksPerProducer = 20000
-
- private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread {
- val queue = new ArrayBlockingQueue[(BlockId, Seq[Int])](100)
-
- override def run() {
- for (i <- 1 to numBlocksPerProducer) {
- val blockId = TestBlockId("b-" + id + "-" + i)
- val blockSize = Random.nextInt(1000)
- val block = (1 to blockSize).map(_ => Random.nextInt())
- val level = randomLevel()
- val startTime = System.currentTimeMillis()
- manager.putIterator(blockId, block.iterator, level, tellMaster = true)
- println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms")
- queue.add((blockId, block))
- }
- println("Producer thread " + id + " terminated")
- }
-
- def randomLevel(): StorageLevel = {
- math.abs(Random.nextInt()) % 4 match {
- case 0 => StorageLevel.MEMORY_ONLY
- case 1 => StorageLevel.MEMORY_ONLY_SER
- case 2 => StorageLevel.MEMORY_AND_DISK
- case 3 => StorageLevel.MEMORY_AND_DISK_SER
- }
- }
- }
-
- private[spark] class ConsumerThread(
- manager: BlockManager,
- queue: ArrayBlockingQueue[(BlockId, Seq[Int])]
- ) extends Thread {
- var numBlockConsumed = 0
-
- override def run() {
- println("Consumer thread started")
- while(numBlockConsumed < numBlocksPerProducer) {
- val (blockId, block) = queue.take()
- val startTime = System.currentTimeMillis()
- manager.get(blockId) match {
- case Some(retrievedBlock) =>
- assert(retrievedBlock.data.toList.asInstanceOf[List[Int]] == block.toList,
- "Block " + blockId + " did not match")
- println("Got block " + blockId + " in " +
- (System.currentTimeMillis - startTime) + " ms")
- case None =>
- assert(false, "Block " + blockId + " could not be retrieved")
- }
- numBlockConsumed += 1
- }
- println("Consumer thread terminated")
- }
- }
-
- def main(args: Array[String]) {
- System.setProperty("spark.kryoserializer.buffer.mb", "1")
- val actorSystem = ActorSystem("test")
- val conf = new SparkConf()
- val serializer = new KryoSerializer(conf)
- val blockManagerMaster = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))),
- conf)
- val blockManager = new BlockManager(
- "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf,
- new SecurityManager(conf), new MapOutputTrackerMaster(conf), new HashShuffleManager(conf))
- val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
- val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
- producers.foreach(_.start)
- consumers.foreach(_.start)
- producers.foreach(_.join)
- consumers.foreach(_.join)
- blockManager.stop()
- blockManagerMaster.stop()
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- println("Everything stopped.")
- println(
- "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.")
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index bee6dad3387e5..f0006b42aee4f 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -232,7 +232,7 @@ private[spark] object UIUtils extends Logging {
def listingTable[T](
headers: Seq[String],
generateDataRow: T => Seq[Node],
- data: Seq[T],
+ data: Iterable[T],
fixedWidth: Boolean = false): Seq[Node] = {
var listingTableClass = TABLE_CLASS
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
index 02df4e8fe61af..b0e3bb3b552fd 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
@@ -21,7 +21,6 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import org.apache.spark.storage.StorageLevel
import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage}
import org.apache.spark.util.Utils
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
index 0cc51c873727d..2987dc04494a5 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -24,8 +24,8 @@ import org.apache.spark.ui.{ToolTips, UIUtils}
import org.apache.spark.ui.jobs.UIData.StageUIData
import org.apache.spark.util.Utils
-/** Page showing executor summary */
-private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) {
+/** Stage summary grouped by executors. */
+private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobProgressTab) {
private val listener = parent.listener
def toNodeSeq: Seq[Node] = {
@@ -65,7 +65,7 @@ private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) {
executorIdToAddress.put(executorId, address)
}
- listener.stageIdToData.get(stageId) match {
+ listener.stageIdToData.get((stageId, stageAttemptId)) match {
case Some(stageData: StageUIData) =>
stageData.executorSummary.toSeq.sortBy(_._1).map { case (k, v) =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 74cd637d88155..eaeb861f59e5a 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ui.jobs
-import scala.collection.mutable.{HashMap, ListBuffer, Map}
+import scala.collection.mutable.{HashMap, ListBuffer}
import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
@@ -43,12 +43,16 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
// How many stages to remember
val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES)
- val activeStages = HashMap[Int, StageInfo]()
+ // Map from stageId to StageInfo
+ val activeStages = new HashMap[Int, StageInfo]
+
+ // Map from (stageId, attemptId) to StageUIData
+ val stageIdToData = new HashMap[(Int, Int), StageUIData]
+
val completedStages = ListBuffer[StageInfo]()
val failedStages = ListBuffer[StageInfo]()
- val stageIdToData = new HashMap[Int, StageUIData]
-
+ // Map from pool name to a hash map (map from stage id to StageInfo).
val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]()
val executorIdToBlockManagerId = HashMap[String, BlockManagerId]()
@@ -59,9 +63,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {
val stage = stageCompleted.stageInfo
- val stageId = stage.stageId
- val stageData = stageIdToData.getOrElseUpdate(stageId, {
- logWarning("Stage completed for unknown stage " + stageId)
+ val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), {
+ logWarning("Stage completed for unknown stage " + stage.stageId)
new StageUIData
})
@@ -69,8 +72,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
stageData.accumulables(id) = info
}
- poolToActiveStages.get(stageData.schedulingPool).foreach(_.remove(stageId))
- activeStages.remove(stageId)
+ poolToActiveStages.get(stageData.schedulingPool).foreach { hashMap =>
+ hashMap.remove(stage.stageId)
+ }
+ activeStages.remove(stage.stageId)
if (stage.failureReason.isEmpty) {
completedStages += stage
trimIfNecessary(completedStages)
@@ -84,7 +89,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
if (stages.size > retainedStages) {
val toRemove = math.max(retainedStages / 10, 1)
- stages.take(toRemove).foreach { s => stageIdToData.remove(s.stageId) }
+ stages.take(toRemove).foreach { s => stageIdToData.remove((s.stageId, s.attemptId)) }
stages.trimStart(toRemove)
}
}
@@ -98,21 +103,21 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
}.getOrElse(DEFAULT_POOL_NAME)
- val stageData = stageIdToData.getOrElseUpdate(stage.stageId, new StageUIData)
+ val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), new StageUIData)
stageData.schedulingPool = poolName
stageData.description = Option(stageSubmitted.properties).flatMap {
p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
}
- val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]())
+ val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo])
stages(stage.stageId) = stage
}
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
val taskInfo = taskStart.taskInfo
if (taskInfo != null) {
- val stageData = stageIdToData.getOrElseUpdate(taskStart.stageId, {
+ val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), {
logWarning("Task start for unknown stage " + taskStart.stageId)
new StageUIData
})
@@ -128,8 +133,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val info = taskEnd.taskInfo
- if (info != null) {
- val stageData = stageIdToData.getOrElseUpdate(taskEnd.stageId, {
+ // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task
+ // compeletion event is for. Let's just drop it here. This means we might have some speculation
+ // tasks on the web ui that's never marked as complete.
+ if (info != null && taskEnd.stageAttemptId != -1) {
+ val stageData = stageIdToData.getOrElseUpdate((taskEnd.stageId, taskEnd.stageAttemptId), {
logWarning("Task end for unknown stage " + taskEnd.stageId)
new StageUIData
})
@@ -222,8 +230,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
}
override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) {
- for ((taskId, sid, taskMetrics) <- executorMetricsUpdate.taskMetrics) {
- val stageData = stageIdToData.getOrElseUpdate(sid, {
+ for ((taskId, sid, sAttempt, taskMetrics) <- executorMetricsUpdate.taskMetrics) {
+ val stageData = stageIdToData.getOrElseUpdate((sid, sAttempt), {
logWarning("Metrics update for task in unknown stage " + sid)
new StageUIData
})
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index d4eb02722ad12..db01be596e073 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -34,7 +34,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
def render(request: HttpServletRequest): Seq[Node] = {
listener.synchronized {
val stageId = request.getParameter("id").toInt
- val stageDataOption = listener.stageIdToData.get(stageId)
+ val stageAttemptId = request.getParameter("attempt").toInt
+ val stageDataOption = listener.stageIdToData.get((stageId, stageAttemptId))
if (stageDataOption.isEmpty || stageDataOption.get.taskData.isEmpty) {
val content =
@@ -42,14 +43,15 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
Summary Metrics
No tasks have started yet
Tasks
No tasks have started yet
- return UIUtils.headerSparkPage("Details for Stage %s".format(stageId), content, parent)
+ return UIUtils.headerSparkPage(
+ s"Details for Stage $stageId (Attempt $stageAttemptId)", content, parent)
}
val stageData = stageDataOption.get
val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime)
val numCompleted = tasks.count(_.taskInfo.finished)
- val accumulables = listener.stageIdToData(stageId).accumulables
+ val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables
val hasInput = stageData.inputBytes > 0
val hasShuffleRead = stageData.shuffleReadBytes > 0
val hasShuffleWrite = stageData.shuffleWriteBytes > 0
@@ -211,7 +213,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
def quantileRow(data: Seq[Node]): Seq[Node] =
{data}
Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true))
}
- val executorTable = new ExecutorTable(stageId, parent)
+
+ val executorTable = new ExecutorTable(stageId, stageAttemptId, parent)
val maybeAccumulableTable: Seq[Node] =
if (accumulables.size > 0) { Accumulators
++ accumulableTable } else Seq()
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 16ad0df45aa0d..2e67310594784 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -97,8 +97,8 @@ private[ui] class StageTableBase(
}
// scalastyle:on
- val nameLinkUri ="%s/stages/stage?id=%s"
- .format(UIUtils.prependBaseUri(parent.basePath), s.stageId)
+ val nameLinkUri ="%s/stages/stage?id=%s&attempt=%s"
+ .format(UIUtils.prependBaseUri(parent.basePath), s.stageId, s.attemptId)
val nameLink = {s.name}
val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0)
@@ -121,7 +121,7 @@ private[ui] class StageTableBase(
}
val stageDesc = for {
- stageData <- listener.stageIdToData.get(s.stageId)
+ stageData <- listener.stageIdToData.get((s.stageId, s.attemptId))
desc <- stageData.description
} yield {
{desc}
@@ -131,7 +131,7 @@ private[ui] class StageTableBase(
}
protected def stageRow(s: StageInfo): Seq[Node] = {
- val stageDataOption = listener.stageIdToData.get(s.stageId)
+ val stageDataOption = listener.stageIdToData.get((s.stageId, s.attemptId))
if (stageDataOption.isEmpty) {
return {s.stageId} | No data available for this stage |
}
@@ -154,7 +154,11 @@ private[ui] class StageTableBase(
val shuffleWrite = stageData.shuffleWriteBytes
val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else ""
- {s.stageId} | ++
+ {if (s.attemptId > 0) {
+ {s.stageId} (retry {s.attemptId}) |
+ } else {
+ {s.stageId} |
+ }} ++
{if (isFairScheduler) {
info.numCachedPartitions > 0 }
+ // Remove all partitions that are no longer cached in current completed stage
+ val completedRddIds = stageCompleted.stageInfo.rddInfos.map(r => r.id).toSet
+ _rddInfoMap.retain { case (id, info) =>
+ !completedRddIds.contains(id) || info.numCachedPartitions > 0
+ }
}
override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized {
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index d6afb73b74242..e2d32c859bbda 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -27,7 +27,7 @@ import akka.pattern.ask
import com.typesafe.config.ConfigFactory
import org.apache.log4j.{Level, Logger}
-import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException}
/**
* Various utility classes for working with Akka.
@@ -192,10 +192,11 @@ private[spark] object AkkaUtils extends Logging {
}
def makeDriverRef(name: String, conf: SparkConf, actorSystem: ActorSystem): ActorRef = {
+ val driverActorSystemName = SparkEnv.driverActorSystemName
val driverHost: String = conf.get("spark.driver.host", "localhost")
val driverPort: Int = conf.getInt("spark.driver.port", 7077)
Utils.checkHost(driverHost, "Expected hostname")
- val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name"
+ val url = s"akka.tcp://$driverActorSystemName@$driverHost:$driverPort/user/$name"
val timeout = AkkaUtils.lookupTimeout(conf)
logInfo(s"Connecting to $name: $url")
Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala
index 2e8fbf5a91ee7..6d1fc05a15d2c 100644
--- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala
+++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala
@@ -41,18 +41,40 @@ import org.apache.spark.io.CompressionCodec
private[spark] class FileLogger(
logDir: String,
sparkConf: SparkConf,
- hadoopConf: Configuration = SparkHadoopUtil.get.newConfiguration(),
+ hadoopConf: Configuration,
outputBufferSize: Int = 8 * 1024, // 8 KB
compress: Boolean = false,
overwrite: Boolean = true,
dirPermissions: Option[FsPermission] = None)
extends Logging {
+ def this(
+ logDir: String,
+ sparkConf: SparkConf,
+ compress: Boolean = false,
+ overwrite: Boolean = true) = {
+ this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress,
+ overwrite = overwrite)
+ }
+
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
}
- private val fileSystem = Utils.getHadoopFileSystem(logDir)
+ /**
+ * To avoid effects of FileSystem#close or FileSystem.closeAll called from other modules,
+ * create unique FileSystem instance only for FileLogger
+ */
+ private val fileSystem = {
+ val conf = SparkHadoopUtil.get.newConfiguration(sparkConf)
+ val logUri = new URI(logDir)
+ val scheme = logUri.getScheme
+ if (scheme == "hdfs") {
+ conf.setBoolean("fs.hdfs.impl.disable.cache", true)
+ }
+ FileSystem.get(logUri, conf)
+ }
+
var fileIndex = 0
// Only used if compression is enabled
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 1e18ec688c40d..c4dddb2d1037e 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -96,6 +96,7 @@ private[spark] object JsonProtocol {
val taskInfo = taskStart.taskInfo
("Event" -> Utils.getFormattedClassName(taskStart)) ~
("Stage ID" -> taskStart.stageId) ~
+ ("Stage Attempt ID" -> taskStart.stageAttemptId) ~
("Task Info" -> taskInfoToJson(taskInfo))
}
@@ -112,6 +113,7 @@ private[spark] object JsonProtocol {
val taskMetricsJson = if (taskMetrics != null) taskMetricsToJson(taskMetrics) else JNothing
("Event" -> Utils.getFormattedClassName(taskEnd)) ~
("Stage ID" -> taskEnd.stageId) ~
+ ("Stage Attempt ID" -> taskEnd.stageAttemptId) ~
("Task Type" -> taskEnd.taskType) ~
("Task End Reason" -> taskEndReason) ~
("Task Info" -> taskInfoToJson(taskInfo)) ~
@@ -150,13 +152,15 @@ private[spark] object JsonProtocol {
val blockManagerId = blockManagerIdToJson(blockManagerAdded.blockManagerId)
("Event" -> Utils.getFormattedClassName(blockManagerAdded)) ~
("Block Manager ID" -> blockManagerId) ~
- ("Maximum Memory" -> blockManagerAdded.maxMem)
+ ("Maximum Memory" -> blockManagerAdded.maxMem) ~
+ ("Timestamp" -> blockManagerAdded.time)
}
def blockManagerRemovedToJson(blockManagerRemoved: SparkListenerBlockManagerRemoved): JValue = {
val blockManagerId = blockManagerIdToJson(blockManagerRemoved.blockManagerId)
("Event" -> Utils.getFormattedClassName(blockManagerRemoved)) ~
- ("Block Manager ID" -> blockManagerId)
+ ("Block Manager ID" -> blockManagerId) ~
+ ("Timestamp" -> blockManagerRemoved.time)
}
def unpersistRDDToJson(unpersistRDD: SparkListenerUnpersistRDD): JValue = {
@@ -167,6 +171,7 @@ private[spark] object JsonProtocol {
def applicationStartToJson(applicationStart: SparkListenerApplicationStart): JValue = {
("Event" -> Utils.getFormattedClassName(applicationStart)) ~
("App Name" -> applicationStart.appName) ~
+ ("App ID" -> applicationStart.appId.map(JString(_)).getOrElse(JNothing)) ~
("Timestamp" -> applicationStart.time) ~
("User" -> applicationStart.sparkUser)
}
@@ -187,6 +192,7 @@ private[spark] object JsonProtocol {
val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing)
val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing)
("Stage ID" -> stageInfo.stageId) ~
+ ("Stage Attempt ID" -> stageInfo.attemptId) ~
("Stage Name" -> stageInfo.name) ~
("Number of Tasks" -> stageInfo.numTasks) ~
("RDD Info" -> rddInfo) ~
@@ -199,7 +205,6 @@ private[spark] object JsonProtocol {
}
def taskInfoToJson(taskInfo: TaskInfo): JValue = {
- val accumUpdateMap = taskInfo.accumulables
("Task ID" -> taskInfo.taskId) ~
("Index" -> taskInfo.index) ~
("Attempt" -> taskInfo.attempt) ~
@@ -292,8 +297,7 @@ private[spark] object JsonProtocol {
def blockManagerIdToJson(blockManagerId: BlockManagerId): JValue = {
("Executor ID" -> blockManagerId.executorId) ~
("Host" -> blockManagerId.host) ~
- ("Port" -> blockManagerId.port) ~
- ("Netty Port" -> blockManagerId.nettyPort)
+ ("Port" -> blockManagerId.port)
}
def jobResultToJson(jobResult: JobResult): JValue = {
@@ -419,8 +423,9 @@ private[spark] object JsonProtocol {
def taskStartFromJson(json: JValue): SparkListenerTaskStart = {
val stageId = (json \ "Stage ID").extract[Int]
+ val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0)
val taskInfo = taskInfoFromJson(json \ "Task Info")
- SparkListenerTaskStart(stageId, taskInfo)
+ SparkListenerTaskStart(stageId, stageAttemptId, taskInfo)
}
def taskGettingResultFromJson(json: JValue): SparkListenerTaskGettingResult = {
@@ -430,11 +435,12 @@ private[spark] object JsonProtocol {
def taskEndFromJson(json: JValue): SparkListenerTaskEnd = {
val stageId = (json \ "Stage ID").extract[Int]
+ val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0)
val taskType = (json \ "Task Type").extract[String]
val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason")
val taskInfo = taskInfoFromJson(json \ "Task Info")
val taskMetrics = taskMetricsFromJson(json \ "Task Metrics")
- SparkListenerTaskEnd(stageId, taskType, taskEndReason, taskInfo, taskMetrics)
+ SparkListenerTaskEnd(stageId, stageAttemptId, taskType, taskEndReason, taskInfo, taskMetrics)
}
def jobStartFromJson(json: JValue): SparkListenerJobStart = {
@@ -462,12 +468,14 @@ private[spark] object JsonProtocol {
def blockManagerAddedFromJson(json: JValue): SparkListenerBlockManagerAdded = {
val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID")
val maxMem = (json \ "Maximum Memory").extract[Long]
- SparkListenerBlockManagerAdded(blockManagerId, maxMem)
+ val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L)
+ SparkListenerBlockManagerAdded(time, blockManagerId, maxMem)
}
def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = {
val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID")
- SparkListenerBlockManagerRemoved(blockManagerId)
+ val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L)
+ SparkListenerBlockManagerRemoved(time, blockManagerId)
}
def unpersistRDDFromJson(json: JValue): SparkListenerUnpersistRDD = {
@@ -476,9 +484,10 @@ private[spark] object JsonProtocol {
def applicationStartFromJson(json: JValue): SparkListenerApplicationStart = {
val appName = (json \ "App Name").extract[String]
+ val appId = Utils.jsonOption(json \ "App ID").map(_.extract[String])
val time = (json \ "Timestamp").extract[Long]
val sparkUser = (json \ "User").extract[String]
- SparkListenerApplicationStart(appName, time, sparkUser)
+ SparkListenerApplicationStart(appName, appId, time, sparkUser)
}
def applicationEndFromJson(json: JValue): SparkListenerApplicationEnd = {
@@ -492,6 +501,7 @@ private[spark] object JsonProtocol {
def stageInfoFromJson(json: JValue): StageInfo = {
val stageId = (json \ "Stage ID").extract[Int]
+ val attemptId = (json \ "Attempt ID").extractOpt[Int].getOrElse(0)
val stageName = (json \ "Stage Name").extract[String]
val numTasks = (json \ "Number of Tasks").extract[Int]
val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson(_))
@@ -504,7 +514,7 @@ private[spark] object JsonProtocol {
case None => Seq[AccumulableInfo]()
}
- val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfos, details)
+ val stageInfo = new StageInfo(stageId, attemptId, stageName, numTasks, rddInfos, details)
stageInfo.submissionTime = submissionTime
stageInfo.completionTime = completionTime
stageInfo.failureReason = failureReason
@@ -638,8 +648,7 @@ private[spark] object JsonProtocol {
val executorId = (json \ "Executor ID").extract[String]
val host = (json \ "Host").extract[String]
val port = (json \ "Port").extract[Int]
- val nettyPort = (json \ "Netty Port").extract[Int]
- BlockManagerId(executorId, host, port, nettyPort)
+ BlockManagerId(executorId, host, port)
}
def jobResultFromJson(json: JValue): JobResult = {
diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala
new file mode 100644
index 0000000000000..f64e069cd1724
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * Exception thrown when there is an exception in
+ * executing the callback in TaskCompletionListener.
+ */
+private[spark]
+class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception {
+
+ override def getMessage: String = {
+ if (errorMessages.size == 1) {
+ errorMessages.head
+ } else {
+ errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 019f68b160894..ed063844323af 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -20,9 +20,11 @@ package org.apache.spark.util
import java.io._
import java.net._
import java.nio.ByteBuffer
-import java.util.{Locale, Random, UUID}
+import java.util.{Properties, Locale, Random, UUID}
import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor}
+import org.apache.log4j.PropertyConfigurator
+
import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
@@ -34,6 +36,7 @@ import scala.util.control.{ControlThrowable, NonFatal}
import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.commons.lang3.SystemUtils
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
import org.json4s._
import tachyon.client.{TachyonFile,TachyonFS}
@@ -52,11 +55,6 @@ private[spark] case class CallSite(shortForm: String, longForm: String)
private[spark] object Utils extends Logging {
val random = new Random()
- def sparkBin(sparkHome: String, which: String): File = {
- val suffix = if (isWindows) ".cmd" else ""
- new File(sparkHome + File.separator + "bin", which + suffix)
- }
-
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@@ -162,30 +160,6 @@ private[spark] object Utils extends Logging {
}
}
- def isAlpha(c: Char): Boolean = {
- (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')
- }
-
- /** Split a string into words at non-alphabetic characters */
- def splitWords(s: String): Seq[String] = {
- val buf = new ArrayBuffer[String]
- var i = 0
- while (i < s.length) {
- var j = i
- while (j < s.length && isAlpha(s.charAt(j))) {
- j += 1
- }
- if (j > i) {
- buf += s.substring(i, j)
- }
- i = j
- while (i < s.length && !isAlpha(s.charAt(i))) {
- i += 1
- }
- }
- buf
- }
-
private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]()
@@ -347,7 +321,8 @@ private[spark] object Utils extends Logging {
* Throws SparkException if the target file already exists and has different contents than
* the requested file.
*/
- def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager) {
+ def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager,
+ hadoopConf: Configuration) {
val filename = url.split("/").last
val tempDir = getLocalDir(conf)
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
@@ -419,7 +394,7 @@ private[spark] object Utils extends Logging {
}
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
- val fs = getHadoopFileSystem(uri)
+ val fs = getHadoopFileSystem(uri, hadoopConf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(tempFile)
Utils.copyStream(in, out, true)
@@ -449,12 +424,71 @@ private[spark] object Utils extends Logging {
}
/**
- * Get a temporary directory using Spark's spark.local.dir property, if set. This will always
- * return a single directory, even though the spark.local.dir property might be a list of
- * multiple paths.
+ * Get the path of a temporary directory. Spark's local directories can be configured through
+ * multiple settings, which are used with the following precedence:
+ *
+ * - If called from inside of a YARN container, this will return a directory chosen by YARN.
+ * - If the SPARK_LOCAL_DIRS environment variable is set, this will return a directory from it.
+ * - Otherwise, if the spark.local.dir is set, this will return a directory from it.
+ * - Otherwise, this will return java.io.tmpdir.
+ *
+ * Some of these configuration options might be lists of multiple paths, but this method will
+ * always return a single directory.
*/
def getLocalDir(conf: SparkConf): String = {
- conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0)
+ getOrCreateLocalRootDirs(conf)(0)
+ }
+
+ private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = {
+ // These environment variables are set by YARN.
+ // For Hadoop 0.23.X, we check for YARN_LOCAL_DIRS (we use this below in getYarnLocalDirs())
+ // For Hadoop 2.X, we check for CONTAINER_ID.
+ conf.getenv("CONTAINER_ID") != null || conf.getenv("YARN_LOCAL_DIRS") != null
+ }
+
+ /**
+ * Gets or creates the directories listed in spark.local.dir or SPARK_LOCAL_DIRS,
+ * and returns only the directories that exist / could be created.
+ *
+ * If no directories could be created, this will return an empty list.
+ */
+ private[spark] def getOrCreateLocalRootDirs(conf: SparkConf): Array[String] = {
+ val confValue = if (isRunningInYarnContainer(conf)) {
+ // If we are in yarn mode, systems can have different disk layouts so we must set it
+ // to what Yarn on this system said was available.
+ getYarnLocalDirs(conf)
+ } else {
+ Option(conf.getenv("SPARK_LOCAL_DIRS")).getOrElse(
+ conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")))
+ }
+ val rootDirs = confValue.split(',')
+ logDebug(s"Getting/creating local root dirs at '$confValue'")
+
+ rootDirs.flatMap { rootDir =>
+ val localDir: File = new File(rootDir)
+ val foundLocalDir = localDir.exists || localDir.mkdirs()
+ if (!foundLocalDir) {
+ logError(s"Failed to create local root dir in $rootDir. Ignoring this directory.")
+ None
+ } else {
+ Some(rootDir)
+ }
+ }
+ }
+
+ /** Get the Yarn approved local directories. */
+ private def getYarnLocalDirs(conf: SparkConf): String = {
+ // Hadoop 0.23 and 2.x have different Environment variable names for the
+ // local dirs, so lets check both. We assume one of the 2 is set.
+ // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
+ val localDirs = Option(conf.getenv("YARN_LOCAL_DIRS"))
+ .getOrElse(Option(conf.getenv("LOCAL_DIRS"))
+ .getOrElse(""))
+
+ if (localDirs.isEmpty) {
+ throw new Exception("Yarn Local dirs can't be empty")
+ }
+ localDirs
}
/**
@@ -496,7 +530,12 @@ private[spark] object Utils extends Logging {
if (address.isLoopbackAddress) {
// Address resolves to something like 127.0.1.1, which happens on Debian; try to find
// a better address using the local network interfaces
- for (ni <- NetworkInterface.getNetworkInterfaces) {
+ // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order
+ // on unix-like system. On windows, it returns in index order.
+ // It's more proper to pick ip address following system output order.
+ val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.toList
+ val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse
+ for (ni <- reOrderedNetworkIFs) {
for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress &&
!addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) {
// We've found an address that looks reasonable!
@@ -771,14 +810,6 @@ private[spark] object Utils extends Logging {
}
}
- /**
- * Execute a command in the current working directory, throwing an exception if it completes
- * with an exit code other than 0.
- */
- def execute(command: Seq[String]) {
- execute(command, new File("."))
- }
-
/**
* Execute a command and get its output, throwing an exception if it yields a code other than 0.
*/
@@ -810,6 +841,7 @@ private[spark] object Utils extends Logging {
val exitCode = process.waitFor()
stdoutThread.join() // Wait for it to finish reading output
if (exitCode != 0) {
+ logError(s"Process $command exited with code $exitCode: ${output}")
throw new SparkException("Process " + command + " exited with code " + exitCode)
}
output.toString
@@ -840,8 +872,8 @@ private[spark] object Utils extends Logging {
*/
def getCallSite: CallSite = {
val trace = Thread.currentThread.getStackTrace()
- .filterNot { ste:StackTraceElement =>
- // When running under some profilers, the current stack trace might contain some bogus
+ .filterNot { ste:StackTraceElement =>
+ // When running under some profilers, the current stack trace might contain some bogus
// frames. This is intended to ensure that we don't crash in these situations by
// ignoring any frames that we can't examine.
(ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace"))
@@ -1157,15 +1189,15 @@ private[spark] object Utils extends Logging {
/**
* Return a Hadoop FileSystem with the scheme encoded in the given path.
*/
- def getHadoopFileSystem(path: URI): FileSystem = {
- FileSystem.get(path, SparkHadoopUtil.get.newConfiguration())
+ def getHadoopFileSystem(path: URI, conf: Configuration): FileSystem = {
+ FileSystem.get(path, conf)
}
/**
* Return a Hadoop FileSystem with the scheme encoded in the given path.
*/
- def getHadoopFileSystem(path: String): FileSystem = {
- getHadoopFileSystem(new URI(path))
+ def getHadoopFileSystem(path: String, conf: Configuration): FileSystem = {
+ getHadoopFileSystem(new URI(path), conf)
}
/**
@@ -1242,7 +1274,7 @@ private[spark] object Utils extends Logging {
}
}
- /**
+ /**
* Execute the given block, logging and re-throwing any uncaught exception.
* This is particularly useful for wrapping code that runs in a thread, to ensure
* that exceptions are printed, and to avoid having to catch Throwable.
@@ -1350,15 +1382,15 @@ private[spark] object Utils extends Logging {
}
/**
- * Default number of retries in binding to a port.
+ * Default maximum number of retries when binding to a port before giving up.
*/
val portMaxRetries: Int = {
if (sys.props.contains("spark.testing")) {
// Set a higher number of retries for tests...
- sys.props.get("spark.ports.maxRetries").map(_.toInt).getOrElse(100)
+ sys.props.get("spark.port.maxRetries").map(_.toInt).getOrElse(100)
} else {
Option(SparkEnv.get)
- .flatMap(_.conf.getOption("spark.ports.maxRetries"))
+ .flatMap(_.conf.getOption("spark.port.maxRetries"))
.map(_.toInt)
.getOrElse(16)
}
@@ -1420,4 +1452,39 @@ private[spark] object Utils extends Logging {
}
}
+ /**
+ * config a log4j properties used for testsuite
+ */
+ def configTestLog4j(level: String): Unit = {
+ val pro = new Properties()
+ pro.put("log4j.rootLogger", s"$level, console")
+ pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender")
+ pro.put("log4j.appender.console.target", "System.err")
+ pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout")
+ pro.put("log4j.appender.console.layout.ConversionPattern",
+ "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n")
+ PropertyConfigurator.configure(pro)
+ }
+
+}
+
+/**
+ * A utility class to redirect the child process's stdout or stderr.
+ */
+private[spark] class RedirectThread(in: InputStream, out: OutputStream, name: String)
+ extends Thread(name) {
+
+ setDaemon(true)
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
+ val buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ out.write(buf, 0, len)
+ out.flush()
+ len = in.read(buf)
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 9f85b94a70800..8a015c1d26a96 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -413,7 +413,12 @@ class ExternalAppendOnlyMap[K, V, C](
extends Iterator[(K, C)]
{
private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1
- assert(file.length() == batchOffsets(batchOffsets.length - 1))
+ assert(file.length() == batchOffsets.last,
+ "File length is not equal to the last batch offset:\n" +
+ s" file length = ${file.length}\n" +
+ s" last batch offset = ${batchOffsets.last}\n" +
+ s" all batch offsets = ${batchOffsets.mkString(",")}"
+ )
private var batchIndex = 0 // Which batch we're in
private var fileStream: FileInputStream = null
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 5d8a648d9551e..782b979e2e93d 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -719,20 +719,20 @@ private[spark] class ExternalSorter[K, V, C](
def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2)
/**
- * Write all the data added into this ExternalSorter into a file in the disk store, creating
- * an .index file for it as well with the offsets of each partition. This is called by the
- * SortShuffleWriter and can go through an efficient path of just concatenating binary files
- * if we decided to avoid merge-sorting.
+ * Write all the data added into this ExternalSorter into a file in the disk store. This is
+ * called by the SortShuffleWriter and can go through an efficient path of just concatenating
+ * binary files if we decided to avoid merge-sorting.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
* @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
- def writePartitionedFile(blockId: BlockId, context: TaskContext): Array[Long] = {
- val outputFile = blockManager.diskBlockManager.getFile(blockId)
+ def writePartitionedFile(
+ blockId: BlockId,
+ context: TaskContext,
+ outputFile: File): Array[Long] = {
// Track location of each range in the output file
- val offsets = new Array[Long](numPartitions + 1)
val lengths = new Array[Long](numPartitions)
if (bypassMergeSort && partitionWriters != null) {
@@ -750,7 +750,6 @@ private[spark] class ExternalSorter[K, V, C](
in.close()
in = null
lengths(i) = size
- offsets(i + 1) = offsets(i) + lengths(i)
}
} finally {
if (out != null) {
@@ -772,11 +771,7 @@ private[spark] class ExternalSorter[K, V, C](
}
writer.commitAndClose()
val segment = writer.fileSegment()
- offsets(id + 1) = segment.offset + segment.length
lengths(id) = segment.length
- } else {
- // The partition is empty; don't create a new writer to avoid writing headers, etc
- offsets(id + 1) = offsets(id)
}
}
}
@@ -784,23 +779,6 @@ private[spark] class ExternalSorter[K, V, C](
context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
context.taskMetrics.diskBytesSpilled += diskBytesSpilled
- // Write an index file with the offsets of each block, plus a final offset at the end for the
- // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure
- // out where each block begins and ends.
-
- val diskBlockManager = blockManager.diskBlockManager
- val indexFile = diskBlockManager.getFile(blockId.name + ".index")
- val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
- try {
- var i = 0
- while (i < numPartitions + 1) {
- out.writeLong(offsets(i))
- i += 1
- }
- } finally {
- out.close()
- }
-
lengths
}
@@ -811,7 +789,7 @@ private[spark] class ExternalSorter[K, V, C](
if (writer.isOpen) {
writer.commitAndClose()
}
- blockManager.getLocalFromDisk(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]]
+ blockManager.diskStore.getValues(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]]
}
def stop(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
new file mode 100644
index 0000000000000..daac6f971eb20
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.io
+
+import java.io.OutputStream
+
+import scala.collection.mutable.ArrayBuffer
+
+
+/**
+ * An OutputStream that writes to fixed-size chunks of byte arrays.
+ *
+ * @param chunkSize size of each chunk, in bytes.
+ */
+private[spark]
+class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
+
+ private val chunks = new ArrayBuffer[Array[Byte]]
+
+ /** Index of the last chunk. Starting with -1 when the chunks array is empty. */
+ private var lastChunkIndex = -1
+
+ /**
+ * Next position to write in the last chunk.
+ *
+ * If this equals chunkSize, it means for next write we need to allocate a new chunk.
+ * This can also never be 0.
+ */
+ private var position = chunkSize
+
+ override def write(b: Int): Unit = {
+ allocateNewChunkIfNeeded()
+ chunks(lastChunkIndex)(position) = b.toByte
+ position += 1
+ }
+
+ override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
+ var written = 0
+ while (written < len) {
+ allocateNewChunkIfNeeded()
+ val thisBatch = math.min(chunkSize - position, len - written)
+ System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch)
+ written += thisBatch
+ position += thisBatch
+ }
+ }
+
+ @inline
+ private def allocateNewChunkIfNeeded(): Unit = {
+ if (position == chunkSize) {
+ chunks += new Array[Byte](chunkSize)
+ lastChunkIndex += 1
+ position = 0
+ }
+ }
+
+ def toArrays: Array[Array[Byte]] = {
+ if (lastChunkIndex == -1) {
+ new Array[Array[Byte]](0)
+ } else {
+ // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk.
+ // An alternative would have been returning an array of ByteBuffers, with the last buffer
+ // bounded to only the last chunk's position. However, given our use case in Spark (to put
+ // the chunks in block manager), only limiting the view bound of the buffer would still
+ // require the block manager to store the whole chunk.
+ val ret = new Array[Array[Byte]](chunks.size)
+ for (i <- 0 until chunks.size - 1) {
+ ret(i) = chunks(i)
+ }
+ if (position == chunkSize) {
+ ret(lastChunkIndex) = chunks(lastChunkIndex)
+ } else {
+ ret(lastChunkIndex) = new Array[Byte](position)
+ System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position)
+ }
+ ret
+ }
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index e1c13de04a0be..b8574dfb42e6b 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -29,19 +29,14 @@
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
-import com.google.common.collect.Sets;
import com.google.common.base.Optional;
import com.google.common.base.Charsets;
import com.google.common.io.Files;
import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.compress.DefaultCodec;
-import org.apache.hadoop.mapred.FileSplit;
-import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
-import org.apache.hadoop.mapred.TextInputFormat;
import org.apache.hadoop.mapreduce.Job;
import org.junit.After;
import org.junit.Assert;
@@ -49,7 +44,6 @@
import org.junit.Test;
import org.apache.spark.api.java.JavaDoubleRDD;
-import org.apache.spark.api.java.JavaHadoopRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
@@ -189,6 +183,36 @@ public void sortByKey() {
Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2));
}
+ @Test
+ public void repartitionAndSortWithinPartitions() {
+ List> pairs = new ArrayList>();
+ pairs.add(new Tuple2(0, 5));
+ pairs.add(new Tuple2(3, 8));
+ pairs.add(new Tuple2(2, 6));
+ pairs.add(new Tuple2(0, 8));
+ pairs.add(new Tuple2(3, 8));
+ pairs.add(new Tuple2(1, 3));
+
+ JavaPairRDD rdd = sc.parallelizePairs(pairs);
+
+ Partitioner partitioner = new Partitioner() {
+ public int numPartitions() {
+ return 2;
+ }
+ public int getPartition(Object key) {
+ return ((Integer)key).intValue() % 2;
+ }
+ };
+
+ JavaPairRDD repartitioned =
+ rdd.repartitionAndSortWithinPartitions(partitioner);
+ List>> partitions = repartitioned.glom().collect();
+ Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2(0, 5),
+ new Tuple2(0, 8), new Tuple2(2, 6)));
+ Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2(1, 3),
+ new Tuple2(3, 8), new Tuple2(3, 8)));
+ }
+
@Test
public void emptyRDD() {
JavaRDD rdd = sc.emptyRDD();
@@ -1283,23 +1307,4 @@ public void collectUnderlyingScalaRDD() {
SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect();
Assert.assertEquals(data.size(), collected.length);
}
-
- public void getHadoopInputSplits() {
- String outDir = new File(tempDir, "output").getAbsolutePath();
- sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2).saveAsTextFile(outDir);
-
- JavaHadoopRDD hadoopRDD = (JavaHadoopRDD)
- sc.hadoopFile(outDir, TextInputFormat.class, LongWritable.class, Text.class);
- List inputPaths = hadoopRDD.mapPartitionsWithInputSplit(
- new Function2>, Iterator>() {
- @Override
- public Iterator call(InputSplit split, Iterator> it)
- throws Exception {
- FileSplit fileSplit = (FileSplit) split;
- return Lists.newArrayList(fileSplit.getPath().toUri().getPath()).iterator();
- }
- }, true).collect();
- Assert.assertEquals(Sets.newHashSet(inputPaths),
- Sets.newHashSet(outDir + "/part-00000", outDir + "/part-00001"));
- }
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 9c5f394d3899d..90dcadcffd091 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -32,6 +32,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
var split: Partition = _
/** An RDD which returns the values [1, 2, 3, 4]. */
var rdd: RDD[Int] = _
+ var rdd2: RDD[Int] = _
+ var rdd3: RDD[Int] = _
before {
sc = new SparkContext("local", "test")
@@ -43,6 +45,16 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
override val getDependencies = List[Dependency[_]]()
override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator
}
+ rdd2 = new RDD[Int](sc, List(new OneToOneDependency(rdd))) {
+ override def getPartitions: Array[Partition] = firstParent[Int].partitions
+ override def compute(split: Partition, context: TaskContext) =
+ firstParent[Int].iterator(split, context)
+ }.cache()
+ rdd3 = new RDD[Int](sc, List(new OneToOneDependency(rdd2))) {
+ override def getPartitions: Array[Partition] = firstParent[Int].partitions
+ override def compute(split: Partition, context: TaskContext) =
+ firstParent[Int].iterator(split, context)
+ }.cache()
}
after {
@@ -87,4 +99,11 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
assert(value.toList === List(1, 2, 3, 4))
}
}
+
+ test("verify task metrics updated correctly") {
+ cacheManager = sc.env.cacheManager
+ val context = new TaskContext(0, 0, 0)
+ cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
+ assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 4bc4346c0a288..2e3fc5ef0e336 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -21,7 +21,6 @@ import java.lang.ref.WeakReference
import scala.collection.mutable.{HashSet, SynchronizedSet}
import scala.language.existentials
-import scala.language.postfixOps
import scala.util.Random
import org.scalatest.{BeforeAndAfter, FunSuite}
@@ -52,6 +51,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[Ha
.setMaster("local[2]")
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
+ .set("spark.cleaner.referenceTracking.blocking.shuffle", "true")
.set("spark.shuffle.manager", shuffleManager.getName)
before {
@@ -243,6 +243,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase {
.setMaster("local-cluster[2, 1, 512]")
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
+ .set("spark.cleaner.referenceTracking.blocking.shuffle", "true")
.set("spark.shuffle.manager", shuffleManager.getName)
sc = new SparkContext(conf2)
@@ -319,6 +320,7 @@ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[Sor
.setMaster("local-cluster[2, 1, 512]")
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
+ .set("spark.cleaner.referenceTracking.blocking.shuffle", "true")
.set("spark.shuffle.manager", shuffleManager.getName)
sc = new SparkContext(conf2)
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 41c294f727b3c..81b64c36ddca1 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -24,8 +24,7 @@ import org.scalatest.Matchers
import org.scalatest.time.{Millis, Span}
import org.apache.spark.SparkContext._
-import org.apache.spark.network.ConnectionManagerId
-import org.apache.spark.storage.{BlockManagerWorker, GetBlock, RDDBlockId, StorageLevel}
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
class NotSerializableClass
class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {}
@@ -136,7 +135,6 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter
sc.parallelize(1 to 10, 2).foreach { x => if (x == 1) System.exit(42) }
}
assert(thrown.getClass === classOf[SparkException])
- System.out.println(thrown.getMessage)
assert(thrown.getMessage.contains("failed 4 times"))
}
}
@@ -202,12 +200,13 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter
val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray
val blockId = blockIds(0)
val blockManager = SparkEnv.get.blockManager
- blockManager.master.getLocations(blockId).foreach(id => {
- val bytes = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(id.host, id.port))
- val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList
+ val blockTransfer = SparkEnv.get.blockTransferService
+ blockManager.master.getLocations(blockId).foreach { cmId =>
+ val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, blockId.toString)
+ val deserialized = blockManager.dataDeserialize(blockId, bytes.nioByteBuffer())
+ .asInstanceOf[Iterator[Int]].toList
assert(deserialized === (1 to 100).toList)
- })
+ }
}
test("compute without caching when no partitions fit in memory") {
diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
index a73e1ef0288a5..5265ba904032f 100644
--- a/core/src/test/scala/org/apache/spark/DriverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
@@ -19,9 +19,6 @@ package org.apache.spark
import java.io.File
-import org.apache.log4j.Logger
-import org.apache.log4j.Level
-
import org.scalatest.FunSuite
import org.scalatest.concurrent.Timeouts
import org.scalatest.prop.TableDrivenPropertyChecks._
@@ -29,8 +26,6 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.util.Utils
-import scala.language.postfixOps
-
class DriverSuite extends FunSuite with Timeouts {
test("driver should exit after finishing") {
@@ -54,7 +49,7 @@ class DriverSuite extends FunSuite with Timeouts {
*/
object DriverWithoutCleanup {
def main(args: Array[String]) {
- Logger.getRootLogger().setLevel(Level.WARN)
+ Utils.configTestLog4j("INFO")
val sc = new SparkContext(args(0), "DriverWithoutCleanup")
sc.parallelize(1 to 100, 4).count()
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala
similarity index 66%
rename from core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala
rename to core/src/test/scala/org/apache/spark/HashShuffleSuite.scala
index f4261c13f70a8..2acc02a54fa3d 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala
+++ b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala
@@ -15,17 +15,19 @@
* limitations under the License.
*/
-package org.apache.spark.network.netty
+package org.apache.spark
-import io.netty.channel.ChannelInitializer
-import io.netty.channel.socket.SocketChannel
-import io.netty.handler.codec.string.StringEncoder
+import org.scalatest.BeforeAndAfterAll
+class HashShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
-class FileClientChannelInitializer(handler: FileClientHandler)
- extends ChannelInitializer[SocketChannel] {
+ // This test suite should run all tests in ShuffleSuite with hash-based shuffle.
- def initChannel(channel: SocketChannel) {
- channel.pipeline.addLast("encoder", new StringEncoder).addLast("handler", handler)
+ override def beforeAll() {
+ System.setProperty("spark.shuffle.manager", "hash")
+ }
+
+ override def afterAll() {
+ System.clearProperty("spark.shuffle.manager")
}
}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 9702838085627..5369169811f81 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -69,13 +69,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize10000)))
- tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000)))
val statuses = tracker.getServerStatuses(10, 0)
- assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000),
- (BlockManagerId("b", "hostB", 1000, 0), size10000)))
+ assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000),
+ (BlockManagerId("b", "hostB", 1000), size10000)))
tracker.stop()
}
@@ -86,9 +86,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
- tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize10000)))
- tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000)))
assert(tracker.containsShuffle(10))
assert(tracker.getServerStatuses(10, 0).nonEmpty)
@@ -105,14 +105,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
- tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
- tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
// As if we had two simultaneous fetch failures
- tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
- tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
// The remaining reduce task might try to grab the output despite the shuffle failure;
// this should cause it to fail, and the scheduler will ignore the failure due to the
@@ -145,13 +145,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
masterTracker.registerMapOutput(10, 0, new MapStatus(
- BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+ Seq((BlockManagerId("a", "hostA", 1000), size1000)))
- masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
+ masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
@@ -174,7 +174,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
// Frame size should be ~123B, and no exception should be thrown
masterTracker.registerShuffle(10, 1)
masterTracker.registerMapOutput(10, 0, new MapStatus(
- BlockManagerId("88", "mph", 1000, 0), Array.fill[Byte](10)(0)))
+ BlockManagerId("88", "mph", 1000), Array.fill[Byte](10)(0)))
masterActor.receive(GetMapOutputStatuses(10))
}
@@ -195,7 +195,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
masterTracker.registerShuffle(20, 100)
(0 until 100).foreach { i =>
masterTracker.registerMapOutput(20, i, new MapStatus(
- BlockManagerId("999", "mps", 1000, 0), Array.fill[Byte](4000000)(0)))
+ BlockManagerId("999", "mps", 1000), Array.fill[Byte](4000000)(0)))
}
intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) }
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index b13ddf96bc77c..15aa4d83800fa 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.MutablePair
-class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
+abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
val conf = new SparkConf(loadDefaults = false)
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
index 5c02c00586ef4..639e56c488db4 100644
--- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -24,8 +24,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with sort-based shuffle.
override def beforeAll() {
- System.setProperty("spark.shuffle.manager",
- "org.apache.spark.shuffle.sort.SortShuffleManager")
+ System.setProperty("spark.shuffle.manager", "sort")
}
override def afterAll() {
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 17c64455b2429..978a6ded80829 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -17,10 +17,12 @@
package org.apache.spark.broadcast
-import org.apache.spark.storage.{BroadcastBlockId, _}
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
import org.scalatest.FunSuite
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
+import org.apache.spark.storage._
+
+
class BroadcastSuite extends FunSuite with LocalSparkContext {
private val httpConf = broadcastConf("HttpBroadcastFactory")
@@ -124,12 +126,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
val numSlaves = if (distributed) 2 else 0
- def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))
-
// Verify that the broadcast file is created, and blocks are persisted only on the driver
- def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
+ val blockId = BroadcastBlockId(broadcastId)
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === 1)
statuses.head match { case (bm, status) =>
assert(bm.executorId === "", "Block should only be on the driver")
@@ -139,14 +139,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
}
if (distributed) {
// this file is only generated in distributed mode
- assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!")
+ assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!")
}
}
// Verify that blocks are persisted in both the executors and the driver
- def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
+ val blockId = BroadcastBlockId(broadcastId)
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === numSlaves + 1)
statuses.foreach { case (_, status) =>
assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
@@ -157,21 +157,21 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true. In the latter case, also verify that the broadcast file is deleted on the driver.
- def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- assert(blockIds.size === 1)
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
+ val blockId = BroadcastBlockId(broadcastId)
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
val expectedNumBlocks = if (removeFromDriver) 0 else 1
val possiblyNot = if (removeFromDriver) "" else " not"
assert(statuses.size === expectedNumBlocks,
"Block should%s be unpersisted on the driver".format(possiblyNot))
if (distributed && removeFromDriver) {
// this file is only generated in distributed mode
- assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists,
+ assert(!HttpBroadcast.getFile(blockId.broadcastId).exists,
"Broadcast file should%s be deleted".format(possiblyNot))
}
}
- testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation,
+ testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation,
afterUsingBroadcast, afterUnpersist, removeFromDriver)
}
@@ -185,67 +185,51 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
val numSlaves = if (distributed) 2 else 0
- def getBlockIds(id: Long) = {
- val broadcastBlockId = BroadcastBlockId(id)
- val metaBlockId = BroadcastBlockId(id, "meta")
- // Assume broadcast value is small enough to fit into 1 piece
- val pieceBlockId = BroadcastBlockId(id, "piece0")
- if (distributed) {
- // the metadata and piece blocks are generated only in distributed mode
- Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
- } else {
- Seq[BroadcastBlockId](broadcastBlockId)
- }
+ // Verify that blocks are persisted only on the driver
+ def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
+ var blockId = BroadcastBlockId(broadcastId)
+ var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === 1)
+
+ blockId = BroadcastBlockId(broadcastId, "piece0")
+ statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === (if (distributed) 1 else 0))
}
- // Verify that blocks are persisted only on the driver
- def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ // Verify that blocks are persisted in both the executors and the driver
+ def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
+ var blockId = BroadcastBlockId(broadcastId)
+ var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ if (distributed) {
+ assert(statuses.size === numSlaves + 1)
+ } else {
assert(statuses.size === 1)
- statuses.head match { case (bm, status) =>
- assert(bm.executorId === "", "Block should only be on the driver")
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store on the driver")
- assert(status.diskSize === 0, "Block should not be in disk store on the driver")
- }
}
- }
- // Verify that blocks are persisted in both the executors and the driver
- def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- if (blockId.field == "meta") {
- // Meta data is only on the driver
- assert(statuses.size === 1)
- statuses.head match { case (bm, _) => assert(bm.executorId === "") }
- } else {
- // Other blocks are on both the executors and the driver
- assert(statuses.size === numSlaves + 1,
- blockId + " has " + statuses.size + " statuses: " + statuses.mkString(","))
- statuses.foreach { case (_, status) =>
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store")
- assert(status.diskSize === 0, "Block should not be in disk store")
- }
- }
+ blockId = BroadcastBlockId(broadcastId, "piece0")
+ statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ if (distributed) {
+ assert(statuses.size === numSlaves + 1)
+ } else {
+ assert(statuses.size === 0)
}
}
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true.
- def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
- val expectedNumBlocks = if (removeFromDriver) 0 else 1
- val possiblyNot = if (removeFromDriver) "" else " not"
- blockIds.foreach { blockId =>
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- assert(statuses.size === expectedNumBlocks,
- "Block should%s be unpersisted on the driver".format(possiblyNot))
- }
+ def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
+ var blockId = BroadcastBlockId(broadcastId)
+ var expectedNumBlocks = if (removeFromDriver) 0 else 1
+ var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === expectedNumBlocks)
+
+ blockId = BroadcastBlockId(broadcastId, "piece0")
+ expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1
+ statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === expectedNumBlocks)
}
- testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation,
+ testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation,
afterUsingBroadcast, afterUnpersist, removeFromDriver)
}
@@ -262,10 +246,9 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
distributed: Boolean,
numSlaves: Int, // used only when distributed = true
broadcastConf: SparkConf,
- getBlockIds: Long => Seq[BroadcastBlockId],
- afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
- afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
- afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ afterCreation: (Long, BlockManagerMaster) => Unit,
+ afterUsingBroadcast: (Long, BlockManagerMaster) => Unit,
+ afterUnpersist: (Long, BlockManagerMaster) => Unit,
removeFromDriver: Boolean) {
sc = if (distributed) {
@@ -278,15 +261,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
// Create broadcast variable
val broadcast = sc.broadcast(list)
- val blocks = getBlockIds(broadcast.id)
- afterCreation(blocks, blockManagerMaster)
+ afterCreation(broadcast.id, blockManagerMaster)
// Use broadcast variable on all executors
val partitions = 10
assert(partitions > numSlaves)
val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
- afterUsingBroadcast(blocks, blockManagerMaster)
+ afterUsingBroadcast(broadcast.id, blockManagerMaster)
// Unpersist broadcast
if (removeFromDriver) {
@@ -294,7 +276,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
} else {
broadcast.unpersist(blocking = true)
}
- afterUnpersist(blocks, blockManagerMaster)
+ afterUnpersist(broadcast.id, blockManagerMaster)
// If the broadcast is removed from driver, all subsequent uses of the broadcast variable
// should throw SparkExceptions. Otherwise, the result should be the same as before.
diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
index 31aa7ec837f43..3f1cd0752e766 100644
--- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
@@ -115,14 +115,16 @@ class JsonProtocolSuite extends FunSuite {
workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis
workerInfo
}
+
def createExecutorRunner(): ExecutorRunner = {
new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host",
new File("sparkHome"), new File("workDir"), "akka://worker",
new SparkConf, ExecutorState.RUNNING)
}
+
def createDriverRunner(): DriverRunner = {
- new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), createDriverDesc(),
- null, "akka://worker")
+ new DriverRunner(new SparkConf(), "driverId", new File("workDir"), new File("sparkHome"),
+ createDriverDesc(), null, "akka://worker")
}
def assertValidJson(json: JValue) {
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 7e1ef80c84561..0c324d8bdf6a4 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -154,6 +154,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
sysProps("spark.app.name") should be ("beauty")
sysProps("spark.shuffle.spill") should be ("false")
sysProps("SPARK_SUBMIT") should be ("true")
+ sysProps.keys should not contain ("spark.jars")
}
test("handles YARN client mode") {
@@ -317,6 +318,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
object JarCreationTest {
def main(args: Array[String]) {
+ Utils.configTestLog4j("INFO")
val conf = new SparkConf()
val sc = new SparkContext(conf)
val result = sc.makeRDD(1 to 100, 10).mapPartitions { x =>
@@ -338,6 +340,7 @@ object JarCreationTest {
object SimpleApplicationTest {
def main(args: Array[String]) {
+ Utils.configTestLog4j("INFO")
val conf = new SparkConf()
val sc = new SparkContext(conf)
val configs = Seq("spark.master", "spark.app.name")
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
index c930839b47f11..b6f4411e0587a 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala
@@ -25,14 +25,15 @@ import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.FunSuite
+import org.apache.spark.SparkConf
import org.apache.spark.deploy.{Command, DriverDescription}
class DriverRunnerTest extends FunSuite {
private def createDriverRunner() = {
val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq())
val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command)
- new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), driverDescription,
- null, "akka://1.2.3.4/worker/")
+ new DriverRunner(new SparkConf(), "driverId", new File("workDir"), new File("sparkHome"),
+ driverDescription, null, "akka://1.2.3.4/worker/")
}
private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = {
diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
new file mode 100644
index 0000000000000..02d0ffc86f58f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.io.{RandomAccessFile, File}
+import java.nio.ByteBuffer
+import java.util.{Collections, HashSet}
+import java.util.concurrent.{TimeUnit, Semaphore}
+
+import scala.collection.JavaConversions._
+
+import io.netty.buffer.{ByteBufUtil, Unpooled}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory}
+import org.apache.spark.network.netty.server.BlockServer
+import org.apache.spark.storage.{FileSegment, BlockDataProvider}
+
+
+/**
+ * Test suite that makes sure the server and the client implementations share the same protocol.
+ */
+class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
+
+ val bufSize = 100000
+ var buf: ByteBuffer = _
+ var testFile: File = _
+ var server: BlockServer = _
+ var clientFactory: BlockFetchingClientFactory = _
+
+ val bufferBlockId = "buffer_block"
+ val fileBlockId = "file_block"
+
+ val fileContent = new Array[Byte](1024)
+ scala.util.Random.nextBytes(fileContent)
+
+ override def beforeAll() = {
+ buf = ByteBuffer.allocate(bufSize)
+ for (i <- 1 to bufSize) {
+ buf.put(i.toByte)
+ }
+ buf.flip()
+
+ testFile = File.createTempFile("netty-test-file", "txt")
+ val fp = new RandomAccessFile(testFile, "rw")
+ fp.write(fileContent)
+ fp.close()
+
+ server = new BlockServer(new SparkConf, new BlockDataProvider {
+ override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
+ if (blockId == bufferBlockId) {
+ Right(buf)
+ } else if (blockId == fileBlockId) {
+ Left(new FileSegment(testFile, 10, testFile.length - 25))
+ } else {
+ throw new Exception("Unknown block id " + blockId)
+ }
+ }
+ })
+
+ clientFactory = new BlockFetchingClientFactory(new SparkConf)
+ }
+
+ override def afterAll() = {
+ server.stop()
+ clientFactory.stop()
+ }
+
+ /** A ByteBuf for buffer_block */
+ lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf)
+
+ /** A ByteBuf for file_block */
+ lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25)
+
+ def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) =
+ {
+ val client = clientFactory.createClient(server.hostName, server.port)
+ val sem = new Semaphore(0)
+ val receivedBlockIds = Collections.synchronizedSet(new HashSet[String])
+ val errorBlockIds = Collections.synchronizedSet(new HashSet[String])
+ val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer])
+
+ client.fetchBlocks(
+ blockIds,
+ new BlockClientListener {
+ override def onFetchFailure(blockId: String, errorMsg: String): Unit = {
+ errorBlockIds.add(blockId)
+ sem.release()
+ }
+
+ override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = {
+ receivedBlockIds.add(blockId)
+ data.retain()
+ receivedBuffers.add(data)
+ sem.release()
+ }
+ }
+ )
+ if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server")
+ }
+ client.close()
+ (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet)
+ }
+
+ test("fetch a ByteBuffer block") {
+ val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId))
+ assert(blockIds === Set(bufferBlockId))
+ assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
+ assert(failBlockIds.isEmpty)
+ buffers.foreach(_.release())
+ }
+
+ test("fetch a FileSegment block via zero-copy send") {
+ val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId))
+ assert(blockIds === Set(fileBlockId))
+ assert(buffers.map(_.underlying) === Set(fileBlockReference))
+ assert(failBlockIds.isEmpty)
+ buffers.foreach(_.release())
+ }
+
+ test("fetch a non-existent block") {
+ val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block"))
+ assert(blockIds.isEmpty)
+ assert(buffers.isEmpty)
+ assert(failBlockIds === Set("random-block"))
+ }
+
+ test("fetch both ByteBuffer block and FileSegment block") {
+ val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId))
+ assert(blockIds === Set(bufferBlockId, fileBlockId))
+ assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference))
+ assert(failBlockIds.isEmpty)
+ buffers.foreach(_.release())
+ }
+
+ test("fetch both ByteBuffer block and a non-existent block") {
+ val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block"))
+ assert(blockIds === Set(bufferBlockId))
+ assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
+ assert(failBlockIds === Set("random-block"))
+ buffers.foreach(_.release())
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
new file mode 100644
index 0000000000000..903ab09ae4322
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.client
+
+import java.nio.ByteBuffer
+
+import io.netty.buffer.Unpooled
+import io.netty.channel.embedded.EmbeddedChannel
+
+import org.scalatest.{PrivateMethodTester, FunSuite}
+
+
+class BlockFetchingClientHandlerSuite extends FunSuite with PrivateMethodTester {
+
+ test("handling block data (successful fetch)") {
+ val blockId = "test_block"
+ val blockData = "blahblahblahblahblah"
+ val totalLength = 4 + blockId.length + blockData.length
+
+ var parsedBlockId: String = ""
+ var parsedBlockData: String = ""
+ val handler = new BlockFetchingClientHandler
+ handler.addRequest(blockId,
+ new BlockClientListener {
+ override def onFetchFailure(blockId: String, errorMsg: String): Unit = ???
+ override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = {
+ parsedBlockId = bid
+ val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining)
+ refCntBuf.byteBuffer().get(bytes)
+ parsedBlockData = new String(bytes)
+ }
+ }
+ )
+
+ val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests)
+ assert(handler.invokePrivate(outstandingRequests()).size === 1)
+
+ val channel = new EmbeddedChannel(handler)
+ val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
+ buf.putInt(totalLength)
+ buf.putInt(blockId.length)
+ buf.put(blockId.getBytes)
+ buf.put(blockData.getBytes)
+ buf.flip()
+
+ channel.writeInbound(Unpooled.wrappedBuffer(buf))
+ assert(parsedBlockId === blockId)
+ assert(parsedBlockData === blockData)
+
+ assert(handler.invokePrivate(outstandingRequests()).size === 0)
+
+ channel.close()
+ }
+
+ test("handling error message (failed fetch)") {
+ val blockId = "test_block"
+ val errorMsg = "error erro5r error err4or error3 error6 error erro1r"
+ val totalLength = 4 + blockId.length + errorMsg.length
+
+ var parsedBlockId: String = ""
+ var parsedErrorMsg: String = ""
+ val handler = new BlockFetchingClientHandler
+ handler.addRequest(blockId, new BlockClientListener {
+ override def onFetchFailure(bid: String, msg: String) ={
+ parsedBlockId = bid
+ parsedErrorMsg = msg
+ }
+ override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ???
+ })
+
+ val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests)
+ assert(handler.invokePrivate(outstandingRequests()).size === 1)
+
+ val channel = new EmbeddedChannel(handler)
+ val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
+ buf.putInt(totalLength)
+ buf.putInt(-blockId.length)
+ buf.put(blockId.getBytes)
+ buf.put(errorMsg.getBytes)
+ buf.flip()
+
+ channel.writeInbound(Unpooled.wrappedBuffer(buf))
+ assert(parsedBlockId === blockId)
+ assert(parsedErrorMsg === errorMsg)
+
+ assert(handler.invokePrivate(outstandingRequests()).size === 0)
+
+ channel.close()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
new file mode 100644
index 0000000000000..3ee281cb1350b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.server
+
+import io.netty.buffer.ByteBuf
+import io.netty.channel.embedded.EmbeddedChannel
+
+import org.scalatest.FunSuite
+
+
+class BlockHeaderEncoderSuite extends FunSuite {
+
+ test("encode normal block data") {
+ val blockId = "test_block"
+ val channel = new EmbeddedChannel(new BlockHeaderEncoder)
+ channel.writeOutbound(new BlockHeader(17, blockId, None))
+ val out = channel.readOutbound().asInstanceOf[ByteBuf]
+ assert(out.readInt() === 4 + blockId.length + 17)
+ assert(out.readInt() === blockId.length)
+
+ val blockIdBytes = new Array[Byte](blockId.length)
+ out.readBytes(blockIdBytes)
+ assert(new String(blockIdBytes) === blockId)
+ assert(out.readableBytes() === 0)
+
+ channel.close()
+ }
+
+ test("encode error message") {
+ val blockId = "error_block"
+ val errorMsg = "error encountered"
+ val channel = new EmbeddedChannel(new BlockHeaderEncoder)
+ channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg)))
+ val out = channel.readOutbound().asInstanceOf[ByteBuf]
+ assert(out.readInt() === 4 + blockId.length + errorMsg.length)
+ assert(out.readInt() === -blockId.length)
+
+ val blockIdBytes = new Array[Byte](blockId.length)
+ out.readBytes(blockIdBytes)
+ assert(new String(blockIdBytes) === blockId)
+
+ val errorMsgBytes = new Array[Byte](errorMsg.length)
+ out.readBytes(errorMsgBytes)
+ assert(new String(errorMsgBytes) === errorMsg)
+ assert(out.readableBytes() === 0)
+
+ channel.close()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
new file mode 100644
index 0000000000000..3239c710f1639
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.server
+
+import java.io.{RandomAccessFile, File}
+import java.nio.ByteBuffer
+
+import io.netty.buffer.{Unpooled, ByteBuf}
+import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion}
+import io.netty.channel.embedded.EmbeddedChannel
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.storage.{BlockDataProvider, FileSegment}
+
+
+class BlockServerHandlerSuite extends FunSuite {
+
+ test("ByteBuffer block") {
+ val expectedBlockId = "test_bytebuffer_block"
+ val buf = ByteBuffer.allocate(10000)
+ for (i <- 1 to 10000) {
+ buf.put(i.toByte)
+ }
+ buf.flip()
+
+ val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider {
+ override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf)
+ }))
+
+ channel.writeInbound(expectedBlockId)
+ assert(channel.outboundMessages().size === 2)
+
+ val out1 = channel.readOutbound().asInstanceOf[BlockHeader]
+ val out2 = channel.readOutbound().asInstanceOf[ByteBuf]
+
+ assert(out1.blockId === expectedBlockId)
+ assert(out1.blockSize === buf.remaining)
+ assert(out1.error === None)
+
+ assert(out2.equals(Unpooled.wrappedBuffer(buf)))
+
+ channel.close()
+ }
+
+ test("FileSegment block via zero-copy") {
+ val expectedBlockId = "test_file_block"
+
+ // Create random file data
+ val fileContent = new Array[Byte](1024)
+ scala.util.Random.nextBytes(fileContent)
+ val testFile = File.createTempFile("netty-test-file", "txt")
+ val fp = new RandomAccessFile(testFile, "rw")
+ fp.write(fileContent)
+ fp.close()
+
+ val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider {
+ override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
+ Left(new FileSegment(testFile, 15, testFile.length - 25))
+ }
+ }))
+
+ channel.writeInbound(expectedBlockId)
+ assert(channel.outboundMessages().size === 2)
+
+ val out1 = channel.readOutbound().asInstanceOf[BlockHeader]
+ val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion]
+
+ assert(out1.blockId === expectedBlockId)
+ assert(out1.blockSize === testFile.length - 25)
+ assert(out1.error === None)
+
+ assert(out2.count === testFile.length - 25)
+ assert(out2.position === 15)
+ }
+
+ test("pipeline exception propagation") {
+ val blockServerHandler = new BlockServerHandler(new BlockDataProvider {
+ override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ???
+ })
+ val exceptionHandler = new SimpleChannelInboundHandler[String]() {
+ override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = {
+ throw new Exception("this is an error")
+ }
+ }
+
+ val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler)
+ assert(channel.isOpen)
+ channel.writeInbound("a message to trigger the error")
+ assert(!channel.isOpen)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
similarity index 97%
rename from core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
rename to core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
index e2f4d4c57cdb5..9f49587cdc670 100644
--- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
@@ -15,23 +15,18 @@
* limitations under the License.
*/
-package org.apache.spark.network
+package org.apache.spark.network.nio
import java.io.IOException
import java.nio._
-import java.util.concurrent.TimeoutException
-import org.apache.spark.{SecurityManager, SparkConf}
-import org.scalatest.FunSuite
-
-import org.mockito.Mockito._
-import org.mockito.Matchers._
-
-import scala.concurrent.TimeoutException
-import scala.concurrent.{Await, TimeoutException}
import scala.concurrent.duration._
+import scala.concurrent.{Await, TimeoutException}
import scala.language.postfixOps
-import scala.util.{Failure, Success, Try}
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.{SecurityManager, SparkConf}
/**
* Test the ConnectionManager with various security settings.
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index 28197657e9bad..3b833f2e41867 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -22,7 +22,6 @@ import java.util.concurrent.Semaphore
import scala.concurrent.{Await, TimeoutException}
import scala.concurrent.duration.Duration
import scala.concurrent.ExecutionContext.Implicits.global
-import scala.language.postfixOps
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.concurrent.Timeouts
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala
index 956c2b9cbd321..8408d7e785c65 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala
@@ -38,9 +38,7 @@ class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
Iterator()
}
}
- val prunedRDD = PartitionPruningRDD.create(rdd, {
- x => if (x == 2) true else false
- })
+ val prunedRDD = PartitionPruningRDD.create(rdd, _ == 2)
assert(prunedRDD.partitions.length == 1)
val p = prunedRDD.partitions(0)
assert(p.index == 0)
@@ -62,13 +60,10 @@ class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
List(split.asInstanceOf[TestPartition].testValue).iterator
}
}
- val prunedRDD1 = PartitionPruningRDD.create(rdd, {
- x => if (x == 0) true else false
- })
+ val prunedRDD1 = PartitionPruningRDD.create(rdd, _ == 0)
- val prunedRDD2 = PartitionPruningRDD.create(rdd, {
- x => if (x == 2) true else false
- })
+
+ val prunedRDD2 = PartitionPruningRDD.create(rdd, _ == 2)
val merged = prunedRDD1 ++ prunedRDD2
assert(merged.count() == 2)
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 926d4fecb5b91..c1b501a75c8b8 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -521,6 +521,13 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sortedLowerK === Array(1, 2, 3, 4, 5))
}
+ test("takeOrdered with limit 0") {
+ val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ val rdd = sc.makeRDD(nums, 2)
+ val sortedLowerK = rdd.takeOrdered(0)
+ assert(sortedLowerK.size === 0)
+ }
+
test("takeOrdered with custom ordering") {
val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
implicit val ord = implicitly[Ordering[Int]].reverse
@@ -675,6 +682,20 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered)
}
+ test("repartitionAndSortWithinPartitions") {
+ val data = sc.parallelize(Seq((0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)), 2)
+
+ val partitioner = new Partitioner {
+ def numPartitions: Int = 2
+ def getPartition(key: Any): Int = key.asInstanceOf[Int] % 2
+ }
+
+ val repartitioned = data.repartitionAndSortWithinPartitions(partitioner)
+ val partitions = repartitioned.glom().collect()
+ assert(partitions(0) === Seq((0, 5), (0, 8), (2, 6)))
+ assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8)))
+ }
+
test("intersection") {
val all = sc.parallelize(1 to 10)
val evens = sc.parallelize(2 to 10 by 2)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index bd829752eb401..aa73469b6acd8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.scheduler
-import scala.collection.mutable.{HashSet, HashMap, Map}
+import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map}
import scala.language.reflectiveCalls
import akka.actor._
@@ -27,6 +27,7 @@ import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
import org.apache.spark._
+import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
@@ -97,10 +98,12 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
val sparkListener = new SparkListener() {
- val successfulStages = new HashSet[Int]()
- val failedStages = new HashSet[Int]()
+ val successfulStages = new HashSet[Int]
+ val failedStages = new ArrayBuffer[Int]
+ val stageByOrderOfExecution = new ArrayBuffer[Int]
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
val stageInfo = stageCompleted.stageInfo
+ stageByOrderOfExecution += stageInfo.stageId
if (stageInfo.failureReason.isEmpty) {
successfulStages += stageInfo.stageId
} else {
@@ -120,7 +123,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
*/
val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
// stub out BlockManagerMaster.getLocations to use our cacheLocations
- val blockManagerMaster = new BlockManagerMaster(null, conf) {
+ val blockManagerMaster = new BlockManagerMaster(null, conf, true) {
override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
blockIds.map {
_.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)).
@@ -231,6 +234,13 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
runEvent(JobCancelled(jobId))
}
+ test("[SPARK-3353] parent stage should have lower stage id") {
+ sparkListener.stageByOrderOfExecution.clear()
+ sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count()
+ assert(sparkListener.stageByOrderOfExecution.length === 2)
+ assert(sparkListener.stageByOrderOfExecution(0) < sparkListener.stageByOrderOfExecution(1))
+ }
+
test("zero split job") {
var numResults = 0
val fakeListener = new JobListener() {
@@ -435,6 +445,43 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
assertDataStructuresEmpty
}
+ test("trivial shuffle with multiple fetch failures") {
+ val shuffleMapRdd = new MyRDD(sc, 2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+ submit(reduceRdd, Array(0, 1))
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 1)),
+ (Success, makeMapStatus("hostB", 1))))
+ // The MapOutputTracker should know about both map output locations.
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) ===
+ Array("hostA", "hostB"))
+
+ // The first result task fails, with a fetch failure for the output from the first mapper.
+ runEvent(CompletionEvent(
+ taskSets(1).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0),
+ null,
+ Map[Long, Any](),
+ null,
+ null))
+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ assert(sparkListener.failedStages.contains(1))
+
+ // The second ResultTask fails, with a fetch failure for the output from the second mapper.
+ runEvent(CompletionEvent(
+ taskSets(1).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1),
+ null,
+ Map[Long, Any](),
+ null,
+ null))
+ // The SparkListener should not receive redundant failure events.
+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ assert(sparkListener.failedStages.size == 1)
+ }
+
test("ignore late map task completions") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
@@ -478,8 +525,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
// Listener bus should get told about the map stage failing, but not the reduce stage
// (since the reduce stage hasn't been started yet).
assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
- assert(sparkListener.failedStages.contains(1))
- assert(sparkListener.failedStages.size === 1)
+ assert(sparkListener.failedStages.toSet === Set(0))
assertDataStructuresEmpty
}
@@ -526,14 +572,12 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
val stageFailureMessage = "Exception failure in map stage"
failed(taskSets(0), stageFailureMessage)
- assert(cancelledStages.contains(1))
+ assert(cancelledStages.toSet === Set(0, 2))
// Make sure the listeners got told about both failed stages.
assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
assert(sparkListener.successfulStages.isEmpty)
- assert(sparkListener.failedStages.contains(1))
- assert(sparkListener.failedStages.contains(3))
- assert(sparkListener.failedStages.size === 2)
+ assert(sparkListener.failedStages.toSet === Set(0, 2))
assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage")
assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage")
@@ -699,7 +743,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
private def makeBlockManagerId(host: String): BlockManagerId =
- BlockManagerId("exec-" + host, host, 12345, 0)
+ BlockManagerId("exec-" + host, host, 12345)
private def assertDataStructuresEmpty = {
assert(scheduler.activeJobs.isEmpty)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
index 10d8b299317ea..e5315bc93e217 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
@@ -26,7 +26,9 @@ import org.json4s.jackson.JsonMethods._
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.io.CompressionCodec
+import org.apache.spark.SPARK_VERSION
import org.apache.spark.util.{JsonProtocol, Utils}
import java.io.File
@@ -39,7 +41,8 @@ import java.io.File
* read and deserialized into actual SparkListenerEvents.
*/
class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter {
- private val fileSystem = Utils.getHadoopFileSystem("/")
+ private val fileSystem = Utils.getHadoopFileSystem("/",
+ SparkHadoopUtil.get.newConfiguration(new SparkConf()))
private val allCompressionCodecs = Seq[String](
"org.apache.spark.io.LZFCompressionCodec",
"org.apache.spark.io.SnappyCompressionCodec"
@@ -194,7 +197,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter {
def assertInfoCorrect(info: EventLoggingInfo, loggerStopped: Boolean) {
assert(info.logPaths.size > 0)
- assert(info.sparkVersion === SparkContext.SPARK_VERSION)
+ assert(info.sparkVersion === SPARK_VERSION)
assert(info.compressionCodec.isDefined === compressionCodec.isDefined)
info.compressionCodec.foreach { codec =>
assert(compressionCodec.isDefined)
@@ -227,7 +230,8 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter {
val conf = getLoggingConf(logDirPath, compressionCodec)
val eventLogger = new EventLoggingListener("test", conf)
val listenerBus = new LiveListenerBus
- val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", 125L, "Mickey")
+ val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None,
+ 125L, "Mickey")
val applicationEnd = SparkListenerApplicationEnd(1000L)
// A comprehensive test on JSON de/serialization of all events is in JsonProtocolSuite
@@ -378,7 +382,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter {
private def assertSparkVersionIsValid(logFiles: Array[FileStatus]) {
val file = logFiles.map(_.getPath.getName).find(EventLoggingListener.isSparkVersionFile)
assert(file.isDefined)
- assert(EventLoggingListener.parseSparkVersion(file.get) === SparkContext.SPARK_VERSION)
+ assert(EventLoggingListener.parseSparkVersion(file.get) === SPARK_VERSION)
}
private def assertCompressionCodecIsValid(logFiles: Array[FileStatus], compressionCodec: String) {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
index d81499ac6abef..7ab351d1b4d24 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
@@ -25,6 +25,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite}
import org.apache.spark.SparkContext._
import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.{JsonProtocol, Utils}
@@ -32,11 +33,9 @@ import org.apache.spark.util.{JsonProtocol, Utils}
* Test whether ReplayListenerBus replays events from logs correctly.
*/
class ReplayListenerSuite extends FunSuite with BeforeAndAfter {
- private val fileSystem = Utils.getHadoopFileSystem("/")
- private val allCompressionCodecs = Seq[String](
- "org.apache.spark.io.LZFCompressionCodec",
- "org.apache.spark.io.SnappyCompressionCodec"
- )
+ private val fileSystem = Utils.getHadoopFileSystem("/",
+ SparkHadoopUtil.get.newConfiguration(new SparkConf()))
+ private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS
private var testDir: File = _
before {
@@ -84,7 +83,8 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter {
val fstream = fileSystem.create(logFilePath)
val cstream = codec.map(_.compressedOutputStream(fstream)).getOrElse(fstream)
val writer = new PrintWriter(cstream)
- val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", 125L, "Mickey")
+ val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None,
+ 125L, "Mickey")
val applicationEnd = SparkListenerApplicationEnd(1000L)
writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart))))
writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd))))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 3b0b8e2f68c97..ab35e8edc4ebf 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -180,7 +180,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
rdd3.count()
assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be {2} // Shuffle map stage + result stage
- val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 2).get
+ val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 3).get
stageInfo3.rddInfos.size should be {1} // ShuffledRDD
stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true}
stageInfo3.rddInfos.exists(_.name == "Trois") should be {true}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index db2ad829a48f9..faba5508c906c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -17,16 +17,20 @@
package org.apache.spark.scheduler
+import org.mockito.Mockito._
+import org.mockito.Matchers.any
+
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.apache.spark._
import org.apache.spark.rdd.RDD
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
+
class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
- test("Calls executeOnCompleteCallbacks after failure") {
+ test("calls TaskCompletionListener after failure") {
TaskContextSuite.completed = false
sc = new SparkContext("local", "test")
val rdd = new RDD[String](sc, List()) {
@@ -45,6 +49,20 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
}
assert(TaskContextSuite.completed === true)
}
+
+ test("all TaskCompletionListeners should be called even if some fail") {
+ val context = new TaskContext(0, 0, 0)
+ val listener = mock(classOf[TaskCompletionListener])
+ context.addTaskCompletionListener(_ => throw new Exception("blah"))
+ context.addTaskCompletionListener(listener)
+ context.addTaskCompletionListener(_ => throw new Exception("blah"))
+
+ intercept[TaskCompletionListenerException] {
+ context.markTaskCompleted()
+ }
+
+ verify(listener, times(1)).onTaskCompletion(any())
+ }
}
private object TaskContextSuite {
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
new file mode 100644
index 0000000000000..ba47fe5e25b9b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import java.io.{File, FileWriter}
+
+import scala.language.reflectiveCalls
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf}
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.shuffle.FileShuffleBlockManager
+import org.apache.spark.storage.{ShuffleBlockId, FileSegment}
+
+class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
+ private val testConf = new SparkConf(false)
+
+ private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) {
+ assert(buffer.isInstanceOf[FileSegmentManagedBuffer])
+ val segment = buffer.asInstanceOf[FileSegmentManagedBuffer]
+ assert(expected.file.getCanonicalPath === segment.file.getCanonicalPath)
+ assert(expected.offset === segment.offset)
+ assert(expected.length === segment.length)
+ }
+
+ test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") {
+
+ val conf = new SparkConf(false)
+ // reset after EACH object write. This is to ensure that there are bytes appended after
+ // an object is written. So if the codepaths assume writeObject is end of data, this should
+ // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc.
+ conf.set("spark.serializer.objectStreamReset", "1")
+ conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
+
+ sc = new SparkContext("local", "test", conf)
+
+ val shuffleBlockManager =
+ SparkEnv.get.shuffleManager.shuffleBlockManager.asInstanceOf[FileShuffleBlockManager]
+
+ val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf),
+ new ShuffleWriteMetrics)
+ for (writer <- shuffle1.writers) {
+ writer.write("test1")
+ writer.write("test2")
+ }
+ for (writer <- shuffle1.writers) {
+ writer.commitAndClose()
+ }
+
+ val shuffle1Segment = shuffle1.writers(0).fileSegment()
+ shuffle1.releaseWriters(success = true)
+
+ val shuffle2 = shuffleBlockManager.forMapTask(1, 2, 1, new JavaSerializer(conf),
+ new ShuffleWriteMetrics)
+
+ for (writer <- shuffle2.writers) {
+ writer.write("test3")
+ writer.write("test4")
+ }
+ for (writer <- shuffle2.writers) {
+ writer.commitAndClose()
+ }
+ val shuffle2Segment = shuffle2.writers(0).fileSegment()
+ shuffle2.releaseWriters(success = true)
+
+ // Now comes the test :
+ // Write to shuffle 3; and close it, but before registering it, check if the file lengths for
+ // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length
+ // of block based on remaining data in file : which could mess things up when there is concurrent read
+ // and writes happening to the same shuffle group.
+
+ val shuffle3 = shuffleBlockManager.forMapTask(1, 3, 1, new JavaSerializer(testConf),
+ new ShuffleWriteMetrics)
+ for (writer <- shuffle3.writers) {
+ writer.write("test3")
+ writer.write("test4")
+ }
+ for (writer <- shuffle3.writers) {
+ writer.commitAndClose()
+ }
+ // check before we register.
+ checkSegments(shuffle2Segment, shuffleBlockManager.getBlockData(ShuffleBlockId(1, 2, 0)))
+ shuffle3.releaseWriters(success = true)
+ checkSegments(shuffle2Segment, shuffleBlockManager.getBlockData(ShuffleBlockId(1, 2, 0)))
+ shuffleBlockManager.removeShuffle(1)
+ }
+
+ def writeToFile(file: File, numBytes: Int) {
+ val writer = new FileWriter(file, true)
+ for (i <- 0 until numBytes) writer.write(i)
+ writer.close()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
deleted file mode 100644
index bcbfe8baf36ad..0000000000000
--- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
+++ /dev/null
@@ -1,231 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.io.IOException
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.future
-import scala.concurrent.ExecutionContext.Implicits.global
-
-import org.scalatest.{FunSuite, Matchers}
-
-import org.mockito.Mockito._
-import org.mockito.Matchers.{any, eq => meq}
-import org.mockito.stubbing.Answer
-import org.mockito.invocation.InvocationOnMock
-
-import org.apache.spark.storage.BlockFetcherIterator._
-import org.apache.spark.network.{ConnectionManager, Message}
-import org.apache.spark.executor.ShuffleReadMetrics
-
-class BlockFetcherIteratorSuite extends FunSuite with Matchers {
-
- test("block fetch from local fails using BasicBlockFetcherIterator") {
- val blockManager = mock(classOf[BlockManager])
- val connManager = mock(classOf[ConnectionManager])
- doReturn(connManager).when(blockManager).connectionManager
- doReturn(BlockManagerId("test-client", "test-client", 1, 0)).when(blockManager).blockManagerId
-
- doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight
-
- val blIds = Array[BlockId](
- ShuffleBlockId(0,0,0),
- ShuffleBlockId(0,1,0),
- ShuffleBlockId(0,2,0),
- ShuffleBlockId(0,3,0),
- ShuffleBlockId(0,4,0))
-
- val optItr = mock(classOf[Option[Iterator[Any]]])
- val answer = new Answer[Option[Iterator[Any]]] {
- override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] {
- throw new Exception
- }
- }
-
- // 3rd block is going to fail
- doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any())
- doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any())
- doAnswer(answer).when(blockManager).getLocalFromDisk(meq(blIds(2)), any())
- doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any())
- doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any())
-
- val bmId = BlockManagerId("test-client", "test-client",1 , 0)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
- )
-
- val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null,
- new ShuffleReadMetrics())
-
- iterator.initialize()
-
- // 3rd getLocalFromDisk invocation should be failed
- verify(blockManager, times(3)).getLocalFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
- // the 2nd element of the tuple returned by iterator.next should be defined when fetching successfully
- assert(iterator.next._2.isDefined, "1st element should be defined but is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
- assert(iterator.next._2.isDefined, "2nd element should be defined but is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
- // 3rd fetch should be failed
- assert(!iterator.next._2.isDefined, "3rd element should not be defined but is actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements")
- // Don't call next() after fetching non-defined element even if thare are rest of elements in the iterator.
- // Otherwise, BasicBlockFetcherIterator hangs up.
- }
-
-
- test("block fetch from local succeed using BasicBlockFetcherIterator") {
- val blockManager = mock(classOf[BlockManager])
- val connManager = mock(classOf[ConnectionManager])
- doReturn(connManager).when(blockManager).connectionManager
- doReturn(BlockManagerId("test-client", "test-client", 1, 0)).when(blockManager).blockManagerId
-
- doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight
-
- val blIds = Array[BlockId](
- ShuffleBlockId(0,0,0),
- ShuffleBlockId(0,1,0),
- ShuffleBlockId(0,2,0),
- ShuffleBlockId(0,3,0),
- ShuffleBlockId(0,4,0))
-
- val optItr = mock(classOf[Option[Iterator[Any]]])
-
- // All blocks should be fetched successfully
- doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any())
- doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any())
- doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(2)), any())
- doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any())
- doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any())
-
- val bmId = BlockManagerId("test-client", "test-client",1 , 0)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
- )
-
- val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null,
- new ShuffleReadMetrics())
-
- iterator.initialize()
-
- // getLocalFromDis should be invoked for all of 5 blocks
- verify(blockManager, times(5)).getLocalFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
- assert(iterator.next._2.isDefined, "All elements should be defined but 1st element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
- assert(iterator.next._2.isDefined, "All elements should be defined but 2nd element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
- assert(iterator.next._2.isDefined, "All elements should be defined but 3rd element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements")
- assert(iterator.next._2.isDefined, "All elements should be defined but 4th element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements")
- assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined")
- }
-
- test("block fetch from remote fails using BasicBlockFetcherIterator") {
- val blockManager = mock(classOf[BlockManager])
- val connManager = mock(classOf[ConnectionManager])
- when(blockManager.connectionManager).thenReturn(connManager)
-
- val f = future {
- throw new IOException("Send failed or we received an error ACK")
- }
- when(connManager.sendMessageReliably(any(),
- any())).thenReturn(f)
- when(blockManager.futureExecContext).thenReturn(global)
-
- when(blockManager.blockManagerId).thenReturn(
- BlockManagerId("test-client", "test-client", 1, 0))
- when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)
-
- val blId1 = ShuffleBlockId(0,0,0)
- val blId2 = ShuffleBlockId(0,1,0)
- val bmId = BlockManagerId("test-server", "test-server",1 , 0)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, Seq((blId1, 1L), (blId2, 1L)))
- )
-
- val iterator = new BasicBlockFetcherIterator(blockManager,
- blocksByAddress, null, new ShuffleReadMetrics())
-
- iterator.initialize()
- iterator.foreach{
- case (_, r) => {
- (!r.isDefined) should be(true)
- }
- }
- }
-
- test("block fetch from remote succeed using BasicBlockFetcherIterator") {
- val blockManager = mock(classOf[BlockManager])
- val connManager = mock(classOf[ConnectionManager])
- when(blockManager.connectionManager).thenReturn(connManager)
-
- val blId1 = ShuffleBlockId(0,0,0)
- val blId2 = ShuffleBlockId(0,1,0)
- val buf1 = ByteBuffer.allocate(4)
- val buf2 = ByteBuffer.allocate(4)
- buf1.putInt(1)
- buf1.flip()
- buf2.putInt(1)
- buf2.flip()
- val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1))
- val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2))
- val blockMessageArray = new BlockMessageArray(
- Seq(blockMessage1, blockMessage2))
-
- val bufferMessage = blockMessageArray.toBufferMessage
- val buffer = ByteBuffer.allocate(bufferMessage.size)
- val arrayBuffer = new ArrayBuffer[ByteBuffer]
- bufferMessage.buffers.foreach{ b =>
- buffer.put(b)
- }
- buffer.flip()
- arrayBuffer += buffer
-
- val f = future {
- Message.createBufferMessage(arrayBuffer)
- }
- when(connManager.sendMessageReliably(any(),
- any())).thenReturn(f)
- when(blockManager.futureExecContext).thenReturn(global)
-
- when(blockManager.blockManagerId).thenReturn(
- BlockManagerId("test-client", "test-client", 1, 0))
- when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)
-
- val bmId = BlockManagerId("test-server", "test-server",1 , 0)
- val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, Seq((blId1, 1L), (blId2, 1L)))
- )
-
- val iterator = new BasicBlockFetcherIterator(blockManager,
- blocksByAddress, null, new ShuffleReadMetrics())
- iterator.initialize()
- iterator.foreach{
- case (_, r) => {
- (r.isDefined) should be(true)
- }
- }
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 20bac66105a69..e251660dae5de 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -21,15 +21,19 @@ import java.nio.{ByteBuffer, MappedByteBuffer}
import java.util.Arrays
import java.util.concurrent.TimeUnit
+import org.apache.spark.network.nio.NioBlockTransferService
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.Await
+import scala.concurrent.duration._
+import scala.language.implicitConversions
+import scala.language.postfixOps
+
import akka.actor._
import akka.pattern.ask
import akka.util.Timeout
-import org.apache.spark.shuffle.hash.HashShuffleManager
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.Matchers.any
-import org.mockito.Mockito.{doAnswer, mock, spy, when}
-import org.mockito.stubbing.Answer
+import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester}
import org.scalatest.concurrent.Eventually._
@@ -38,17 +42,12 @@ import org.scalatest.Matchers
import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
import org.apache.spark.executor.DataReadMethod
-import org.apache.spark.network.{Message, ConnectionManagerId}
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils}
-import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.Await
-import scala.concurrent.duration._
-import scala.language.implicitConversions
-import scala.language.postfixOps
class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
with PrivateMethodTester {
@@ -73,8 +72,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)
private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = {
- new BlockManager(name, actorSystem, master, serializer, maxMem, conf, securityMgr,
- mapOutputTracker, shuffleManager)
+ val transfer = new NioBlockTransferService(conf, securityMgr)
+ new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+ mapOutputTracker, shuffleManager, transfer)
}
before {
@@ -92,7 +92,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
master = new BlockManagerMaster(
actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))),
- conf)
+ conf, true)
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
@@ -139,9 +139,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
}
test("BlockManagerId object caching") {
- val id1 = BlockManagerId("e1", "XXX", 1, 0)
- val id2 = BlockManagerId("e1", "XXX", 1, 0) // this should return the same object as id1
- val id3 = BlockManagerId("e1", "XXX", 2, 0) // this should return a different object
+ val id1 = BlockManagerId("e1", "XXX", 1)
+ val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1
+ val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object
assert(id2 === id1, "id2 is not same as id1")
assert(id2.eq(id1), "id2 is not the same object as id1")
assert(id3 != id1, "id3 is same as id1")
@@ -792,8 +792,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
+ val transfer = new NioBlockTransferService(conf, securityMgr)
store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf,
- securityMgr, mapOutputTracker, shuffleManager)
+ mapOutputTracker, shuffleManager, transfer)
// The put should fail since a1 is not serializable.
class UnserializableClass
@@ -823,12 +824,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
// be nice to refactor classes involved in disk storage in a way that
// allows for easier testing.
val blockManager = mock(classOf[BlockManager])
- val shuffleBlockManager = mock(classOf[ShuffleBlockManager])
- when(shuffleBlockManager.conf).thenReturn(conf)
- val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
- System.getProperty("java.io.tmpdir"))
-
when(blockManager.conf).thenReturn(conf.clone.set(confKey, 0.toString))
+ val diskBlockManager = new DiskBlockManager(blockManager, conf)
+
val diskStoreMapped = new DiskStore(blockManager, diskBlockManager)
diskStoreMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY)
val mapped = diskStoreMapped.getBytes(blockId).get
@@ -1007,109 +1005,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store")
}
- test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf,
- securityMgr, mapOutputTracker, shuffleManager)
-
- val worker = spy(new BlockManagerWorker(store))
- val connManagerId = mock(classOf[ConnectionManagerId])
-
- // setup request block messages
- val reqBlId1 = ShuffleBlockId(0,0,0)
- val reqBlId2 = ShuffleBlockId(0,1,0)
- val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1))
- val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2))
- val reqBlockMessages = new BlockMessageArray(
- Seq(reqBlockMessage1, reqBlockMessage2))
- val reqBufferMessage = reqBlockMessages.toBufferMessage
-
- val answer = new Answer[Option[BlockMessage]] {
- override def answer(invocation: InvocationOnMock)
- :Option[BlockMessage]= {
- throw new Exception
- }
- }
-
- doAnswer(answer).when(worker).processBlockMessage(any())
-
- // Test when exception was thrown during processing block messages
- var ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId)
-
- assert(ackMessage.isDefined, "When Exception was thrown in " +
- "BlockManagerWorker#processBlockMessage, " +
- "ackMessage should be defined")
- assert(ackMessage.get.hasError, "When Exception was thown in " +
- "BlockManagerWorker#processBlockMessage, " +
- "ackMessage should have error")
-
- val notBufferMessage = mock(classOf[Message])
-
- // Test when not BufferMessage was received
- ackMessage = worker.onBlockMessageReceive(notBufferMessage, connManagerId)
- assert(ackMessage.isDefined, "When not BufferMessage was passed to " +
- "BlockManagerWorker#onBlockMessageReceive, " +
- "ackMessage should be defined")
- assert(ackMessage.get.hasError, "When not BufferMessage was passed to " +
- "BlockManagerWorker#onBlockMessageReceive, " +
- "ackMessage should have error")
- }
-
- test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf,
- securityMgr, mapOutputTracker, shuffleManager)
-
- val worker = spy(new BlockManagerWorker(store))
- val connManagerId = mock(classOf[ConnectionManagerId])
-
- // setup request block messages
- val reqBlId1 = ShuffleBlockId(0,0,0)
- val reqBlId2 = ShuffleBlockId(0,1,0)
- val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1))
- val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2))
- val reqBlockMessages = new BlockMessageArray(
- Seq(reqBlockMessage1, reqBlockMessage2))
-
- val tmpBufferMessage = reqBlockMessages.toBufferMessage
- val buffer = ByteBuffer.allocate(tmpBufferMessage.size)
- val arrayBuffer = new ArrayBuffer[ByteBuffer]
- tmpBufferMessage.buffers.foreach{ b =>
- buffer.put(b)
- }
- buffer.flip()
- arrayBuffer += buffer
- val reqBufferMessage = Message.createBufferMessage(arrayBuffer)
-
- // setup ack block messages
- val buf1 = ByteBuffer.allocate(4)
- val buf2 = ByteBuffer.allocate(4)
- buf1.putInt(1)
- buf1.flip()
- buf2.putInt(1)
- buf2.flip()
- val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1))
- val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2))
-
- val answer = new Answer[Option[BlockMessage]] {
- override def answer(invocation: InvocationOnMock)
- :Option[BlockMessage]= {
- if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq(
- reqBlockMessage1)) {
- return Some(ackBlockMessage1)
- } else {
- return Some(ackBlockMessage2)
- }
- }
- }
-
- doAnswer(answer).when(worker).processBlockMessage(any())
-
- val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId)
- assert(ackMessage.isDefined, "When BlockManagerWorker#onBlockMessageReceive " +
- "was executed successfully, ackMessage should be defined")
- assert(!ackMessage.get.hasError, "When BlockManagerWorker#onBlockMessageReceive " +
- "was executed successfully, ackMessage should not have error")
- }
-
test("reserve/release unroll memory") {
store = makeBlockManager(12000)
val memoryStore = store.memoryStore
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index 777579bc570db..e4522e00a622d 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
import java.io.{File, FileWriter}
+import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.shuffle.hash.HashShuffleManager
import scala.collection.mutable
@@ -26,6 +27,7 @@ import scala.language.reflectiveCalls
import akka.actor.Props
import com.google.common.io.Files
+import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
import org.apache.spark.SparkConf
@@ -40,18 +42,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
private var rootDir1: File = _
private var rootDirs: String = _
- // This suite focuses primarily on consolidation features,
- // so we coerce consolidation if not already enabled.
- testConf.set("spark.shuffle.consolidateFiles", "true")
-
- private val shuffleManager = new HashShuffleManager(testConf.clone)
-
- val shuffleBlockManager = new ShuffleBlockManager(null, shuffleManager) {
- override def conf = testConf.clone
- var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]()
- override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id)
- }
-
+ val blockManager = mock(classOf[BlockManager])
+ when(blockManager.conf).thenReturn(testConf)
var diskBlockManager: DiskBlockManager = _
override def beforeAll() {
@@ -61,7 +53,6 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
rootDir1 = Files.createTempDir()
rootDir1.deleteOnExit()
rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
- println("Created root dirs: " + rootDirs)
}
override def afterAll() {
@@ -71,22 +62,19 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
}
override def beforeEach() {
- diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs)
- shuffleBlockManager.idToSegmentMap.clear()
+ val conf = testConf.clone
+ conf.set("spark.local.dir", rootDirs)
+ diskBlockManager = new DiskBlockManager(blockManager, conf)
}
override def afterEach() {
diskBlockManager.stop()
- shuffleBlockManager.idToSegmentMap.clear()
}
test("basic block creation") {
val blockId = new TestBlockId("test")
- assertSegmentEquals(blockId, blockId.name, 0, 0)
-
val newFile = diskBlockManager.getFile(blockId)
writeToFile(newFile, 10)
- assertSegmentEquals(blockId, blockId.name, 0, 10)
assert(diskBlockManager.containsBlock(blockId))
newFile.delete()
assert(!diskBlockManager.containsBlock(blockId))
@@ -99,127 +87,6 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
}
- test("block appending") {
- val blockId = new TestBlockId("test")
- val newFile = diskBlockManager.getFile(blockId)
- writeToFile(newFile, 15)
- assertSegmentEquals(blockId, blockId.name, 0, 15)
- val newFile2 = diskBlockManager.getFile(blockId)
- assert(newFile === newFile2)
- writeToFile(newFile2, 12)
- assertSegmentEquals(blockId, blockId.name, 0, 27)
- newFile.delete()
- }
-
- test("block remapping") {
- val filename = "test"
- val blockId0 = new ShuffleBlockId(1, 2, 3)
- val newFile = diskBlockManager.getFile(filename)
- writeToFile(newFile, 15)
- shuffleBlockManager.idToSegmentMap(blockId0) = new FileSegment(newFile, 0, 15)
- assertSegmentEquals(blockId0, filename, 0, 15)
-
- val blockId1 = new ShuffleBlockId(1, 2, 4)
- val newFile2 = diskBlockManager.getFile(filename)
- writeToFile(newFile2, 12)
- shuffleBlockManager.idToSegmentMap(blockId1) = new FileSegment(newFile, 15, 12)
- assertSegmentEquals(blockId1, filename, 15, 12)
-
- assert(newFile === newFile2)
- newFile.delete()
- }
-
- private def checkSegments(segment1: FileSegment, segment2: FileSegment) {
- assert (segment1.file.getCanonicalPath === segment2.file.getCanonicalPath)
- assert (segment1.offset === segment2.offset)
- assert (segment1.length === segment2.length)
- }
-
- test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") {
-
- val serializer = new JavaSerializer(testConf)
- val confCopy = testConf.clone
- // reset after EACH object write. This is to ensure that there are bytes appended after
- // an object is written. So if the codepaths assume writeObject is end of data, this should
- // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc.
- confCopy.set("spark.serializer.objectStreamReset", "1")
-
- val securityManager = new org.apache.spark.SecurityManager(confCopy)
- // Do not use the shuffleBlockManager above !
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, confCopy,
- securityManager)
- val master = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))),
- confCopy)
- val store = new BlockManager("", actorSystem, master , serializer, confCopy,
- securityManager, null, shuffleManager)
-
- try {
-
- val shuffleManager = store.shuffleBlockManager
-
- val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer, new ShuffleWriteMetrics)
- for (writer <- shuffle1.writers) {
- writer.write("test1")
- writer.write("test2")
- }
- for (writer <- shuffle1.writers) {
- writer.commitAndClose()
- }
-
- val shuffle1Segment = shuffle1.writers(0).fileSegment()
- shuffle1.releaseWriters(success = true)
-
- val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf),
- new ShuffleWriteMetrics)
-
- for (writer <- shuffle2.writers) {
- writer.write("test3")
- writer.write("test4")
- }
- for (writer <- shuffle2.writers) {
- writer.commitAndClose()
- }
- val shuffle2Segment = shuffle2.writers(0).fileSegment()
- shuffle2.releaseWriters(success = true)
-
- // Now comes the test :
- // Write to shuffle 3; and close it, but before registering it, check if the file lengths for
- // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length
- // of block based on remaining data in file : which could mess things up when there is concurrent read
- // and writes happening to the same shuffle group.
-
- val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf),
- new ShuffleWriteMetrics)
- for (writer <- shuffle3.writers) {
- writer.write("test3")
- writer.write("test4")
- }
- for (writer <- shuffle3.writers) {
- writer.commitAndClose()
- }
- // check before we register.
- checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0)))
- shuffle3.releaseWriters(success = true)
- checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0)))
- shuffleManager.removeShuffle(1)
- } finally {
-
- if (store != null) {
- store.stop()
- }
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
- }
-
- def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) {
- val segment = diskBlockManager.getBlockLocation(blockId)
- assert(segment.file.getName === filename)
- assert(segment.offset === offset)
- assert(segment.length === length)
- }
-
def writeToFile(file: File, numBytes: Int) {
val writer = new FileWriter(file, true)
for (i <- 0 until numBytes) writer.write(i)
diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
new file mode 100644
index 0000000000000..dae7bf0e336de
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+
+import org.apache.spark.util.Utils
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkConf
+
+
+/**
+ * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options.
+ */
+class LocalDirsSuite extends FunSuite {
+
+ test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") {
+ // Regression test for SPARK-2974
+ assert(!new File("/NONEXISTENT_DIR").exists())
+ val conf = new SparkConf(false)
+ .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}")
+ assert(new File(Utils.getLocalDir(conf)).exists())
+ }
+
+ test("SPARK_LOCAL_DIRS override also affects driver") {
+ // Regression test for SPARK-2975
+ assert(!new File("/NONEXISTENT_DIR").exists())
+ // SPARK_LOCAL_DIRS is a valid directory:
+ class MySparkConf extends SparkConf(false) {
+ override def getenv(name: String) = {
+ if (name == "SPARK_LOCAL_DIRS") System.getProperty("java.io.tmpdir")
+ else super.getenv(name)
+ }
+
+ override def clone: SparkConf = {
+ new MySparkConf().setAll(settings)
+ }
+ }
+ // spark.local.dir only contains invalid directories, but that's not a problem since
+ // SPARK_LOCAL_DIRS will override it on both the driver and workers:
+ val conf = new MySparkConf().set("spark.local.dir", "/NONEXISTENT_PATH")
+ assert(new File(Utils.getLocalDir(conf)).exists())
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
new file mode 100644
index 0000000000000..809bd70929656
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.TaskContext
+import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
+
+import org.mockito.Mockito._
+import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+
+import org.scalatest.FunSuite
+
+
+class ShuffleBlockFetcherIteratorSuite extends FunSuite {
+
+ test("handle local read failures in BlockManager") {
+ val transfer = mock(classOf[BlockTransferService])
+ val blockManager = mock(classOf[BlockManager])
+ doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
+
+ val blIds = Array[BlockId](
+ ShuffleBlockId(0,0,0),
+ ShuffleBlockId(0,1,0),
+ ShuffleBlockId(0,2,0),
+ ShuffleBlockId(0,3,0),
+ ShuffleBlockId(0,4,0))
+
+ val optItr = mock(classOf[Option[Iterator[Any]]])
+ val answer = new Answer[Option[Iterator[Any]]] {
+ override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] {
+ throw new Exception
+ }
+ }
+
+ // 3rd block is going to fail
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
+ doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
+
+ val bmId = BlockManagerId("test-client", "test-client", 1)
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+ )
+
+ val iterator = new ShuffleBlockFetcherIterator(
+ new TaskContext(0, 0, 0),
+ transfer,
+ blockManager,
+ blocksByAddress,
+ null,
+ 48 * 1024 * 1024)
+
+ // Without exhausting the iterator, the iterator should be lazy and not call
+ // getLocalShuffleFromDisk.
+ verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
+
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
+ // the 2nd element of the tuple returned by iterator.next should be defined when
+ // fetching successfully
+ assert(iterator.next()._2.isDefined,
+ "1st element should be defined but is not actually defined")
+ verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any())
+
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
+ assert(iterator.next()._2.isDefined,
+ "2nd element should be defined but is not actually defined")
+ verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any())
+
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
+ // 3rd fetch should be failed
+ intercept[Exception] {
+ iterator.next()
+ }
+ verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any())
+ }
+
+ test("handle local read successes") {
+ val transfer = mock(classOf[BlockTransferService])
+ val blockManager = mock(classOf[BlockManager])
+ doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
+
+ val blIds = Array[BlockId](
+ ShuffleBlockId(0,0,0),
+ ShuffleBlockId(0,1,0),
+ ShuffleBlockId(0,2,0),
+ ShuffleBlockId(0,3,0),
+ ShuffleBlockId(0,4,0))
+
+ val optItr = mock(classOf[Option[Iterator[Any]]])
+
+ // All blocks should be fetched successfully
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
+
+ val bmId = BlockManagerId("test-client", "test-client", 1)
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+ )
+
+ val iterator = new ShuffleBlockFetcherIterator(
+ new TaskContext(0, 0, 0),
+ transfer,
+ blockManager,
+ blocksByAddress,
+ null,
+ 48 * 1024 * 1024)
+
+ // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk.
+ verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
+
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 1st element is not actually defined")
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 2nd element is not actually defined")
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 3rd element is not actually defined")
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 4th element is not actually defined")
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 5th element is not actually defined")
+
+ verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any())
+ }
+
+ test("handle remote fetch failures in BlockTransferService") {
+ val transfer = mock(classOf[BlockTransferService])
+ when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
+ listener.onBlockFetchFailure(new Exception("blah"))
+ }
+ })
+
+ val blockManager = mock(classOf[BlockManager])
+
+ when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1))
+
+ val blId1 = ShuffleBlockId(0, 0, 0)
+ val blId2 = ShuffleBlockId(0, 1, 0)
+ val bmId = BlockManagerId("test-server", "test-server", 1)
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (bmId, Seq((blId1, 1L), (blId2, 1L))))
+
+ val iterator = new ShuffleBlockFetcherIterator(
+ new TaskContext(0, 0, 0),
+ transfer,
+ blockManager,
+ blocksByAddress,
+ null,
+ 48 * 1024 * 1024)
+
+ iterator.foreach { case (_, iterOption) =>
+ assert(!iterOption.isDefined)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
index 51fb646a3cb61..3a45875391e29 100644
--- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
@@ -26,8 +26,8 @@ import org.apache.spark.scheduler._
* Test the behavior of StorageStatusListener in response to all relevant events.
*/
class StorageStatusListenerSuite extends FunSuite {
- private val bm1 = BlockManagerId("big", "dog", 1, 1)
- private val bm2 = BlockManagerId("fat", "duck", 2, 2)
+ private val bm1 = BlockManagerId("big", "dog", 1)
+ private val bm2 = BlockManagerId("fat", "duck", 2)
private val taskInfo1 = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false)
private val taskInfo2 = new TaskInfo(0, 0, 0, 0, "fat", "duck", TaskLocality.ANY, false)
@@ -36,13 +36,13 @@ class StorageStatusListenerSuite extends FunSuite {
// Block manager add
assert(listener.executorIdToStorageStatus.size === 0)
- listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm1, 1000L))
+ listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
assert(listener.executorIdToStorageStatus.size === 1)
assert(listener.executorIdToStorageStatus.get("big").isDefined)
assert(listener.executorIdToStorageStatus("big").blockManagerId === bm1)
assert(listener.executorIdToStorageStatus("big").maxMem === 1000L)
assert(listener.executorIdToStorageStatus("big").numBlocks === 0)
- listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm2, 2000L))
+ listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L))
assert(listener.executorIdToStorageStatus.size === 2)
assert(listener.executorIdToStorageStatus.get("fat").isDefined)
assert(listener.executorIdToStorageStatus("fat").blockManagerId === bm2)
@@ -50,11 +50,11 @@ class StorageStatusListenerSuite extends FunSuite {
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
// Block manager remove
- listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(bm1))
+ listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(1L, bm1))
assert(listener.executorIdToStorageStatus.size === 1)
assert(!listener.executorIdToStorageStatus.get("big").isDefined)
assert(listener.executorIdToStorageStatus.get("fat").isDefined)
- listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(bm2))
+ listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(1L, bm2))
assert(listener.executorIdToStorageStatus.size === 0)
assert(!listener.executorIdToStorageStatus.get("big").isDefined)
assert(!listener.executorIdToStorageStatus.get("fat").isDefined)
@@ -62,25 +62,25 @@ class StorageStatusListenerSuite extends FunSuite {
test("task end without updated blocks") {
val listener = new StorageStatusListener
- listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm1, 1000L))
- listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm2, 2000L))
+ listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
+ listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L))
val taskMetrics = new TaskMetrics
// Task end with no updated blocks
assert(listener.executorIdToStorageStatus("big").numBlocks === 0)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
- listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics))
+ listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics))
assert(listener.executorIdToStorageStatus("big").numBlocks === 0)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
- listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo2, taskMetrics))
+ listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics))
assert(listener.executorIdToStorageStatus("big").numBlocks === 0)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
}
test("task end with updated blocks") {
val listener = new StorageStatusListener
- listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm1, 1000L))
- listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm2, 2000L))
+ listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
+ listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L))
val taskMetrics1 = new TaskMetrics
val taskMetrics2 = new TaskMetrics
val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L, 0L))
@@ -92,13 +92,13 @@ class StorageStatusListenerSuite extends FunSuite {
// Task end with new blocks
assert(listener.executorIdToStorageStatus("big").numBlocks === 0)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
- listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics1))
+ listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1))
assert(listener.executorIdToStorageStatus("big").numBlocks === 2)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1)))
assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2)))
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
- listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo2, taskMetrics2))
+ listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics2))
assert(listener.executorIdToStorageStatus("big").numBlocks === 2)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 1)
assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1)))
@@ -111,13 +111,14 @@ class StorageStatusListenerSuite extends FunSuite {
val droppedBlock3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.NONE, 0L, 0L, 0L))
taskMetrics1.updatedBlocks = Some(Seq(droppedBlock1, droppedBlock3))
taskMetrics2.updatedBlocks = Some(Seq(droppedBlock2, droppedBlock3))
- listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics1))
+
+ listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1))
assert(listener.executorIdToStorageStatus("big").numBlocks === 1)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 1)
assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1)))
assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2)))
assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0)))
- listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo2, taskMetrics2))
+ listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics2))
assert(listener.executorIdToStorageStatus("big").numBlocks === 1)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1)))
@@ -127,7 +128,7 @@ class StorageStatusListenerSuite extends FunSuite {
test("unpersist RDD") {
val listener = new StorageStatusListener
- listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm1, 1000L))
+ listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
val taskMetrics1 = new TaskMetrics
val taskMetrics2 = new TaskMetrics
val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L, 0L))
@@ -135,8 +136,8 @@ class StorageStatusListenerSuite extends FunSuite {
val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L, 0L))
taskMetrics1.updatedBlocks = Some(Seq(block1, block2))
taskMetrics2.updatedBlocks = Some(Seq(block3))
- listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics1))
- listener.onTaskEnd(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo1, taskMetrics2))
+ listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1))
+ listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics2))
assert(listener.executorIdToStorageStatus("big").numBlocks === 3)
// Unpersist RDD
diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala
index 38678bbd1dd28..ef5c55f91c39a 100644
--- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala
@@ -27,7 +27,7 @@ class StorageSuite extends FunSuite {
// For testing add, update, and remove (for non-RDD blocks)
private def storageStatus1: StorageStatus = {
- val status = new StorageStatus(BlockManagerId("big", "dog", 1, 1), 1000L)
+ val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L)
assert(status.blocks.isEmpty)
assert(status.rddBlocks.isEmpty)
assert(status.memUsed === 0L)
@@ -78,7 +78,7 @@ class StorageSuite extends FunSuite {
// For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks
private def storageStatus2: StorageStatus = {
- val status = new StorageStatus(BlockManagerId("big", "dog", 1, 1), 1000L)
+ val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L)
assert(status.rddBlocks.isEmpty)
status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L, 0L))
status.addBlock(TestBlockId("man"), BlockStatus(memAndDisk, 10L, 20L, 0L))
@@ -271,9 +271,9 @@ class StorageSuite extends FunSuite {
// For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations
private def stockStorageStatuses: Seq[StorageStatus] = {
- val status1 = new StorageStatus(BlockManagerId("big", "dog", 1, 1), 1000L)
- val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2, 2), 2000L)
- val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3, 3), 3000L)
+ val status1 = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L)
+ val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2), 2000L)
+ val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3), 3000L)
status1.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L, 0L))
status1.addBlock(RDDBlockId(0, 1), BlockStatus(memAndDisk, 1L, 2L, 0L))
status2.addBlock(RDDBlockId(0, 2), BlockStatus(memAndDisk, 1L, 2L, 0L))
diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
index 038746d2eda4b..92a21f82f3c21 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
@@ -21,10 +21,8 @@ import java.net.ServerSocket
import javax.servlet.http.HttpServletRequest
import scala.io.Source
-import scala.language.postfixOps
import scala.util.{Failure, Success, Try}
-import org.eclipse.jetty.server.Server
import org.eclipse.jetty.servlet.ServletContextHandler
import org.scalatest.FunSuite
import org.scalatest.concurrent.Eventually._
@@ -36,11 +34,25 @@ import scala.xml.Node
class UISuite extends FunSuite {
+ /**
+ * Create a test SparkContext with the SparkUI enabled.
+ * It is safe to `get` the SparkUI directly from the SparkContext returned here.
+ */
+ private def newSparkContext(): SparkContext = {
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName("test")
+ .set("spark.ui.enabled", "true")
+ val sc = new SparkContext(conf)
+ assert(sc.ui.isDefined)
+ sc
+ }
+
ignore("basic ui visibility") {
- withSpark(new SparkContext("local", "test")) { sc =>
+ withSpark(newSparkContext()) { sc =>
// test if the ui is visible, and all the expected tabs are visible
eventually(timeout(10 seconds), interval(50 milliseconds)) {
- val html = Source.fromURL(sc.ui.appUIAddress).mkString
+ val html = Source.fromURL(sc.ui.get.appUIAddress).mkString
assert(!html.contains("random data that should not be present"))
assert(html.toLowerCase.contains("stages"))
assert(html.toLowerCase.contains("storage"))
@@ -51,7 +63,7 @@ class UISuite extends FunSuite {
}
ignore("visibility at localhost:4040") {
- withSpark(new SparkContext("local", "test")) { sc =>
+ withSpark(newSparkContext()) { sc =>
// test if visible from http://localhost:4040
eventually(timeout(10 seconds), interval(50 milliseconds)) {
val html = Source.fromURL("http://localhost:4040").mkString
@@ -61,8 +73,8 @@ class UISuite extends FunSuite {
}
ignore("attaching a new tab") {
- withSpark(new SparkContext("local", "test")) { sc =>
- val sparkUI = sc.ui
+ withSpark(newSparkContext()) { sc =>
+ val sparkUI = sc.ui.get
val newTab = new WebUITab(sparkUI, "foo") {
attachPage(new WebUIPage("") {
@@ -73,7 +85,7 @@ class UISuite extends FunSuite {
}
sparkUI.attachTab(newTab)
eventually(timeout(10 seconds), interval(50 milliseconds)) {
- val html = Source.fromURL(sc.ui.appUIAddress).mkString
+ val html = Source.fromURL(sparkUI.appUIAddress).mkString
assert(!html.contains("random data that should not be present"))
// check whether new page exists
@@ -87,7 +99,7 @@ class UISuite extends FunSuite {
}
eventually(timeout(10 seconds), interval(50 milliseconds)) {
- val html = Source.fromURL(sc.ui.appUIAddress.stripSuffix("/") + "/foo").mkString
+ val html = Source.fromURL(sparkUI.appUIAddress.stripSuffix("/") + "/foo").mkString
// check whether new page exists
assert(html.contains("magic"))
}
@@ -95,14 +107,8 @@ class UISuite extends FunSuite {
}
test("jetty selects different port under contention") {
- val startPort = 4040
- val server = new Server(startPort)
-
- Try { server.start() } match {
- case Success(s) =>
- case Failure(e) =>
- // Either case server port is busy hence setup for test complete
- }
+ val server = new ServerSocket(0)
+ val startPort = server.getLocalPort
val serverInfo1 = JettyUtils.startJettyServer(
"0.0.0.0", startPort, Seq[ServletContextHandler](), new SparkConf)
val serverInfo2 = JettyUtils.startJettyServer(
@@ -113,6 +119,9 @@ class UISuite extends FunSuite {
assert(boundPort1 != startPort)
assert(boundPort2 != startPort)
assert(boundPort1 != boundPort2)
+ serverInfo1.server.stop()
+ serverInfo2.server.stop()
+ server.close()
}
test("jetty binds to port 0 correctly") {
@@ -129,16 +138,20 @@ class UISuite extends FunSuite {
}
test("verify appUIAddress contains the scheme") {
- withSpark(new SparkContext("local", "test")) { sc =>
- val uiAddress = sc.ui.appUIAddress
- assert(uiAddress.equals("http://" + sc.ui.appUIHostPort))
+ withSpark(newSparkContext()) { sc =>
+ val ui = sc.ui.get
+ val uiAddress = ui.appUIAddress
+ val uiHostPort = ui.appUIHostPort
+ assert(uiAddress.equals("http://" + uiHostPort))
}
}
test("verify appUIAddress contains the port") {
- withSpark(new SparkContext("local", "test")) { sc =>
- val splitUIAddress = sc.ui.appUIAddress.split(':')
- assert(splitUIAddress(2).toInt == sc.ui.boundPort)
+ withSpark(newSparkContext()) { sc =>
+ val ui = sc.ui.get
+ val splitUIAddress = ui.appUIAddress.split(':')
+ val boundPort = ui.boundPort
+ assert(splitUIAddress(2).toInt == boundPort)
}
}
}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 147ec0bc52e39..3370dd4156c3f 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -34,12 +34,12 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
val listener = new JobProgressListener(conf)
def createStageStartEvent(stageId: Int) = {
- val stageInfo = new StageInfo(stageId, stageId.toString, 0, null, "")
+ val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "")
SparkListenerStageSubmitted(stageInfo)
}
def createStageEndEvent(stageId: Int) = {
- val stageInfo = new StageInfo(stageId, stageId.toString, 0, null, "")
+ val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "")
SparkListenerStageCompleted(stageInfo)
}
@@ -70,33 +70,37 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
taskInfo.finishTime = 1
var task = new ShuffleMapTask(0)
val taskType = Utils.getFormattedClassName(task)
- listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
- assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail())
- .shuffleRead === 1000)
+ listener.onTaskEnd(
+ SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToData.getOrElse((0, 0), fail())
+ .executorSummary.getOrElse("exe-1", fail()).shuffleRead === 1000)
// finish a task with unknown executor-id, nothing should happen
taskInfo =
new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true)
taskInfo.finishTime = 1
task = new ShuffleMapTask(0)
- listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
+ listener.onTaskEnd(
+ SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics))
assert(listener.stageIdToData.size === 1)
// finish this task, should get updated duration
taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
task = new ShuffleMapTask(0)
- listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
- assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail())
- .shuffleRead === 2000)
+ listener.onTaskEnd(
+ SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToData.getOrElse((0, 0), fail())
+ .executorSummary.getOrElse("exe-1", fail()).shuffleRead === 2000)
// finish this task, should get updated duration
taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
task = new ShuffleMapTask(0)
- listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
- assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-2", fail())
- .shuffleRead === 1000)
+ listener.onTaskEnd(
+ SparkListenerTaskEnd(task.stageId, 0, taskType, Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToData.getOrElse((0, 0), fail())
+ .executorSummary.getOrElse("exe-2", fail()).shuffleRead === 1000)
}
test("test task success vs failure counting for different task end reasons") {
@@ -119,16 +123,18 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
UnknownReason)
var failCount = 0
for (reason <- taskFailedReasons) {
- listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, reason, taskInfo, metrics))
+ listener.onTaskEnd(
+ SparkListenerTaskEnd(task.stageId, 0, taskType, reason, taskInfo, metrics))
failCount += 1
- assert(listener.stageIdToData(task.stageId).numCompleteTasks === 0)
- assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount)
+ assert(listener.stageIdToData((task.stageId, 0)).numCompleteTasks === 0)
+ assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === failCount)
}
// Make sure we count success as success.
- listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, metrics))
- assert(listener.stageIdToData(task.stageId).numCompleteTasks === 1)
- assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount)
+ listener.onTaskEnd(
+ SparkListenerTaskEnd(task.stageId, 1, taskType, Success, taskInfo, metrics))
+ assert(listener.stageIdToData((task.stageId, 1)).numCompleteTasks === 1)
+ assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === failCount)
}
test("test update metrics") {
@@ -163,18 +169,18 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
taskInfo
}
- listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1234L)))
- listener.onTaskStart(SparkListenerTaskStart(0, makeTaskInfo(1235L)))
- listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1236L)))
- listener.onTaskStart(SparkListenerTaskStart(1, makeTaskInfo(1237L)))
+ listener.onTaskStart(SparkListenerTaskStart(0, 0, makeTaskInfo(1234L)))
+ listener.onTaskStart(SparkListenerTaskStart(0, 0, makeTaskInfo(1235L)))
+ listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1236L)))
+ listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L)))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array(
- (1234L, 0, makeTaskMetrics(0)),
- (1235L, 0, makeTaskMetrics(100)),
- (1236L, 1, makeTaskMetrics(200)))))
+ (1234L, 0, 0, makeTaskMetrics(0)),
+ (1235L, 0, 0, makeTaskMetrics(100)),
+ (1236L, 1, 0, makeTaskMetrics(200)))))
- var stage0Data = listener.stageIdToData.get(0).get
- var stage1Data = listener.stageIdToData.get(1).get
+ var stage0Data = listener.stageIdToData.get((0, 0)).get
+ var stage1Data = listener.stageIdToData.get((1, 0)).get
assert(stage0Data.shuffleReadBytes == 102)
assert(stage1Data.shuffleReadBytes == 201)
assert(stage0Data.shuffleWriteBytes == 106)
@@ -195,14 +201,14 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
.totalBlocksFetched == 202)
// task that was included in a heartbeat
- listener.onTaskEnd(SparkListenerTaskEnd(0, taskType, Success, makeTaskInfo(1234L, 1),
+ listener.onTaskEnd(SparkListenerTaskEnd(0, 0, taskType, Success, makeTaskInfo(1234L, 1),
makeTaskMetrics(300)))
// task that wasn't included in a heartbeat
- listener.onTaskEnd(SparkListenerTaskEnd(1, taskType, Success, makeTaskInfo(1237L, 1),
+ listener.onTaskEnd(SparkListenerTaskEnd(1, 0, taskType, Success, makeTaskInfo(1237L, 1),
makeTaskMetrics(400)))
- stage0Data = listener.stageIdToData.get(0).get
- stage1Data = listener.stageIdToData.get(1).get
+ stage0Data = listener.stageIdToData.get((0, 0)).get
+ stage1Data = listener.stageIdToData.get((1, 0)).get
assert(stage0Data.shuffleReadBytes == 402)
assert(stage1Data.shuffleReadBytes == 602)
assert(stage0Data.shuffleWriteBytes == 406)
diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
index 6e68dcb3425aa..e1bc1379b5d80 100644
--- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
@@ -34,11 +34,12 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
private val memOnly = StorageLevel.MEMORY_ONLY
private val none = StorageLevel.NONE
private val taskInfo = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false)
+ private val taskInfo1 = new TaskInfo(1, 1, 1, 1, "big", "cat", TaskLocality.ANY, false)
private def rddInfo0 = new RDDInfo(0, "freedom", 100, memOnly)
private def rddInfo1 = new RDDInfo(1, "hostage", 200, memOnly)
private def rddInfo2 = new RDDInfo(2, "sanity", 300, memAndDisk)
private def rddInfo3 = new RDDInfo(3, "grace", 400, memAndDisk)
- private val bm1 = BlockManagerId("big", "dog", 1, 1)
+ private val bm1 = BlockManagerId("big", "dog", 1)
before {
bus = new LiveListenerBus
@@ -53,7 +54,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
assert(storageListener.rddInfoList.isEmpty)
// 2 RDDs are known, but none are cached
- val stageInfo0 = new StageInfo(0, "0", 100, Seq(rddInfo0, rddInfo1), "details")
+ val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(rddInfo0, rddInfo1), "details")
bus.postToAll(SparkListenerStageSubmitted(stageInfo0))
assert(storageListener._rddInfoMap.size === 2)
assert(storageListener.rddInfoList.isEmpty)
@@ -63,7 +64,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
val rddInfo3Cached = rddInfo3
rddInfo2Cached.numCachedPartitions = 1
rddInfo3Cached.numCachedPartitions = 1
- val stageInfo1 = new StageInfo(1, "0", 100, Seq(rddInfo2Cached, rddInfo3Cached), "details")
+ val stageInfo1 = new StageInfo(1, 0, "0", 100, Seq(rddInfo2Cached, rddInfo3Cached), "details")
bus.postToAll(SparkListenerStageSubmitted(stageInfo1))
assert(storageListener._rddInfoMap.size === 4)
assert(storageListener.rddInfoList.size === 2)
@@ -71,7 +72,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
// Submitting RDDInfos with duplicate IDs does nothing
val rddInfo0Cached = new RDDInfo(0, "freedom", 100, StorageLevel.MEMORY_ONLY)
rddInfo0Cached.numCachedPartitions = 1
- val stageInfo0Cached = new StageInfo(0, "0", 100, Seq(rddInfo0), "details")
+ val stageInfo0Cached = new StageInfo(0, 0, "0", 100, Seq(rddInfo0), "details")
bus.postToAll(SparkListenerStageSubmitted(stageInfo0Cached))
assert(storageListener._rddInfoMap.size === 4)
assert(storageListener.rddInfoList.size === 2)
@@ -87,7 +88,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
val rddInfo1Cached = rddInfo1
rddInfo0Cached.numCachedPartitions = 1
rddInfo1Cached.numCachedPartitions = 1
- val stageInfo0 = new StageInfo(0, "0", 100, Seq(rddInfo0Cached, rddInfo1Cached), "details")
+ val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(rddInfo0Cached, rddInfo1Cached), "details")
bus.postToAll(SparkListenerStageSubmitted(stageInfo0))
assert(storageListener._rddInfoMap.size === 2)
assert(storageListener.rddInfoList.size === 2)
@@ -106,8 +107,8 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
val myRddInfo0 = rddInfo0
val myRddInfo1 = rddInfo1
val myRddInfo2 = rddInfo2
- val stageInfo0 = new StageInfo(0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details")
- bus.postToAll(SparkListenerBlockManagerAdded(bm1, 1000L))
+ val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details")
+ bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
bus.postToAll(SparkListenerStageSubmitted(stageInfo0))
assert(storageListener._rddInfoMap.size === 3)
assert(storageListener.rddInfoList.size === 0) // not cached
@@ -116,7 +117,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
assert(!storageListener._rddInfoMap(2).isCached)
// Task end with no updated blocks. This should not change anything.
- bus.postToAll(SparkListenerTaskEnd(0, "obliteration", Success, taskInfo, new TaskMetrics))
+ bus.postToAll(SparkListenerTaskEnd(0, 0, "obliteration", Success, taskInfo, new TaskMetrics))
assert(storageListener._rddInfoMap.size === 3)
assert(storageListener.rddInfoList.size === 0)
@@ -128,7 +129,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
(RDDBlockId(0, 102), BlockStatus(memAndDisk, 400L, 0L, 200L)),
(RDDBlockId(1, 20), BlockStatus(memAndDisk, 0L, 240L, 0L))
))
- bus.postToAll(SparkListenerTaskEnd(1, "obliteration", Success, taskInfo, metrics1))
+ bus.postToAll(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo, metrics1))
assert(storageListener._rddInfoMap(0).memSize === 800L)
assert(storageListener._rddInfoMap(0).diskSize === 400L)
assert(storageListener._rddInfoMap(0).tachyonSize === 200L)
@@ -150,7 +151,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
(RDDBlockId(2, 40), BlockStatus(none, 0L, 0L, 0L)), // doesn't actually exist
(RDDBlockId(4, 80), BlockStatus(none, 0L, 0L, 0L)) // doesn't actually exist
))
- bus.postToAll(SparkListenerTaskEnd(2, "obliteration", Success, taskInfo, metrics2))
+ bus.postToAll(SparkListenerTaskEnd(2, 0, "obliteration", Success, taskInfo, metrics2))
assert(storageListener._rddInfoMap(0).memSize === 400L)
assert(storageListener._rddInfoMap(0).diskSize === 400L)
assert(storageListener._rddInfoMap(0).tachyonSize === 200L)
@@ -162,4 +163,30 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
assert(storageListener._rddInfoMap(2).numCachedPartitions === 0)
}
+ test("verify StorageTab contains all cached rdds") {
+
+ val rddInfo0 = new RDDInfo(0, "rdd0", 1, memOnly)
+ val rddInfo1 = new RDDInfo(1, "rdd1", 1 ,memOnly)
+ val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo0), "details")
+ val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), "details")
+ val taskMetrics0 = new TaskMetrics
+ val taskMetrics1 = new TaskMetrics
+ val block0 = (RDDBlockId(0, 1), BlockStatus(memOnly, 100L, 0L, 0L))
+ val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L, 0L))
+ taskMetrics0.updatedBlocks = Some(Seq(block0))
+ taskMetrics1.updatedBlocks = Some(Seq(block1))
+ bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
+ bus.postToAll(SparkListenerStageSubmitted(stageInfo0))
+ assert(storageListener.rddInfoList.size === 0)
+ bus.postToAll(SparkListenerTaskEnd(0, 0, "big", Success, taskInfo, taskMetrics0))
+ assert(storageListener.rddInfoList.size === 1)
+ bus.postToAll(SparkListenerStageSubmitted(stageInfo1))
+ assert(storageListener.rddInfoList.size === 1)
+ bus.postToAll(SparkListenerStageCompleted(stageInfo0))
+ assert(storageListener.rddInfoList.size === 1)
+ bus.postToAll(SparkListenerTaskEnd(1, 0, "small", Success, taskInfo1, taskMetrics1))
+ assert(storageListener.rddInfoList.size === 2)
+ bus.postToAll(SparkListenerStageCompleted(stageInfo1))
+ assert(storageListener.rddInfoList.size === 2)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
index c4765e53de17b..76bf4cfd11267 100644
--- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
@@ -17,13 +17,16 @@
package org.apache.spark.util
+import scala.concurrent.Await
+
import akka.actor._
+
+import org.scalatest.FunSuite
+
import org.apache.spark._
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
-import org.scalatest.FunSuite
-import scala.concurrent.Await
/**
* Test the AkkaUtils with various security settings.
@@ -35,7 +38,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
conf.set("spark.authenticate", "true")
conf.set("spark.authenticate.secret", "good")
- val securityManager = new SecurityManager(conf);
+ val securityManager = new SecurityManager(conf)
val hostname = "localhost"
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
conf = conf, securityManager = securityManager)
@@ -106,13 +109,13 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
masterTracker.registerMapOutput(10, 0, new MapStatus(
- BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
// this should succeed since security off
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+ Seq((BlockManagerId("a", "hostA", 1000), size1000)))
actorSystem.shutdown()
slaveSystem.shutdown()
@@ -157,13 +160,13 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
masterTracker.registerMapOutput(10, 0, new MapStatus(
- BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
// this should succeed since security on and passwords match
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+ Seq((BlockManagerId("a", "hostA", 1000), size1000)))
actorSystem.shutdown()
slaveSystem.shutdown()
diff --git a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala
index 44332fc8dbc23..c3dd156b40514 100644
--- a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala
@@ -26,13 +26,15 @@ import org.apache.hadoop.fs.Path
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.apache.spark.SparkConf
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.io.CompressionCodec
/**
* Test writing files through the FileLogger.
*/
class FileLoggerSuite extends FunSuite with BeforeAndAfter {
- private val fileSystem = Utils.getHadoopFileSystem("/")
+ private val fileSystem = Utils.getHadoopFileSystem("/",
+ SparkHadoopUtil.get.newConfiguration(new SparkConf()))
private val allCompressionCodecs = Seq[String](
"org.apache.spark.io.LZFCompressionCodec",
"org.apache.spark.io.SnappyCompressionCodec"
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 97ffb07662482..2b45d8b695853 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -21,6 +21,9 @@ import java.util.Properties
import scala.collection.Map
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.JsonAST._
import org.json4s.jackson.JsonMethods._
import org.scalatest.FunSuite
@@ -35,13 +38,13 @@ class JsonProtocolSuite extends FunSuite {
val stageSubmitted =
SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties)
val stageCompleted = SparkListenerStageCompleted(makeStageInfo(101, 201, 301, 401L, 501L))
- val taskStart = SparkListenerTaskStart(111, makeTaskInfo(222L, 333, 1, 444L, false))
+ val taskStart = SparkListenerTaskStart(111, 0, makeTaskInfo(222L, 333, 1, 444L, false))
val taskGettingResult =
SparkListenerTaskGettingResult(makeTaskInfo(1000L, 2000, 5, 3000L, true))
- val taskEnd = SparkListenerTaskEnd(1, "ShuffleMapTask", Success,
+ val taskEnd = SparkListenerTaskEnd(1, 0, "ShuffleMapTask", Success,
makeTaskInfo(123L, 234, 67, 345L, false),
makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = false))
- val taskEndWithHadoopInput = SparkListenerTaskEnd(1, "ShuffleMapTask", Success,
+ val taskEndWithHadoopInput = SparkListenerTaskEnd(1, 0, "ShuffleMapTask", Success,
makeTaskInfo(123L, 234, 67, 345L, false),
makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true))
val jobStart = SparkListenerJobStart(10, Seq[Int](1, 2, 3, 4), properties)
@@ -52,12 +55,12 @@ class JsonProtocolSuite extends FunSuite {
"System Properties" -> Seq(("Username", "guest"), ("Password", "guest")),
"Classpath Entries" -> Seq(("Super library", "/tmp/super_library"))
))
- val blockManagerAdded = SparkListenerBlockManagerAdded(
- BlockManagerId("Stars", "In your multitude...", 300, 400), 500)
- val blockManagerRemoved = SparkListenerBlockManagerRemoved(
- BlockManagerId("Scarce", "to be counted...", 100, 200))
+ val blockManagerAdded = SparkListenerBlockManagerAdded(1L,
+ BlockManagerId("Stars", "In your multitude...", 300), 500)
+ val blockManagerRemoved = SparkListenerBlockManagerRemoved(2L,
+ BlockManagerId("Scarce", "to be counted...", 100))
val unpersistRdd = SparkListenerUnpersistRDD(12345)
- val applicationStart = SparkListenerApplicationStart("The winner of all", 42L, "Garfield")
+ val applicationStart = SparkListenerApplicationStart("The winner of all", None, 42L, "Garfield")
val applicationEnd = SparkListenerApplicationEnd(42L)
testEvent(stageSubmitted, stageSubmittedJsonString)
@@ -81,7 +84,7 @@ class JsonProtocolSuite extends FunSuite {
testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L))
testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false))
testTaskMetrics(makeTaskMetrics(33333L, 44444L, 55555L, 66666L, 7, 8, hasHadoopInput = false))
- testBlockManagerId(BlockManagerId("Hong", "Kong", 500, 1000))
+ testBlockManagerId(BlockManagerId("Hong", "Kong", 500))
// StorageLevel
testStorageLevel(StorageLevel.NONE)
@@ -104,7 +107,7 @@ class JsonProtocolSuite extends FunSuite {
testJobResult(jobFailed)
// TaskEndReason
- val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15, 16), 17, 18, 19)
+ val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19)
val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, None)
testTaskEndReason(Success)
testTaskEndReason(Resubmitted)
@@ -151,6 +154,35 @@ class JsonProtocolSuite extends FunSuite {
assert(newMetrics.inputMetrics.isEmpty)
}
+ test("BlockManager events backward compatibility") {
+ // SparkListenerBlockManagerAdded/Removed in Spark 1.0.0 do not have a "time" property.
+ val blockManagerAdded = SparkListenerBlockManagerAdded(1L,
+ BlockManagerId("Stars", "In your multitude...", 300), 500)
+ val blockManagerRemoved = SparkListenerBlockManagerRemoved(2L,
+ BlockManagerId("Scarce", "to be counted...", 100))
+
+ val oldBmAdded = JsonProtocol.blockManagerAddedToJson(blockManagerAdded)
+ .removeField({ _._1 == "Timestamp" })
+
+ val deserializedBmAdded = JsonProtocol.blockManagerAddedFromJson(oldBmAdded)
+ assert(SparkListenerBlockManagerAdded(-1L, blockManagerAdded.blockManagerId,
+ blockManagerAdded.maxMem) === deserializedBmAdded)
+
+ val oldBmRemoved = JsonProtocol.blockManagerRemovedToJson(blockManagerRemoved)
+ .removeField({ _._1 == "Timestamp" })
+
+ val deserializedBmRemoved = JsonProtocol.blockManagerRemovedFromJson(oldBmRemoved)
+ assert(SparkListenerBlockManagerRemoved(-1L, blockManagerRemoved.blockManagerId) ===
+ deserializedBmRemoved)
+ }
+
+ test("SparkListenerApplicationStart backwards compatibility") {
+ // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property.
+ val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user")
+ val oldEvent = JsonProtocol.applicationStartToJson(applicationStart)
+ .removeField({ _._1 == "App ID" })
+ assert(applicationStart === JsonProtocol.applicationStartFromJson(oldEvent))
+ }
/** -------------------------- *
| Helper test running methods |
@@ -242,8 +274,10 @@ class JsonProtocolSuite extends FunSuite {
assertEquals(e1.environmentDetails, e2.environmentDetails)
case (e1: SparkListenerBlockManagerAdded, e2: SparkListenerBlockManagerAdded) =>
assert(e1.maxMem === e2.maxMem)
+ assert(e1.time === e2.time)
assertEquals(e1.blockManagerId, e2.blockManagerId)
case (e1: SparkListenerBlockManagerRemoved, e2: SparkListenerBlockManagerRemoved) =>
+ assert(e1.time === e2.time)
assertEquals(e1.blockManagerId, e2.blockManagerId)
case (e1: SparkListenerUnpersistRDD, e2: SparkListenerUnpersistRDD) =>
assert(e1.rddId == e2.rddId)
@@ -343,7 +377,6 @@ class JsonProtocolSuite extends FunSuite {
assert(bm1.executorId === bm2.executorId)
assert(bm1.host === bm2.host)
assert(bm1.port === bm2.port)
- assert(bm1.nettyPort === bm2.nettyPort)
}
private def assertEquals(result1: JobResult, result2: JobResult) {
@@ -397,7 +430,8 @@ class JsonProtocolSuite extends FunSuite {
private def assertJsonStringEquals(json1: String, json2: String) {
val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "")
- assert(formatJsonString(json1) === formatJsonString(json2))
+ assert(formatJsonString(json1) === formatJsonString(json2),
+ s"input ${formatJsonString(json1)} got ${formatJsonString(json2)}")
}
private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) {
@@ -485,7 +519,7 @@ class JsonProtocolSuite extends FunSuite {
private def makeStageInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = {
val rddInfos = (0 until a % 5).map { i => makeRddInfo(a + i, b + i, c + i, d + i, e + i) }
- val stageInfo = new StageInfo(a, "greetings", b, rddInfos, "details")
+ val stageInfo = new StageInfo(a, 0, "greetings", b, rddInfos, "details")
val (acc1, acc2) = (makeAccumulableInfo(1), makeAccumulableInfo(2))
stageInfo.accumulables(acc1.id) = acc1
stageInfo.accumulables(acc2.id) = acc2
@@ -558,84 +592,246 @@ class JsonProtocolSuite extends FunSuite {
private val stageSubmittedJsonString =
"""
- {"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":100,"Stage Name":
- "greetings","Number of Tasks":200,"RDD Info":[],"Details":"details",
- "Accumulables":[{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"},
- {"ID":1,"Name":"Accumulable1","Update":"delta1","Value":"val1"}]},"Properties":
- {"France":"Paris","Germany":"Berlin","Russia":"Moscow","Ukraine":"Kiev"}}
+ |{
+ | "Event": "SparkListenerStageSubmitted",
+ | "Stage Info": {
+ | "Stage ID": 100,
+ | "Stage Attempt ID": 0,
+ | "Stage Name": "greetings",
+ | "Number of Tasks": 200,
+ | "RDD Info": [],
+ | "Details": "details",
+ | "Accumulables": [
+ | {
+ | "ID": 2,
+ | "Name": "Accumulable2",
+ | "Update": "delta2",
+ | "Value": "val2"
+ | },
+ | {
+ | "ID": 1,
+ | "Name": "Accumulable1",
+ | "Update": "delta1",
+ | "Value": "val1"
+ | }
+ | ]
+ | },
+ | "Properties": {
+ | "France": "Paris",
+ | "Germany": "Berlin",
+ | "Russia": "Moscow",
+ | "Ukraine": "Kiev"
+ | }
+ |}
"""
private val stageCompletedJsonString =
"""
- {"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":101,"Stage Name":
- "greetings","Number of Tasks":201,"RDD Info":[{"RDD ID":101,"Name":"mayor","Storage
- Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":true,
- "Replication":1},"Number of Partitions":201,"Number of Cached Partitions":301,
- "Memory Size":401,"Tachyon Size":0,"Disk Size":501}],"Details":"details",
- "Accumulables":[{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"},
- {"ID":1,"Name":"Accumulable1","Update":"delta1","Value":"val1"}]}}
+ |{
+ | "Event": "SparkListenerStageCompleted",
+ | "Stage Info": {
+ | "Stage ID": 101,
+ | "Stage Attempt ID": 0,
+ | "Stage Name": "greetings",
+ | "Number of Tasks": 201,
+ | "RDD Info": [
+ | {
+ | "RDD ID": 101,
+ | "Name": "mayor",
+ | "Storage Level": {
+ | "Use Disk": true,
+ | "Use Memory": true,
+ | "Use Tachyon": false,
+ | "Deserialized": true,
+ | "Replication": 1
+ | },
+ | "Number of Partitions": 201,
+ | "Number of Cached Partitions": 301,
+ | "Memory Size": 401,
+ | "Tachyon Size": 0,
+ | "Disk Size": 501
+ | }
+ | ],
+ | "Details": "details",
+ | "Accumulables": [
+ | {
+ | "ID": 2,
+ | "Name": "Accumulable2",
+ | "Update": "delta2",
+ | "Value": "val2"
+ | },
+ | {
+ | "ID": 1,
+ | "Name": "Accumulable1",
+ | "Update": "delta1",
+ | "Value": "val1"
+ | }
+ | ]
+ | }
+ |}
"""
private val taskStartJsonString =
"""
- |{"Event":"SparkListenerTaskStart","Stage ID":111,"Task Info":{"Task ID":222,
- |"Index":333,"Attempt":1,"Launch Time":444,"Executor ID":"executor","Host":"your kind sir",
- |"Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,
- |"Failed":false,"Accumulables":[{"ID":1,"Name":"Accumulable1","Update":"delta1",
- |"Value":"val1"},{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"},
- |{"ID":3,"Name":"Accumulable3","Update":"delta3","Value":"val3"}]}}
+ |{
+ | "Event": "SparkListenerTaskStart",
+ | "Stage ID": 111,
+ | "Stage Attempt ID": 0,
+ | "Task Info": {
+ | "Task ID": 222,
+ | "Index": 333,
+ | "Attempt": 1,
+ | "Launch Time": 444,
+ | "Executor ID": "executor",
+ | "Host": "your kind sir",
+ | "Locality": "NODE_LOCAL",
+ | "Speculative": false,
+ | "Getting Result Time": 0,
+ | "Finish Time": 0,
+ | "Failed": false,
+ | "Accumulables": [
+ | {
+ | "ID": 1,
+ | "Name": "Accumulable1",
+ | "Update": "delta1",
+ | "Value": "val1"
+ | },
+ | {
+ | "ID": 2,
+ | "Name": "Accumulable2",
+ | "Update": "delta2",
+ | "Value": "val2"
+ | },
+ | {
+ | "ID": 3,
+ | "Name": "Accumulable3",
+ | "Update": "delta3",
+ | "Value": "val3"
+ | }
+ | ]
+ | }
+ |}
""".stripMargin
private val taskGettingResultJsonString =
"""
- |{"Event":"SparkListenerTaskGettingResult","Task Info":
- | {"Task ID":1000,"Index":2000,"Attempt":5,"Launch Time":3000,"Executor ID":"executor",
- | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":true,"Getting Result Time":0,
- | "Finish Time":0,"Failed":false,
- | "Accumulables":[{"ID":1,"Name":"Accumulable1","Update":"delta1",
- | "Value":"val1"},{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"},
- | {"ID":3,"Name":"Accumulable3","Update":"delta3","Value":"val3"}]
+ |{
+ | "Event": "SparkListenerTaskGettingResult",
+ | "Task Info": {
+ | "Task ID": 1000,
+ | "Index": 2000,
+ | "Attempt": 5,
+ | "Launch Time": 3000,
+ | "Executor ID": "executor",
+ | "Host": "your kind sir",
+ | "Locality": "NODE_LOCAL",
+ | "Speculative": true,
+ | "Getting Result Time": 0,
+ | "Finish Time": 0,
+ | "Failed": false,
+ | "Accumulables": [
+ | {
+ | "ID": 1,
+ | "Name": "Accumulable1",
+ | "Update": "delta1",
+ | "Value": "val1"
+ | },
+ | {
+ | "ID": 2,
+ | "Name": "Accumulable2",
+ | "Update": "delta2",
+ | "Value": "val2"
+ | },
+ | {
+ | "ID": 3,
+ | "Name": "Accumulable3",
+ | "Update": "delta3",
+ | "Value": "val3"
+ | }
+ | ]
| }
|}
""".stripMargin
private val taskEndJsonString =
"""
- |{"Event":"SparkListenerTaskEnd","Stage ID":1,"Task Type":"ShuffleMapTask",
- |"Task End Reason":{"Reason":"Success"},
- |"Task Info":{
- | "Task ID":123,"Index":234,"Attempt":67,"Launch Time":345,"Executor ID":"executor",
- | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":false,
- | "Getting Result Time":0,"Finish Time":0,"Failed":false,
- | "Accumulables":[{"ID":1,"Name":"Accumulable1","Update":"delta1",
- | "Value":"val1"},{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"},
- | {"ID":3,"Name":"Accumulable3","Update":"delta3","Value":"val3"}]
- |},
- |"Task Metrics":{
- | "Host Name":"localhost","Executor Deserialize Time":300,"Executor Run Time":400,
- | "Result Size":500,"JVM GC Time":600,"Result Serialization Time":700,
- | "Memory Bytes Spilled":800,"Disk Bytes Spilled":0,
- | "Shuffle Read Metrics":{
- | "Shuffle Finish Time":900,
- | "Remote Blocks Fetched":800,
- | "Local Blocks Fetched":700,
- | "Fetch Wait Time":900,
- | "Remote Bytes Read":1000
+ |{
+ | "Event": "SparkListenerTaskEnd",
+ | "Stage ID": 1,
+ | "Stage Attempt ID": 0,
+ | "Task Type": "ShuffleMapTask",
+ | "Task End Reason": {
+ | "Reason": "Success"
| },
- | "Shuffle Write Metrics":{
- | "Shuffle Bytes Written":1200,
- | "Shuffle Write Time":1500
+ | "Task Info": {
+ | "Task ID": 123,
+ | "Index": 234,
+ | "Attempt": 67,
+ | "Launch Time": 345,
+ | "Executor ID": "executor",
+ | "Host": "your kind sir",
+ | "Locality": "NODE_LOCAL",
+ | "Speculative": false,
+ | "Getting Result Time": 0,
+ | "Finish Time": 0,
+ | "Failed": false,
+ | "Accumulables": [
+ | {
+ | "ID": 1,
+ | "Name": "Accumulable1",
+ | "Update": "delta1",
+ | "Value": "val1"
+ | },
+ | {
+ | "ID": 2,
+ | "Name": "Accumulable2",
+ | "Update": "delta2",
+ | "Value": "val2"
+ | },
+ | {
+ | "ID": 3,
+ | "Name": "Accumulable3",
+ | "Update": "delta3",
+ | "Value": "val3"
+ | }
+ | ]
| },
- | "Updated Blocks":[
- | {"Block ID":"rdd_0_0",
- | "Status":{
- | "Storage Level":{
- | "Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":false,
- | "Replication":2
- | },
- | "Memory Size":0,"Tachyon Size":0,"Disk Size":0
+ | "Task Metrics": {
+ | "Host Name": "localhost",
+ | "Executor Deserialize Time": 300,
+ | "Executor Run Time": 400,
+ | "Result Size": 500,
+ | "JVM GC Time": 600,
+ | "Result Serialization Time": 700,
+ | "Memory Bytes Spilled": 800,
+ | "Disk Bytes Spilled": 0,
+ | "Shuffle Read Metrics": {
+ | "Shuffle Finish Time": 900,
+ | "Remote Blocks Fetched": 800,
+ | "Local Blocks Fetched": 700,
+ | "Fetch Wait Time": 900,
+ | "Remote Bytes Read": 1000
+ | },
+ | "Shuffle Write Metrics": {
+ | "Shuffle Bytes Written": 1200,
+ | "Shuffle Write Time": 1500
+ | },
+ | "Updated Blocks": [
+ | {
+ | "Block ID": "rdd_0_0",
+ | "Status": {
+ | "Storage Level": {
+ | "Use Disk": true,
+ | "Use Memory": true,
+ | "Use Tachyon": false,
+ | "Deserialized": false,
+ | "Replication": 2
+ | },
+ | "Memory Size": 0,
+ | "Tachyon Size": 0,
+ | "Disk Size": 0
+ | }
| }
- | }
| ]
| }
|}
@@ -643,80 +839,187 @@ class JsonProtocolSuite extends FunSuite {
private val taskEndWithHadoopInputJsonString =
"""
- |{"Event":"SparkListenerTaskEnd","Stage ID":1,"Task Type":"ShuffleMapTask",
- |"Task End Reason":{"Reason":"Success"},
- |"Task Info":{
- | "Task ID":123,"Index":234,"Attempt":67,"Launch Time":345,"Executor ID":"executor",
- | "Host":"your kind sir","Locality":"NODE_LOCAL","Speculative":false,
- | "Getting Result Time":0,"Finish Time":0,"Failed":false,
- | "Accumulables":[{"ID":1,"Name":"Accumulable1","Update":"delta1",
- | "Value":"val1"},{"ID":2,"Name":"Accumulable2","Update":"delta2","Value":"val2"},
- | {"ID":3,"Name":"Accumulable3","Update":"delta3","Value":"val3"}]
- |},
- |"Task Metrics":{
- | "Host Name":"localhost","Executor Deserialize Time":300,"Executor Run Time":400,
- | "Result Size":500,"JVM GC Time":600,"Result Serialization Time":700,
- | "Memory Bytes Spilled":800,"Disk Bytes Spilled":0,
- | "Shuffle Write Metrics":{"Shuffle Bytes Written":1200,"Shuffle Write Time":1500},
- | "Input Metrics":{"Data Read Method":"Hadoop","Bytes Read":2100},
- | "Updated Blocks":[
- | {"Block ID":"rdd_0_0",
- | "Status":{
- | "Storage Level":{
- | "Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":false,
- | "Replication":2
- | },
- | "Memory Size":0,"Tachyon Size":0,"Disk Size":0
+ |{
+ | "Event": "SparkListenerTaskEnd",
+ | "Stage ID": 1,
+ | "Stage Attempt ID": 0,
+ | "Task Type": "ShuffleMapTask",
+ | "Task End Reason": {
+ | "Reason": "Success"
+ | },
+ | "Task Info": {
+ | "Task ID": 123,
+ | "Index": 234,
+ | "Attempt": 67,
+ | "Launch Time": 345,
+ | "Executor ID": "executor",
+ | "Host": "your kind sir",
+ | "Locality": "NODE_LOCAL",
+ | "Speculative": false,
+ | "Getting Result Time": 0,
+ | "Finish Time": 0,
+ | "Failed": false,
+ | "Accumulables": [
+ | {
+ | "ID": 1,
+ | "Name": "Accumulable1",
+ | "Update": "delta1",
+ | "Value": "val1"
+ | },
+ | {
+ | "ID": 2,
+ | "Name": "Accumulable2",
+ | "Update": "delta2",
+ | "Value": "val2"
+ | },
+ | {
+ | "ID": 3,
+ | "Name": "Accumulable3",
+ | "Update": "delta3",
+ | "Value": "val3"
+ | }
+ | ]
+ | },
+ | "Task Metrics": {
+ | "Host Name": "localhost",
+ | "Executor Deserialize Time": 300,
+ | "Executor Run Time": 400,
+ | "Result Size": 500,
+ | "JVM GC Time": 600,
+ | "Result Serialization Time": 700,
+ | "Memory Bytes Spilled": 800,
+ | "Disk Bytes Spilled": 0,
+ | "Shuffle Write Metrics": {
+ | "Shuffle Bytes Written": 1200,
+ | "Shuffle Write Time": 1500
+ | },
+ | "Input Metrics": {
+ | "Data Read Method": "Hadoop",
+ | "Bytes Read": 2100
+ | },
+ | "Updated Blocks": [
+ | {
+ | "Block ID": "rdd_0_0",
+ | "Status": {
+ | "Storage Level": {
+ | "Use Disk": true,
+ | "Use Memory": true,
+ | "Use Tachyon": false,
+ | "Deserialized": false,
+ | "Replication": 2
+ | },
+ | "Memory Size": 0,
+ | "Tachyon Size": 0,
+ | "Disk Size": 0
+ | }
| }
- | }
- | ]}
+ | ]
+ | }
|}
"""
private val jobStartJsonString =
"""
- {"Event":"SparkListenerJobStart","Job ID":10,"Stage IDs":[1,2,3,4],"Properties":
- {"France":"Paris","Germany":"Berlin","Russia":"Moscow","Ukraine":"Kiev"}}
+ |{
+ | "Event": "SparkListenerJobStart",
+ | "Job ID": 10,
+ | "Stage IDs": [
+ | 1,
+ | 2,
+ | 3,
+ | 4
+ | ],
+ | "Properties": {
+ | "France": "Paris",
+ | "Germany": "Berlin",
+ | "Russia": "Moscow",
+ | "Ukraine": "Kiev"
+ | }
+ |}
"""
private val jobEndJsonString =
"""
- {"Event":"SparkListenerJobEnd","Job ID":20,"Job Result":{"Result":"JobSucceeded"}}
+ |{
+ | "Event": "SparkListenerJobEnd",
+ | "Job ID": 20,
+ | "Job Result": {
+ | "Result": "JobSucceeded"
+ | }
+ |}
"""
private val environmentUpdateJsonString =
"""
- {"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"GC speed":"9999 objects/s",
- "Java home":"Land of coffee"},"Spark Properties":{"Job throughput":"80000 jobs/s,
- regardless of job type"},"System Properties":{"Username":"guest","Password":"guest"},
- "Classpath Entries":{"Super library":"/tmp/super_library"}}
+ |{
+ | "Event": "SparkListenerEnvironmentUpdate",
+ | "JVM Information": {
+ | "GC speed": "9999 objects/s",
+ | "Java home": "Land of coffee"
+ | },
+ | "Spark Properties": {
+ | "Job throughput": "80000 jobs/s, regardless of job type"
+ | },
+ | "System Properties": {
+ | "Username": "guest",
+ | "Password": "guest"
+ | },
+ | "Classpath Entries": {
+ | "Super library": "/tmp/super_library"
+ | }
+ |}
"""
private val blockManagerAddedJsonString =
"""
- {"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"Stars",
- "Host":"In your multitude...","Port":300,"Netty Port":400},"Maximum Memory":500}
+ |{
+ | "Event": "SparkListenerBlockManagerAdded",
+ | "Block Manager ID": {
+ | "Executor ID": "Stars",
+ | "Host": "In your multitude...",
+ | "Port": 300
+ | },
+ | "Maximum Memory": 500,
+ | "Timestamp": 1
+ |}
"""
private val blockManagerRemovedJsonString =
"""
- {"Event":"SparkListenerBlockManagerRemoved","Block Manager ID":{"Executor ID":"Scarce",
- "Host":"to be counted...","Port":100,"Netty Port":200}}
+ |{
+ | "Event": "SparkListenerBlockManagerRemoved",
+ | "Block Manager ID": {
+ | "Executor ID": "Scarce",
+ | "Host": "to be counted...",
+ | "Port": 100
+ | },
+ | "Timestamp": 2
+ |}
"""
private val unpersistRDDJsonString =
"""
- {"Event":"SparkListenerUnpersistRDD","RDD ID":12345}
+ |{
+ | "Event": "SparkListenerUnpersistRDD",
+ | "RDD ID": 12345
+ |}
"""
private val applicationStartJsonString =
"""
- {"Event":"SparkListenerApplicationStart","App Name":"The winner of all","Timestamp":42,
- "User":"Garfield"}
+ |{
+ | "Event": "SparkListenerApplicationStart",
+ | "App Name": "The winner of all",
+ | "Timestamp": 42,
+ | "User": "Garfield"
+ |}
"""
private val applicationEndJsonString =
"""
- {"Event":"SparkListenerApplicationEnd","Timestamp":42}
+ |{
+ | "Event": "SparkListenerApplicationEnd",
+ | "Timestamp": 42
+ |}
"""
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 04d7338488628..511d76c9144cc 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -23,37 +23,43 @@ import org.scalatest.FunSuite
import org.apache.spark._
import org.apache.spark.SparkContext._
+import org.apache.spark.io.CompressionCodec
class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
+ private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS
+ private def createCombiner[T](i: T) = ArrayBuffer[T](i)
+ private def mergeValue[T](buffer: ArrayBuffer[T], i: T): ArrayBuffer[T] = buffer += i
+ private def mergeCombiners[T](buf1: ArrayBuffer[T], buf2: ArrayBuffer[T]): ArrayBuffer[T] =
+ buf1 ++= buf2
- private def createCombiner(i: Int) = ArrayBuffer[Int](i)
- private def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i
- private def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2
+ private def createExternalMap[T] = new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]](
+ createCombiner[T], mergeValue[T], mergeCombiners[T])
- private def createSparkConf(loadDefaults: Boolean): SparkConf = {
+ private def createSparkConf(loadDefaults: Boolean, codec: Option[String] = None): SparkConf = {
val conf = new SparkConf(loadDefaults)
// Make the Java serializer write a reset instruction (TC_RESET) after each object to test
// for a bug we had with bytes written past the last object in a batch (SPARK-2792)
conf.set("spark.serializer.objectStreamReset", "1")
conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+ conf.set("spark.shuffle.spill.compress", codec.isDefined.toString)
+ conf.set("spark.shuffle.compress", codec.isDefined.toString)
+ codec.foreach { c => conf.set("spark.io.compression.codec", c) }
// Ensure that we actually have multiple batches per spill file
conf.set("spark.shuffle.spill.batchSize", "10")
conf
}
test("simple insert") {
- val conf = createSparkConf(false)
+ val conf = createSparkConf(loadDefaults = false)
sc = new SparkContext("local", "test", conf)
-
- val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
- mergeValue, mergeCombiners)
+ val map = createExternalMap[Int]
// Single insert
map.insert(1, 10)
var it = map.iterator
assert(it.hasNext)
val kv = it.next()
- assert(kv._1 == 1 && kv._2 == ArrayBuffer[Int](10))
+ assert(kv._1 === 1 && kv._2 === ArrayBuffer[Int](10))
assert(!it.hasNext)
// Multiple insert
@@ -61,18 +67,17 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
map.insert(3, 30)
it = map.iterator
assert(it.hasNext)
- assert(it.toSet == Set[(Int, ArrayBuffer[Int])](
+ assert(it.toSet === Set[(Int, ArrayBuffer[Int])](
(1, ArrayBuffer[Int](10)),
(2, ArrayBuffer[Int](20)),
(3, ArrayBuffer[Int](30))))
+ sc.stop()
}
test("insert with collision") {
- val conf = createSparkConf(false)
+ val conf = createSparkConf(loadDefaults = false)
sc = new SparkContext("local", "test", conf)
-
- val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
- mergeValue, mergeCombiners)
+ val map = createExternalMap[Int]
map.insertAll(Seq(
(1, 10),
@@ -84,30 +89,28 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
val it = map.iterator
assert(it.hasNext)
val result = it.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
- assert(result == Set[(Int, Set[Int])](
+ assert(result === Set[(Int, Set[Int])](
(1, Set[Int](10, 100, 1000)),
(2, Set[Int](20, 200)),
(3, Set[Int](30))))
+ sc.stop()
}
test("ordering") {
- val conf = createSparkConf(false)
+ val conf = createSparkConf(loadDefaults = false)
sc = new SparkContext("local", "test", conf)
- val map1 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
- mergeValue, mergeCombiners)
+ val map1 = createExternalMap[Int]
map1.insert(1, 10)
map1.insert(2, 20)
map1.insert(3, 30)
- val map2 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
- mergeValue, mergeCombiners)
+ val map2 = createExternalMap[Int]
map2.insert(2, 20)
map2.insert(3, 30)
map2.insert(1, 10)
- val map3 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
- mergeValue, mergeCombiners)
+ val map3 = createExternalMap[Int]
map3.insert(3, 30)
map3.insert(1, 10)
map3.insert(2, 20)
@@ -119,33 +122,33 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
var kv1 = it1.next()
var kv2 = it2.next()
var kv3 = it3.next()
- assert(kv1._1 == kv2._1 && kv2._1 == kv3._1)
- assert(kv1._2 == kv2._2 && kv2._2 == kv3._2)
+ assert(kv1._1 === kv2._1 && kv2._1 === kv3._1)
+ assert(kv1._2 === kv2._2 && kv2._2 === kv3._2)
kv1 = it1.next()
kv2 = it2.next()
kv3 = it3.next()
- assert(kv1._1 == kv2._1 && kv2._1 == kv3._1)
- assert(kv1._2 == kv2._2 && kv2._2 == kv3._2)
+ assert(kv1._1 === kv2._1 && kv2._1 === kv3._1)
+ assert(kv1._2 === kv2._2 && kv2._2 === kv3._2)
kv1 = it1.next()
kv2 = it2.next()
kv3 = it3.next()
- assert(kv1._1 == kv2._1 && kv2._1 == kv3._1)
- assert(kv1._2 == kv2._2 && kv2._2 == kv3._2)
+ assert(kv1._1 === kv2._1 && kv2._1 === kv3._1)
+ assert(kv1._2 === kv2._2 && kv2._2 === kv3._2)
+ sc.stop()
}
test("null keys and values") {
- val conf = createSparkConf(false)
+ val conf = createSparkConf(loadDefaults = false)
sc = new SparkContext("local", "test", conf)
- val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
- mergeValue, mergeCombiners)
+ val map = createExternalMap[Int]
map.insert(1, 5)
map.insert(2, 6)
map.insert(3, 7)
assert(map.size === 3)
- assert(map.iterator.toSet == Set[(Int, Seq[Int])](
+ assert(map.iterator.toSet === Set[(Int, Seq[Int])](
(1, Seq[Int](5)),
(2, Seq[Int](6)),
(3, Seq[Int](7))
@@ -155,7 +158,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
val nullInt = null.asInstanceOf[Int]
map.insert(nullInt, 8)
assert(map.size === 4)
- assert(map.iterator.toSet == Set[(Int, Seq[Int])](
+ assert(map.iterator.toSet === Set[(Int, Seq[Int])](
(1, Seq[Int](5)),
(2, Seq[Int](6)),
(3, Seq[Int](7)),
@@ -167,32 +170,34 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
map.insert(nullInt, nullInt)
assert(map.size === 5)
val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
- assert(result == Set[(Int, Set[Int])](
+ assert(result === Set[(Int, Set[Int])](
(1, Set[Int](5)),
(2, Set[Int](6)),
(3, Set[Int](7)),
(4, Set[Int](nullInt)),
(nullInt, Set[Int](nullInt, 8))
))
+ sc.stop()
}
test("simple aggregator") {
- val conf = createSparkConf(false)
+ val conf = createSparkConf(loadDefaults = false)
sc = new SparkContext("local", "test", conf)
// reduceByKey
val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1))
val result1 = rdd.reduceByKey(_+_).collect()
- assert(result1.toSet == Set[(Int, Int)]((0, 5), (1, 5)))
+ assert(result1.toSet === Set[(Int, Int)]((0, 5), (1, 5)))
// groupByKey
val result2 = rdd.groupByKey().collect().map(x => (x._1, x._2.toList)).toSet
- assert(result2.toSet == Set[(Int, Seq[Int])]
+ assert(result2.toSet === Set[(Int, Seq[Int])]
((0, List[Int](1, 1, 1, 1, 1)), (1, List[Int](1, 1, 1, 1, 1))))
+ sc.stop()
}
test("simple cogroup") {
- val conf = createSparkConf(false)
+ val conf = createSparkConf(loadDefaults = false)
sc = new SparkContext("local", "test", conf)
val rdd1 = sc.parallelize(1 to 4).map(i => (i, i))
val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i))
@@ -200,77 +205,98 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
result.foreach { case (i, (seq1, seq2)) =>
i match {
- case 0 => assert(seq1.toSet == Set[Int]() && seq2.toSet == Set[Int](2, 4))
- case 1 => assert(seq1.toSet == Set[Int](1) && seq2.toSet == Set[Int](1, 3))
- case 2 => assert(seq1.toSet == Set[Int](2) && seq2.toSet == Set[Int]())
- case 3 => assert(seq1.toSet == Set[Int](3) && seq2.toSet == Set[Int]())
- case 4 => assert(seq1.toSet == Set[Int](4) && seq2.toSet == Set[Int]())
+ case 0 => assert(seq1.toSet === Set[Int]() && seq2.toSet === Set[Int](2, 4))
+ case 1 => assert(seq1.toSet === Set[Int](1) && seq2.toSet === Set[Int](1, 3))
+ case 2 => assert(seq1.toSet === Set[Int](2) && seq2.toSet === Set[Int]())
+ case 3 => assert(seq1.toSet === Set[Int](3) && seq2.toSet === Set[Int]())
+ case 4 => assert(seq1.toSet === Set[Int](4) && seq2.toSet === Set[Int]())
}
}
+ sc.stop()
}
test("spilling") {
- val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ testSimpleSpilling()
+ }
+
+ test("spilling with compression") {
+ // Keep track of which compression codec we're using to report in test failure messages
+ var lastCompressionCodec: Option[String] = None
+ try {
+ allCompressionCodecs.foreach { c =>
+ lastCompressionCodec = Some(c)
+ testSimpleSpilling(Some(c))
+ }
+ } catch {
+ // Include compression codec used in test failure message
+ // We need to catch Throwable here because assertion failures are not covered by Exceptions
+ case t: Throwable =>
+ val compressionMessage = lastCompressionCodec
+ .map { c => "with compression using codec " + c }
+ .getOrElse("without compression")
+ val newException = new Exception(s"Test failed $compressionMessage:\n\n${t.getMessage}")
+ newException.setStackTrace(t.getStackTrace)
+ throw newException
+ }
+ }
+
+ /**
+ * Test spilling through simple aggregations and cogroups.
+ * If a compression codec is provided, use it. Otherwise, do not compress spills.
+ */
+ private def testSimpleSpilling(codec: Option[String] = None): Unit = {
+ val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
// reduceByKey - should spill ~8 times
val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
val resultA = rddA.reduceByKey(math.max).collect()
- assert(resultA.length == 50000)
- resultA.foreach { case(k, v) =>
- if (v != k * 2 + 1) {
- fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
- }
+ assert(resultA.length === 50000)
+ resultA.foreach { case (k, v) =>
+ assert(v === k * 2 + 1, s"Value for $k was wrong: expected ${k * 2 + 1}, got $v")
}
// groupByKey - should spill ~17 times
val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
val resultB = rddB.groupByKey().collect()
- assert(resultB.length == 25000)
- resultB.foreach { case(i, seq) =>
+ assert(resultB.length === 25000)
+ resultB.foreach { case (i, seq) =>
val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
- if (seq.toSet != expected) {
- fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
- }
+ assert(seq.toSet === expected,
+ s"Value for $i was wrong: expected $expected, got ${seq.toSet}")
}
// cogroup - should spill ~7 times
val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
val resultC = rddC1.cogroup(rddC2).collect()
- assert(resultC.length == 10000)
- resultC.foreach { case(i, (seq1, seq2)) =>
+ assert(resultC.length === 10000)
+ resultC.foreach { case (i, (seq1, seq2)) =>
i match {
case 0 =>
- assert(seq1.toSet == Set[Int](0))
- assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
+ assert(seq1.toSet === Set[Int](0))
+ assert(seq2.toSet === Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
case 1 =>
- assert(seq1.toSet == Set[Int](1))
- assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
+ assert(seq1.toSet === Set[Int](1))
+ assert(seq2.toSet === Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
case 5000 =>
- assert(seq1.toSet == Set[Int](5000))
- assert(seq2.toSet == Set[Int]())
+ assert(seq1.toSet === Set[Int](5000))
+ assert(seq2.toSet === Set[Int]())
case 9999 =>
- assert(seq1.toSet == Set[Int](9999))
- assert(seq2.toSet == Set[Int]())
+ assert(seq1.toSet === Set[Int](9999))
+ assert(seq2.toSet === Set[Int]())
case _ =>
}
}
+ sc.stop()
}
test("spilling with hash collisions") {
- val conf = createSparkConf(true)
+ val conf = createSparkConf(loadDefaults = true)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
-
- def createCombiner(i: String) = ArrayBuffer[String](i)
- def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i
- def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) =
- buffer1 ++= buffer2
-
- val map = new ExternalAppendOnlyMap[String, String, ArrayBuffer[String]](
- createCombiner, mergeValue, mergeCombiners)
+ val map = createExternalMap[String]
val collisionPairs = Seq(
("Aa", "BB"), // 2112
@@ -312,13 +338,13 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
count += 1
}
assert(count === 100000 + collisionPairs.size * 2)
+ sc.stop()
}
test("spilling with many hash collisions") {
- val conf = createSparkConf(true)
+ val conf = createSparkConf(loadDefaults = true)
conf.set("spark.shuffle.memoryFraction", "0.0001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
-
val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
// Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
@@ -337,15 +363,14 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
count += 1
}
assert(count === 10000)
+ sc.stop()
}
test("spilling with hash collisions using the Int.MaxValue key") {
- val conf = createSparkConf(true)
+ val conf = createSparkConf(loadDefaults = true)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
-
- val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](
- createCombiner, mergeValue, mergeCombiners)
+ val map = createExternalMap[Int]
(1 to 100000).foreach { i => map.insert(i, i) }
map.insert(Int.MaxValue, Int.MaxValue)
@@ -355,15 +380,14 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
// Should not throw NoSuchElementException
it.next()
}
+ sc.stop()
}
test("spilling with null keys and values") {
- val conf = createSparkConf(true)
+ val conf = createSparkConf(loadDefaults = true)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
-
- val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](
- createCombiner, mergeValue, mergeCombiners)
+ val map = createExternalMap[Int]
map.insertAll((1 to 100000).iterator.map(i => (i, i)))
map.insert(null.asInstanceOf[Int], 1)
@@ -375,6 +399,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
// Should not throw NullPointerException
it.next()
}
+ sc.stop()
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala
new file mode 100644
index 0000000000000..f855831b8e367
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.io
+
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+
+class ByteArrayChunkOutputStreamSuite extends FunSuite {
+
+ test("empty output") {
+ val o = new ByteArrayChunkOutputStream(1024)
+ assert(o.toArrays.length === 0)
+ }
+
+ test("write a single byte") {
+ val o = new ByteArrayChunkOutputStream(1024)
+ o.write(10)
+ assert(o.toArrays.length === 1)
+ assert(o.toArrays.head.toSeq === Seq(10.toByte))
+ }
+
+ test("write a single near boundary") {
+ val o = new ByteArrayChunkOutputStream(10)
+ o.write(new Array[Byte](9))
+ o.write(99)
+ assert(o.toArrays.length === 1)
+ assert(o.toArrays.head(9) === 99.toByte)
+ }
+
+ test("write a single at boundary") {
+ val o = new ByteArrayChunkOutputStream(10)
+ o.write(new Array[Byte](10))
+ o.write(99)
+ assert(o.toArrays.length === 2)
+ assert(o.toArrays(1).length === 1)
+ assert(o.toArrays(1)(0) === 99.toByte)
+ }
+
+ test("single chunk output") {
+ val ref = new Array[Byte](8)
+ Random.nextBytes(ref)
+ val o = new ByteArrayChunkOutputStream(10)
+ o.write(ref)
+ val arrays = o.toArrays
+ assert(arrays.length === 1)
+ assert(arrays.head.length === ref.length)
+ assert(arrays.head.toSeq === ref.toSeq)
+ }
+
+ test("single chunk output at boundary size") {
+ val ref = new Array[Byte](10)
+ Random.nextBytes(ref)
+ val o = new ByteArrayChunkOutputStream(10)
+ o.write(ref)
+ val arrays = o.toArrays
+ assert(arrays.length === 1)
+ assert(arrays.head.length === ref.length)
+ assert(arrays.head.toSeq === ref.toSeq)
+ }
+
+ test("multiple chunk output") {
+ val ref = new Array[Byte](26)
+ Random.nextBytes(ref)
+ val o = new ByteArrayChunkOutputStream(10)
+ o.write(ref)
+ val arrays = o.toArrays
+ assert(arrays.length === 3)
+ assert(arrays(0).length === 10)
+ assert(arrays(1).length === 10)
+ assert(arrays(2).length === 6)
+
+ assert(arrays(0).toSeq === ref.slice(0, 10))
+ assert(arrays(1).toSeq === ref.slice(10, 20))
+ assert(arrays(2).toSeq === ref.slice(20, 26))
+ }
+
+ test("multiple chunk output at boundary size") {
+ val ref = new Array[Byte](30)
+ Random.nextBytes(ref)
+ val o = new ByteArrayChunkOutputStream(10)
+ o.write(ref)
+ val arrays = o.toArrays
+ assert(arrays.length === 3)
+ assert(arrays(0).length === 10)
+ assert(arrays(1).length === 10)
+ assert(arrays(2).length === 10)
+
+ assert(arrays(0).toSeq === ref.slice(0, 10))
+ assert(arrays(1).toSeq === ref.slice(10, 20))
+ assert(arrays(2).toSeq === ref.slice(20, 30))
+ }
+}
diff --git a/dev/check-license b/dev/check-license
index 625ec161bc571..9ff0929e9a5e8 100755
--- a/dev/check-license
+++ b/dev/check-license
@@ -23,18 +23,18 @@ acquire_rat_jar () {
URL1="http://search.maven.org/remotecontent?filepath=org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar"
URL2="http://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar"
- JAR=$rat_jar
+ JAR="$rat_jar"
if [[ ! -f "$rat_jar" ]]; then
# Download rat launch jar if it hasn't been downloaded yet
if [ ! -f "$JAR" ]; then
# Download
printf "Attempting to fetch rat\n"
- JAR_DL=${JAR}.part
+ JAR_DL="${JAR}.part"
if hash curl 2>/dev/null; then
- (curl --progress-bar ${URL1} > "$JAR_DL" || curl --progress-bar ${URL2} > "$JAR_DL") && mv "$JAR_DL" "$JAR"
+ (curl --silent "${URL1}" > "$JAR_DL" || curl --silent "${URL2}" > "$JAR_DL") && mv "$JAR_DL" "$JAR"
elif hash wget 2>/dev/null; then
- (wget --progress=bar ${URL1} -O "$JAR_DL" || wget --progress=bar ${URL2} -O "$JAR_DL") && mv "$JAR_DL" "$JAR"
+ (wget --quiet ${URL1} -O "$JAR_DL" || wget --quiet ${URL2} -O "$JAR_DL") && mv "$JAR_DL" "$JAR"
else
printf "You do not have curl or wget installed, please install rat manually.\n"
exit -1
@@ -50,7 +50,7 @@ acquire_rat_jar () {
}
# Go to the Spark project root directory
-FWDIR="$(cd `dirname $0`/..; pwd)"
+FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
cd "$FWDIR"
if test -x "$JAVA_HOME/bin/java"; then
@@ -60,17 +60,17 @@ else
fi
export RAT_VERSION=0.10
-export rat_jar=$FWDIR/lib/apache-rat-${RAT_VERSION}.jar
-mkdir -p $FWDIR/lib
+export rat_jar="$FWDIR"/lib/apache-rat-${RAT_VERSION}.jar
+mkdir -p "$FWDIR"/lib
[[ -f "$rat_jar" ]] || acquire_rat_jar || {
echo "Download failed. Obtain the rat jar manually and place it at $rat_jar"
exit 1
}
-$java_cmd -jar $rat_jar -E $FWDIR/.rat-excludes -d $FWDIR > rat-results.txt
+$java_cmd -jar "$rat_jar" -E "$FWDIR"/.rat-excludes -d "$FWDIR" > rat-results.txt
-ERRORS=$(cat rat-results.txt | grep -e "??")
+ERRORS="$(cat rat-results.txt | grep -e "??")"
if test ! -z "$ERRORS"; then
echo "Could not find Apache license headers in the following files:"
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index 28f26d2368254..281e8d4de6d71 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -60,14 +60,14 @@ if [[ ! "$@" =~ --package-only ]]; then
-Dmaven.javadoc.skip=true \
-Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
-Dtag=$GIT_TAG -DautoVersionSubmodules=true \
- -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
+ -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
--batch-mode release:prepare
mvn -DskipTests \
-Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
-Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
-Dmaven.javadoc.skip=true \
- -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
+ -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
release:perform
cd ..
@@ -117,12 +117,13 @@ make_binary_release() {
spark-$RELEASE_VERSION-bin-$NAME.tgz.sha
}
-make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" &
-make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" &
-make_binary_release "hadoop2" \
- "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" &
-make_binary_release "hadoop2-without-hive" \
- "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" &
+make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" &
+make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" &
+make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Pyarn" &
+make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Pyarn" &
+make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" &
+make_binary_release "mapr3" "-Pmapr3 -Phive" &
+make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive" &
wait
# Copy data
diff --git a/dev/create-release/generate-changelist.py b/dev/create-release/generate-changelist.py
index de1b5d4ae1314..2e1a35a629342 100755
--- a/dev/create-release/generate-changelist.py
+++ b/dev/create-release/generate-changelist.py
@@ -125,8 +125,8 @@ def cleanup(ask=True):
pr_num = [line.split()[1].lstrip("#") for line in body_lines if "Closes #" in line][0]
github_url = "github.com/apache/spark/pull/%s" % pr_num
day = time.strptime(date.split()[0], "%Y-%m-%d")
- if day < SPARK_REPO_CHANGE_DATE1 or
- (day < SPARK_REPO_CHANGE_DATE2 and pr_num < SPARK_REPO_PR_NUM_THRESH):
+ if (day < SPARK_REPO_CHANGE_DATE1 or
+ (day < SPARK_REPO_CHANGE_DATE2 and pr_num < SPARK_REPO_PR_NUM_THRESH)):
github_url = "github.com/apache/incubator-spark/pull/%s" % pr_num
append_to_changelist(" %s" % subject)
diff --git a/dev/lint-python b/dev/lint-python
index 4efddad839387..772f856154ae0 100755
--- a/dev/lint-python
+++ b/dev/lint-python
@@ -18,10 +18,10 @@
#
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
-SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)"
+SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")"
PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt"
-cd $SPARK_ROOT_DIR
+cd "$SPARK_ROOT_DIR"
# Get pep8 at runtime so that we don't rely on it being installed on the build server.
#+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162
@@ -30,6 +30,7 @@ cd $SPARK_ROOT_DIR
#+ - Download this from a more reliable source. (GitHub raw can be flaky, apparently. (?))
PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8.py"
PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.5.7/pep8.py"
+PEP8_PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/"
curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH"
curl_status=$?
@@ -44,7 +45,7 @@ fi
#+ first, but we do so so that the check status can
#+ be output before the report, like with the
#+ scalastyle and RAT checks.
-python $PEP8_SCRIPT_PATH ./python > "$PEP8_REPORT_PATH"
+python "$PEP8_SCRIPT_PATH" $PEP8_PATHS_TO_CHECK > "$PEP8_REPORT_PATH"
pep8_status=${PIPESTATUS[0]} #$?
if [ $pep8_status -ne 0 ]; then
@@ -54,7 +55,7 @@ else
echo "PEP 8 checks passed."
fi
-rm -f "$PEP8_REPORT_PATH"
+rm "$PEP8_REPORT_PATH"
rm "$PEP8_SCRIPT_PATH"
exit $pep8_status
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index d48c8bde12905..a8e92e36fe0d8 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -44,9 +44,9 @@
# Remote name which points to Apache git
PUSH_REMOTE_NAME = os.environ.get("PUSH_REMOTE_NAME", "apache")
# ASF JIRA username
-JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "")
+JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "pwendell")
# ASF JIRA password
-JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "")
+JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "35500")
GITHUB_BASE = "https://github.com/apache/spark/pull"
GITHUB_API_BASE = "https://api.github.com/repos/apache/spark"
diff --git a/dev/mima b/dev/mima
index 09e4482af5f3d..40603166c21ae 100755
--- a/dev/mima
+++ b/dev/mima
@@ -21,15 +21,23 @@ set -o pipefail
set -e
# Go to the Spark project root directory
-FWDIR="$(cd `dirname $0`/..; pwd)"
+FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
cd "$FWDIR"
echo -e "q\n" | sbt/sbt oldDeps/update
+rm -f .generated-mima*
+
+# Generate Mima Ignore is called twice, first with latest built jars
+# on the classpath and then again with previous version jars on the classpath.
+# Because of a bug in GenerateMIMAIgnore that when old jars are ahead on classpath
+# it did not process the new classes (which are in assembly jar).
+./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore
-export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`
+export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`"
echo "SPARK_CLASSPATH=$SPARK_CLASSPATH"
./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore
+
echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving"
ret_val=$?
diff --git a/dev/run-tests b/dev/run-tests
index 0e24515d1376c..c3d8f49cdd993 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -18,47 +18,76 @@
#
# Go to the Spark project root directory
-FWDIR="$(cd `dirname $0`/..; pwd)"
+FWDIR="$(cd "`dirname $0`"/..; pwd)"
cd "$FWDIR"
-if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then
- if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then
- export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=1.0.4"
- elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then
- export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1"
- elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then
- export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0"
- elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then
- export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0"
+# Remove work directory
+rm -rf ./work
+
+# Build against the right verison of Hadoop.
+{
+ if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then
+ if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then
+ export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=1.0.4"
+ elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then
+ export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1"
+ elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then
+ export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0"
+ elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then
+ export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0"
+ fi
fi
-fi
-if [ -z "$SBT_MAVEN_PROFILES_ARGS" ]; then
- export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0"
-fi
+ if [ -z "$SBT_MAVEN_PROFILES_ARGS" ]; then
+ export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0"
+ fi
+}
export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl"
-echo "SBT_MAVEN_PROFILES_ARGS=\"$SBT_MAVEN_PROFILES_ARGS\""
-
-# Remove work directory
-rm -rf ./work
-
-if test -x "$JAVA_HOME/bin/java"; then
- declare java_cmd="$JAVA_HOME/bin/java"
-else
- declare java_cmd=java
-fi
-JAVA_VERSION=$($java_cmd -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q')
-[ "$JAVA_VERSION" -ge 18 ] && echo "" || echo "[Warn] Java 8 tests will not run because JDK version is < 1.8."
+# Determine Java path and version.
+{
+ if test -x "$JAVA_HOME/bin/java"; then
+ declare java_cmd="$JAVA_HOME/bin/java"
+ else
+ declare java_cmd=java
+ fi
+
+ # We can't use sed -r -e due to OS X / BSD compatibility; hence, all the parentheses.
+ JAVA_VERSION=$(
+ $java_cmd -version 2>&1 \
+ | grep -e "^java version" --max-count=1 \
+ | sed "s/java version \"\(.*\)\.\(.*\)\.\(.*\)\"/\1\2/"
+ )
+
+ if [ "$JAVA_VERSION" -lt 18 ]; then
+ echo "[warn] Java 8 tests will not run because JDK version is < 1.8."
+ fi
+}
-# Partial solution for SPARK-1455. Only run Hive tests if there are sql changes.
+# Only run Hive tests if there are sql changes.
+# Partial solution for SPARK-1455.
if [ -n "$AMPLAB_JENKINS" ]; then
git fetch origin master:master
- diffs=`git diff --name-only master | grep "^sql/"`
- if [ -n "$diffs" ]; then
- echo "Detected changes in SQL. Will run Hive test suite."
- export _RUN_SQL_TESTS=true # exported for PySpark tests
+
+ sql_diffs=$(
+ git diff --name-only master \
+ | grep -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh"
+ )
+
+ non_sql_diffs=$(
+ git diff --name-only master \
+ | grep -v -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh"
+ )
+
+ if [ -n "$sql_diffs" ]; then
+ echo "[info] Detected changes in SQL. Will run Hive test suite."
+ _RUN_SQL_TESTS=true
+
+ if [ -z "$non_sql_diffs" ]; then
+ echo "[info] Detected no changes except in SQL. Will only run SQL tests."
+ _SQL_TESTS_ONLY=true
+ fi
fi
fi
@@ -70,33 +99,77 @@ echo ""
echo "========================================================================="
echo "Running Apache RAT checks"
echo "========================================================================="
-dev/check-license
+./dev/check-license
echo ""
echo "========================================================================="
echo "Running Scala style checks"
echo "========================================================================="
-dev/lint-scala
+./dev/lint-scala
echo ""
echo "========================================================================="
echo "Running Python style checks"
echo "========================================================================="
-dev/lint-python
+./dev/lint-python
+
+echo ""
+echo "========================================================================="
+echo "Building Spark"
+echo "========================================================================="
+
+{
+ # We always build with Hive because the PySpark Spark SQL tests need it.
+ BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive"
+
+ echo "[info] Building Spark with these arguments: $BUILD_MVN_PROFILE_ARGS"
+
+ # NOTE: echo "q" is needed because sbt on encountering a build file with failure
+ #+ (either resolution or compilation) prompts the user for input either q, r, etc
+ #+ to quit or retry. This echo is there to make it not block.
+ # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a
+ #+ single argument!
+ # QUESTION: Why doesn't 'yes "q"' work?
+ # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work?
+ echo -e "q\n" \
+ | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly \
+ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
+}
echo ""
echo "========================================================================="
echo "Running Spark unit tests"
echo "========================================================================="
-if [ -n "$_RUN_SQL_TESTS" ]; then
- SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver"
-fi
-# echo "q" is needed because sbt on encountering a build file with failure
-# (either resolution or compilation) prompts the user for input either q, r,
-# etc to quit or retry. This echo is there to make it not block.
-echo -e "q\n" | sbt/sbt $SBT_MAVEN_PROFILES_ARGS clean package assembly/assembly test | \
- grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
+{
+ # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled.
+ # This must be a single argument, as it is.
+ if [ -n "$_RUN_SQL_TESTS" ]; then
+ SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive"
+ fi
+
+ if [ -n "$_SQL_TESTS_ONLY" ]; then
+ # This must be an array of individual arguments. Otherwise, having one long string
+ #+ will be interpreted as a single test, which doesn't work.
+ SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test")
+ else
+ SBT_MAVEN_TEST_ARGS=("test")
+ fi
+
+ echo "[info] Running Spark tests with these arguments: $SBT_MAVEN_PROFILES_ARGS ${SBT_MAVEN_TEST_ARGS[@]}"
+
+ # NOTE: echo "q" is needed because sbt on encountering a build file with failure
+ #+ (either resolution or compilation) prompts the user for input either q, r, etc
+ #+ to quit or retry. This echo is there to make it not block.
+ # NOTE: Do not quote $SBT_MAVEN_PROFILES_ARGS or else it will be interpreted as a
+ #+ single argument!
+ #+ "${SBT_MAVEN_TEST_ARGS[@]}" is cool because it's an array.
+ # QUESTION: Why doesn't 'yes "q"' work?
+ # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work?
+ echo -e "q\n" \
+ | sbt/sbt $SBT_MAVEN_PROFILES_ARGS "${SBT_MAVEN_TEST_ARGS[@]}" \
+ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
+}
echo ""
echo "========================================================================="
@@ -108,4 +181,4 @@ echo ""
echo "========================================================================="
echo "Detecting binary incompatibilites with MiMa"
echo "========================================================================="
-dev/mima
+./dev/mima
diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins
index 31506e28e05af..06c3781eb3ccf 100755
--- a/dev/run-tests-jenkins
+++ b/dev/run-tests-jenkins
@@ -33,9 +33,7 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}"
# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :(
SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}"
-# NOTE: Jenkins will kill the whole build after 120 minutes.
-# Tests are a large part of that, but not all of it.
-TESTS_TIMEOUT="120m"
+TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout
function post_message () {
local message=$1
@@ -93,9 +91,14 @@ function post_message () {
else
merge_note=" * This patch merges cleanly."
- non_test_files=$(git diff master --name-only | grep -v "\/test" | tr "\n" " ")
+ source_files=$(
+ git diff master --name-only \
+ | grep -v -e "\/test" `# ignore files in test directories` \
+ | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \
+ | tr "\n" " "
+ )
new_public_classes=$(
- git diff master ${non_test_files} `# diff this patch against master and...` \
+ git diff master ${source_files} `# diff this patch against master and...` \
| grep "^\+" `# filter in only added lines` \
| sed -r -e "s/^\+//g" `# remove the leading +` \
| grep -e "trait " -e "class " `# filter in lines with these key words` \
@@ -138,7 +141,8 @@ function post_message () {
test_result="$?"
if [ "$test_result" -eq "124" ]; then
- fail_message="**Tests timed out** after a configured wait of \`${TESTS_TIMEOUT}\`."
+ fail_message="**[Tests timed out](${BUILD_URL}consoleFull)** after \
+ a configured wait of \`${TESTS_TIMEOUT}\`."
post_message "$fail_message"
exit $test_result
else
diff --git a/dev/scalastyle b/dev/scalastyle
index b53053a04ff42..efb5f291ea3b7 100755
--- a/dev/scalastyle
+++ b/dev/scalastyle
@@ -17,9 +17,9 @@
# limitations under the License.
#
-echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt
+echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt
# Check style with YARN alpha built too
-echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \
+echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \
>> scalastyle.txt
# Check style with YARN built too
echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalastyle \
diff --git a/docs/README.md b/docs/README.md
index fd7ba4e0d72ea..79708c3df9106 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -20,17 +20,22 @@ In this directory you will find textfiles formatted using Markdown, with an ".md
read those text files directly if you want. Start with index.md.
The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllrb.com).
-To use the `jekyll` command, you will need to have Jekyll installed.
-The easiest way to do this is via a Ruby Gem, see the
-[jekyll installation instructions](http://jekyllrb.com/docs/installation).
-If not already installed, you need to install `kramdown` with `sudo gem install kramdown`.
+`Jekyll` and a few dependencies must be installed for this to work. We recommend
+installing via the Ruby Gem dependency manager. Since the exact HTML output
+varies between versions of Jekyll and its dependencies, we list specific versions here
+in some cases:
+
+ $ sudo gem install jekyll -v 1.4.3
+ $ sudo gem uninstall kramdown -v 1.4.1
+ $ sudo gem install jekyll-redirect-from
+
Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory
called `_site` containing index.html as well as the rest of the compiled files.
You can modify the default Jekyll build as follows:
# Skip generating API docs (which takes a while)
- $ SKIP_SCALADOC=1 jekyll build
+ $ SKIP_API=1 jekyll build
# Serve content locally on port 4000
$ jekyll serve --watch
# Build the site with extra features used on the live page
diff --git a/docs/_config.yml b/docs/_config.yml
index 45b78fe724a50..7bc3a78e2d265 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -1,5 +1,12 @@
-pygments: true
+highlighter: pygments
markdown: kramdown
+gems:
+ - jekyll-redirect-from
+
+# For some reason kramdown seems to behave differently on different
+# OS/packages wrt encoding. So we hard code this config.
+kramdown:
+ entity_output: numeric
# These allow the documentation to be updated with nerw releases
# of Spark, Scala, and Mesos.
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index b30ab1e5218c0..627ed37de4a9c 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -109,8 +109,9 @@
Hardware Provisioning
3rd-Party Hadoop Distros
- Building Spark with Maven
+ Building Spark
Contributing to Spark
+ Supplemental Projects
@@ -151,7 +152,7 @@ {{ page.title }}
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ["$", "$"], ["\\\\(","\\\\)"] ],
- displayMath: [ ["$$","$$"], ["\\[", "\\]"] ],
+ displayMath: [ ["$$","$$"], ["\\[", "\\]"] ],
processEscapes: true,
skipTags: ['script', 'noscript', 'style', 'textarea', 'pre']
}
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index 2dbbbf6feb4b8..3b02e090aec28 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -25,8 +25,8 @@
curr_dir = pwd
cd("..")
- puts "Running 'sbt/sbt compile unidoc' from " + pwd + "; this may take a few minutes..."
- puts `sbt/sbt compile unidoc`
+ puts "Running 'sbt/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..."
+ puts `sbt/sbt -Pkinesis-asl compile unidoc`
puts "Moving back into docs dir."
cd("docs")
diff --git a/docs/building-with-maven.md b/docs/building-spark.md
similarity index 84%
rename from docs/building-with-maven.md
rename to docs/building-spark.md
index 4d87ab92cec5b..2378092d4a1a8 100644
--- a/docs/building-with-maven.md
+++ b/docs/building-spark.md
@@ -1,6 +1,7 @@
---
layout: global
-title: Building Spark with Maven
+title: Building Spark
+redirect_from: "building-with-maven.html"
---
* This will become a table of contents (this text will be scraped).
@@ -96,13 +97,12 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package
mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package
{% endhighlight %}
-# Building Thrift JDBC server and CLI for Spark SQL
-
-Spark SQL supports Thrift JDBC server and CLI.
-See sql-programming-guide.md for more information about those features.
-You can use those features by setting `-Phive-thriftserver` when building Spark as follows.
+# Building With Hive and JDBC Support
+To enable Hive integration for Spark SQL along with its JDBC server and CLI,
+add the `-Phive` profile to your existing build options.
{% highlight bash %}
-mvn -Phive-thriftserver assembly
+# Apache Hadoop 2.4.X with Hive support
+mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package
{% endhighlight %}
# Spark Tests in Maven
@@ -160,4 +160,21 @@ then ship it over to the cluster. We are investigating the exact cause for this.
The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with yarn.application.classpath. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself.
+# Building with SBT
+
+Maven is the official recommendation for packaging Spark, and is the "build of reference".
+But SBT is supported for day-to-day development since it can provide much faster iterative
+compilation. More advanced developers may wish to use SBT.
+
+The SBT build is derived from the Maven POM files, and so the same Maven profiles and variables
+can be set to control the SBT build. For example:
+
+ sbt/sbt -Pyarn -Phadoop-2.3 compile
+
+# Speeding up Compilation with Zinc
+[Zinc](https://github.com/typesafehub/zinc) is a long-running server version of SBT's incremental
+compiler. When run locally as a background process, it speeds up builds of Scala-based projects
+like Spark. Developers who regularly recompile Spark with Maven will be the most interested in
+Zinc. The project site gives instructions for building and running `zinc`; OS X users can
+install it using `brew install zinc`.
\ No newline at end of file
diff --git a/docs/configuration.md b/docs/configuration.md
index 981170d8b49b7..a6dd7245e1552 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -206,6 +206,16 @@ Apart from these, the following properties are also available, and may be useful
used during aggregation goes above this amount, it will spill the data into disks.
|